-
Notifications
You must be signed in to change notification settings - Fork 0
/
loader.py
111 lines (92 loc) · 4.54 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import numpy as np
import os
import json
import random
import worker
import torch
import torch.distributed as dist
from datetime import datetime
from torch.backends import cudnn
def multi_gpu_setup(local_rank, args, gpus_per_node, port_number):
cudnn.benchmark, cudnn.deterministic = True, False
dist.init_process_group(backend='nccl',
init_method='tcp://%s:%s' % ('localhost', str(port_number)),
rank=local_rank,
world_size=gpus_per_node)
torch.cuda.set_device(local_rank)
def load_worker(local_rank, args, gpus_per_node, port_number):
# setup multi-gpu processing
multi_gpu_setup(local_rank, args, gpus_per_node, port_number)
if args.phase == 'train':
with open(os.path.join(args.model_name, 'args.txt'), 'w') as f:
json.dump(args.__dict__, f, indent=2)
gan_worker = worker.WORKER(args, local_rank, gpus_per_node)
epoch = 0
start_time = datetime.now()
# Load the epoch number from epoch.txt if it exists
epoch_file_path = os.path.join(args.model_name, 'epoch.txt')
if os.path.exists(epoch_file_path):
with open(epoch_file_path, "r") as file:
epoch = int(file.read().strip()) + 1
print("restart training from:", epoch)
gan_worker.load_model()
dist.barrier(gan_worker.group)
while epoch <= args.epoch:
gan_worker.requires_grad(gan_worker.generator, True)
gan_worker.requires_grad(gan_worker.discriminator, False)
g_loss = gan_worker.train_generator(epoch)
gan_worker.ema_update(epoch)
gan_worker.requires_grad(gan_worker.generator, False)
gan_worker.requires_grad(gan_worker.discriminator, True)
if epoch >= args.freezeD_start:
gan_worker.freeze_discriminator(args.freezeD_layer)
d_loss = gan_worker.train_discriminator(epoch)
if epoch % args.print_interval == 0:
elapsed = datetime.now() - start_time
if local_rank == 0:
elapsed = str(elapsed).split(".")[0]
if epoch == 0:
file = open(os.path.join(args.model_name, 'log.txt'), "w")
else:
file = open(os.path.join(args.model_name, 'log.txt'), "a")
file.write("epoch:{loop}, elapsed:{elapsed}, "
"g_loss:{g_loss:.6f}, d_loss:{d_loss:.6f} \n"
.format(loop=epoch, elapsed=elapsed, g_loss=g_loss, d_loss=d_loss))
file.close()
dist.barrier(gan_worker.group)
if epoch % args.show_interval == 0 and epoch > 0:
if local_rank == 0:
gan_worker.monitor_current_result(num_explore=20, w_psi=args.w_psi, epoch=epoch, images_per_output=args.geo_noise_dim)
dist.barrier(gan_worker.group)
if epoch % args.save_interval == 0 and epoch > 0:
if local_rank == 0:
gan_worker.save_model()
with open(epoch_file_path, 'w') as f:
f.write(str(epoch))
dist.barrier(gan_worker.group)
epoch += 1
elif args.phase == 'fid_eval':
# fid evaluation phase
print(args)
gan_worker = worker.WORKER(args, local_rank, gpus_per_node)
gan_worker.load_model()
dist.barrier(gan_worker.group)
fid_value = gan_worker.fid_evaluate()
file = open(os.path.join(args.model_name, 'fid.txt', "w"))
file.write("FID:{fid} \n".format(fid=fid_value))
file.close()
elif args.phase == 'fake_image_generation':
gan_worker = worker.WORKER(args, local_rank, gpus_per_node)
gan_worker.load_model()
dist.barrier(gan_worker.group)
gan_worker.fake_image_generation(num_images=args.num_fakes)
elif args.phase == 'video_generation':
gan_worker = worker.WORKER(args, local_rank, gpus_per_node)
gan_worker.load_model()
dist.barrier(gan_worker.group)
ctrl_dim = args.ctrl_dim
if ctrl_dim == -1:
for i in range(args.geo_noise_dim+args.app_noise_dim):
gan_worker.demo_generation(controlled_dim=i, num_video=args.num_videos)
else:
gan_worker.demo_generation(controlled_dim=ctrl_dim, num_video=args.num_videos)