-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
390 lines (309 loc) · 13.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
# ---------------------------------------------------------------
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file.
# ---------------------------------------------------------------
import os
import shutil
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.multiprocessing import Process
from tqdm import tqdm
from util.args_parser import args_parser
from util.data_process import getCleanData, getMixedData
from util.diffusion_coefficients import get_sigma_schedule, get_time_schedule
from util.utility import broadcast_params, copy_source, q_sample_pairs, sample_from_model, sample_posterior, select_phi
class Diffusion_Coefficients:
def __init__(self, args, device):
self.sigmas, self.a_s, _ = get_sigma_schedule(args, device=device)
self.a_s_cum = np.cumprod(self.a_s.cpu())
self.sigmas_cum = np.sqrt(1 - self.a_s_cum**2)
self.a_s_prev = self.a_s.clone()
self.a_s_prev[-1] = 1
self.a_s_cum = self.a_s_cum.to(device)
self.sigmas_cum = self.sigmas_cum.to(device)
self.a_s_prev = self.a_s_prev.to(device)
# %% posterior sampling
class Posterior_Coefficients:
def __init__(self, args, device):
_, _, self.betas = get_sigma_schedule(args, device=device)
# we don't need the zeros
self.betas = self.betas.type(torch.float32)[1:]
self.alphas = 1 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, 0)
self.alphas_cumprod_prev = torch.cat(
(
torch.tensor([1.0], dtype=torch.float32, device=device),
self.alphas_cumprod[:-1],
),
0,
)
self.posterior_variance = self.betas * (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = torch.rsqrt(self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1 / self.alphas_cumprod - 1)
self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)
self.posterior_mean_coef2 = (
(1 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1 - self.alphas_cumprod)
)
self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min=1e-20))
# %%
def train(rank, gpu, args):
from EMA import EMA
from score_sde.models.discriminator import Discriminator_64, Discriminator_large, Discriminator_small
from score_sde.models.ncsnpp_generator_adagn import NCSNpp
torch.manual_seed(args.seed + rank)
torch.cuda.manual_seed(args.seed + rank)
torch.cuda.manual_seed_all(args.seed + rank)
device = torch.device("cuda:{}".format(gpu))
batch_size = args.batch_size
nz = args.nz # latent dimension
if args.perturb_dataset == "none":
dataset = getCleanData(args.dataset, image_size=args.image_size)
else:
dataset = getMixedData(
args.dataset,
args.perturb_dataset,
percentage=args.perturb_percent,
image_size=args.image_size,
shuffle=args.shuffle,
)
print("Finish loading dataset")
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=args.world_size, rank=rank)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
sampler=train_sampler,
drop_last=True,
)
netG = NCSNpp(args).to(device)
if args.dataset in ["cifar10", "mnist", "stackmnist", "stl10", "celeba_64"]:
print("using small discriminator")
netD = Discriminator_small(
nc=2 * args.num_channels,
ngf=args.ngf,
t_emb_dim=args.t_emb_dim,
act=nn.LeakyReLU(0.2),
).to(device)
elif args.dataset in ["clipart", "quickdraw", "sketch"]:
print("using 64 discriminator")
netD = Discriminator_64(
nc=2 * args.num_channels,
ngf=args.ngf,
t_emb_dim=args.t_emb_dim,
act=nn.LeakyReLU(0.2),
).to(device)
else:
print("using large discriminator")
netD = Discriminator_large(
nc=2 * args.num_channels,
ngf=args.ngf,
t_emb_dim=args.t_emb_dim,
act=nn.LeakyReLU(0.2),
).to(device)
broadcast_params(netG.parameters())
broadcast_params(netD.parameters())
optimizerD = optim.Adam(netD.parameters(), lr=args.lr_d, betas=(args.beta1, args.beta2))
optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas=(args.beta1, args.beta2))
if args.use_ema:
optimizerG = EMA(optimizerG, ema_decay=args.ema_decay)
schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, args.schedule, eta_min=1e-5)
schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.schedule, eta_min=1e-5)
# ddp
netG = nn.parallel.DistributedDataParallel(netG, device_ids=[gpu])
netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
exp = args.exp
algo = "rduot"
parent_dir = f"./saved_info/{algo}/{args.dataset}"
if args.perturb_percent > 0:
parent_dir += f"_{int(args.perturb_percent)}p_{args.perturb_dataset}"
parent_dir += f"/{args.version}"
if exp == "none":
exp_path = parent_dir
else:
exp_path = os.path.join(parent_dir, exp)
if rank == 0:
if not os.path.exists(exp_path):
os.makedirs(exp_path)
copy_source(__file__, exp_path)
shutil.copytree("score_sde/models", os.path.join(exp_path, "score_sde/models"))
coeff = Diffusion_Coefficients(args, device)
pos_coeff = Posterior_Coefficients(args, device)
T = get_time_schedule(args, device)
if args.resume:
checkpoint_file = os.path.join(exp_path, "content.pth")
checkpoint = torch.load(checkpoint_file, map_location=device)
init_epoch = checkpoint["epoch"]
epoch = init_epoch
netG.load_state_dict(checkpoint["netG_dict"])
# load G
optimizerG.load_state_dict(checkpoint["optimizerG"])
schedulerG.load_state_dict(checkpoint["schedulerG"])
# load D
netD.load_state_dict(checkpoint["netD_dict"])
optimizerD.load_state_dict(checkpoint["optimizerD"])
schedulerD.load_state_dict(checkpoint["schedulerD"])
global_step = checkpoint["global_step"]
print("=> loaded checkpoint (epoch {})".format(checkpoint["epoch"]))
else:
global_step, epoch, init_epoch = 0, 0, 0
# get phi star
phi_star1 = select_phi(args.phi1)
phi_star2 = select_phi(args.phi2)
for epoch in range(init_epoch, args.num_epoch + 1):
train_sampler.set_epoch(epoch)
for iteration, (x, y) in enumerate(tqdm(data_loader)):
for p in netD.parameters():
p.requires_grad = True
netD.zero_grad()
# sample from p(x_0)
real_data = x.to(device, non_blocking=True)
# sample t
t = torch.randint(0, args.num_timesteps, (real_data.size(0),), device=device)
x_t, x_tp1, noise = q_sample_pairs(coeff, real_data, t)
x_t.requires_grad = True
# train with real
D_real = netD(x_t, t, x_tp1.detach())
errD_real = phi_star2(-D_real)
errD_real = errD_real.mean()
errD_real.backward(retain_graph=True)
if args.lazy_reg is None:
grad_real = torch.autograd.grad(outputs=D_real.sum(), inputs=x_t, create_graph=True)[0]
grad_penalty = (grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean()
grad_penalty = args.r1_gamma / 2 * grad_penalty
grad_penalty.backward()
else:
if global_step % args.lazy_reg == 0:
grad_real = torch.autograd.grad(outputs=D_real.sum(), inputs=x_t, create_graph=True)[0]
grad_penalty = (grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean()
grad_penalty = args.r1_gamma / 2 * grad_penalty
grad_penalty.backward()
# train with fake
latent_z = torch.randn(batch_size, nz, device=device)
x_0_predict = netG(x_tp1.detach(), t, latent_z)
x_pos_sample, _ = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
output = netD(x_pos_sample, t, x_tp1.detach())
errD_fake = phi_star1(
output
- args.tau
* torch.sum(
((x_0_predict - x_tp1.detach()).view(x_tp1.detach().size(0), -1)) ** 2,
dim=1,
)
)
errD_fake = errD_fake.mean()
errD_fake.backward()
errD = errD_real + errD_fake
# Update D
optimizerD.step()
# update G
for p in netD.parameters():
p.requires_grad = False
netG.zero_grad()
t = torch.randint(0, args.num_timesteps, (real_data.size(0),), device=device)
x_t, x_tp1, _ = q_sample_pairs(coeff, real_data, t)
latent_z = torch.randn(batch_size, nz, device=device)
x_0_predict = netG(x_tp1.detach(), t, latent_z)
x_pos_sample, noise = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
output = netD(x_pos_sample, t, x_tp1.detach())
errG = (
args.tau
* torch.sum(
((x_0_predict - x_tp1.detach()).view(x_tp1.detach().size(0), -1)) ** 2,
dim=1,
)
- output
)
errG = errG.mean()
errG.backward()
optimizerG.step()
global_step += 1
if iteration % 100 == 0:
if rank == 0:
print(
"epoch {} iteration{}, G Loss: {}, D Loss: {}".format(
epoch, iteration, errG.item(), errD.item()
)
)
if not args.no_lr_decay:
schedulerG.step()
schedulerD.step()
if rank == 0:
if epoch % 10 == 0:
torchvision.utils.save_image(
x_pos_sample,
os.path.join(exp_path, "xpos_epoch_{}.png".format(epoch)),
normalize=True,
)
x_t_1 = torch.randn_like(real_data)
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1, T, args)
if epoch % 10 == 0:
torchvision.utils.save_image(
fake_sample,
os.path.join(exp_path, "sample_discrete_epoch_{}.png".format(epoch)),
normalize=True,
)
if args.save_content:
if epoch % args.save_content_every == 0:
print("Saving content.")
content = {
"epoch": epoch + 1,
"global_step": global_step,
"args": args,
"netG_dict": netG.state_dict(),
"optimizerG": optimizerG.state_dict(),
"schedulerG": schedulerG.state_dict(),
"netD_dict": netD.state_dict(),
"optimizerD": optimizerD.state_dict(),
"schedulerD": schedulerD.state_dict(),
}
torch.save(content, os.path.join(exp_path, "content.pth"))
if epoch % args.save_ckpt_every == 0:
if args.use_ema:
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
torch.save(
netG.state_dict(),
os.path.join(exp_path, "netG_{}.pth".format(epoch)),
)
if args.use_ema:
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
def init_processes(rank, size, fn, args):
"""Initialize the distributed environment."""
os.environ["MASTER_ADDR"] = args.master_address
os.environ["MASTER_PORT"] = args.master_port
torch.cuda.set_device(args.local_rank)
gpu = args.local_rank
dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=size)
fn(rank, gpu, args)
dist.barrier()
cleanup()
def cleanup():
dist.destroy_process_group()
# %%
if __name__ == "__main__":
args = args_parser()
args.world_size = args.num_proc_node * args.num_process_per_node
size = args.num_process_per_node
if size > 1:
processes = []
for rank in range(size):
args.local_rank = rank
global_rank = rank + args.node_rank * args.num_process_per_node
global_size = args.num_proc_node * args.num_process_per_node
args.global_rank = global_rank
print("Node rank %d, local proc %d, global proc %d" % (args.node_rank, rank, global_rank))
p = Process(target=init_processes, args=(global_rank, global_size, train, args))
p.start()
processes.append(p)
for p in processes:
p.join()
else:
init_processes(0, size, train, args)