-
Notifications
You must be signed in to change notification settings - Fork 338
/
main_imagenet.py
578 lines (517 loc) · 28.9 KB
/
main_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
import datetime
import os, sys
import time
import warnings
import registry
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
from engine.utils.imagenet_utils import presets, transforms, utils, sampler
import torch
import torch.utils.data
import torchvision
#from sampler import RASampler
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
import torch_pruning as tp
from functools import partial
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
parser.add_argument("--data-path", default="~/Datasets/ImageNet/", type=str, help="dataset path")
parser.add_argument("--model", default="resnet18", type=str, help="model name")
parser.add_argument("--pretrained", action="store_true")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument("-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size")
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
parser.add_argument("-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)")
parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
parser.add_argument("--wd", "--weight-decay", default=1e-4, type=float, metavar="W", help="weight decay (default: 1e-4)", dest="weight_decay")
parser.add_argument("--norm-weight-decay", default=None, type=float, help="weight decay for Normalization layers (default: None, same value as --wd)")
parser.add_argument("--bias-weight-decay", default=None, type=float, help="weight decay for bias parameters of all layers (default: None, same value as --wd)")
parser.add_argument("--transformer-embedding-decay", default=None, type=float, help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)")
parser.add_argument("--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing")
parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
parser.add_argument("--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)")
parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
parser.add_argument("--print-freq", default=100, type=int, help="print frequency")
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
parser.add_argument("--cache-dataset", dest="cache_dataset", help="Cache the datasets for quicker initialization. It also serializes the transforms", action="store_true")
parser.add_argument("--sync-bn", dest="sync_bn", help="Use sync batch norm", action="store_true")
parser.add_argument("--test-only", dest="test_only", help="Only test the model", action="store_true")
parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
# distributed training parameters
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
parser.add_argument("--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters")
parser.add_argument("--model-ema-steps", type=int, default=32, help="the number of iterations that controls how often to update the EMA model (default: 32)")
parser.add_argument("--model-ema-decay", type=float, default=0.99998, help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)")
parser.add_argument("--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only.")
parser.add_argument("--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)")
parser.add_argument("--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)")
parser.add_argument("--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)")
parser.add_argument("--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)")
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
parser.add_argument("--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)")
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
# pruning parameters
parser.add_argument("--prune", action="store_true")
parser.add_argument("--method", type=str, default='l1')
parser.add_argument("--global-pruning", default=False, action="store_true")
parser.add_argument("--target-flops", type=float, default=2.0, help="GFLOPs of pruned model")
parser.add_argument("--soft-keeping-ratio", type=float, default=0.0)
parser.add_argument("--reg", type=float, default=1e-4)
parser.add_argument("--delta_reg", type=float, default=1e-4)
parser.add_argument("--max-pruning-ratio", default=1.0, type=float, help="maximum channel pruning ratio")
parser.add_argument("--sl-epochs", type=int, default=None)
parser.add_argument("--sl-resume", type=str, default=None)
parser.add_argument("--sl-lr", default=None, type=float, help="learning rate")
parser.add_argument("--sl-lr-step-size", default=None, type=int, help="milestones for learning rate decay")
parser.add_argument("--sl-lr-warmup-epochs", default=None, type=int, help="the number of epochs to warmup (default: 0)")
return parser
def prune_to_target_flops(pruner, model, target_flops, example_inputs):
model.eval()
ori_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
pruned_ops = ori_ops
while pruned_ops / 1e9 > target_flops:
pruner.step()
if 'vit' in args.model:
model.hidden_dim = model.conv_proj.out_channels
pruned_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
return pruned_ops
def get_pruner(model, example_inputs, args):
unwrapped_parameters = (
[model.encoder.pos_embedding, model.class_token] if "vit" in args.model else None
)
sparsity_learning = False
data_dependency = False
if args.method == "random":
imp = tp.importance.RandomImportance()
pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning)
elif args.method == "l1":
imp = tp.importance.MagnitudeImportance(p=1)
pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning)
elif args.method == "lamp":
imp = tp.importance.LAMPImportance(p=2)
pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning)
elif args.method == "slim":
sparsity_learning = True
imp = tp.importance.BNScaleImportance()
pruner_entry = partial(tp.pruner.BNScalePruner, reg=args.reg, global_pruning=args.global_pruning)
elif args.method == "group_norm":
imp = tp.importance.GroupNormImportance(p=2)
pruner_entry = partial(tp.pruner.GroupNormPruner, global_pruning=args.global_pruning)
elif args.method == "group_greg":
sparsity_learning = True
imp = tp.importance.GroupNormImportance(p=2)
pruner_entry = partial(tp.pruner.GrowingRegPruner, reg=args.reg, delta_reg=args.delta_reg, global_pruning=args.global_pruning)
elif args.method == "group_sl":
sparsity_learning = True
imp = tp.importance.GroupNormImportance(p=2)
pruner_entry = partial(tp.pruner.GroupNormPruner, reg=args.reg, global_pruning=args.global_pruning)
else:
raise NotImplementedError
args.data_dependency = data_dependency
args.sparsity_learning = sparsity_learning
ignored_layers = []
pruning_ratio_dict = {}
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
ignored_layers.append(m)
round_to = None
if 'vit' in args.model:
round_to = model.encoder.layers[0].num_heads
pruner = pruner_entry(
model,
example_inputs,
importance=imp,
iterative_steps=100,
pruning_ratio=1.0,
pruning_ratio_dict=pruning_ratio_dict,
max_pruning_ratio=args.max_pruning_ratio,
ignored_layers=ignored_layers,
round_to=round_to,
unwrapped_parameters=unwrapped_parameters,
)
return pruner
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None, pruner=None, recover=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
header = f"Epoch: [{epoch}]"
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
start_time = time.time()
image, target = image.to(device), target.to(device)
with torch.cuda.amp.autocast(enabled=scaler is not None):
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
if scaler is not None:
scaler.scale(loss).backward()
if pruner:
scaler.unscale_(optimizer)
pruner.regularize(model)
#if recover:
# recover(model.module)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
if pruner is not None:
pruner.regularize(model)
if recover:
recover(model.module)
if args.clip_grad_norm is not None:
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
optimizer.step()
if model_ema and i % args.model_ema_steps == 0:
model_ema.update_parameters(model)
if epoch < args.lr_warmup_epochs:
# Reset ema buffer to keep copying weights during warmup period
model_ema.n_averaged.fill_(0)
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
batch_size = image.shape[0]
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
if pruner is not None and isinstance(pruner, tp.pruner.GrowingRegPruner):
pruner.update_reg()
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = f"Test: {log_suffix}"
num_processed_samples = 0
with torch.no_grad():
for image, target in metric_logger.log_every(data_loader, print_freq, header):
image = image.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(image)
loss = criterion(output, target)
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
# FIXME need to take into account that the datasets
# could have been padded in distributed setup
batch_size = image.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
num_processed_samples += batch_size
# gather the stats from all processes
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
if (
hasattr(data_loader.dataset, "__len__")
and len(data_loader.dataset) != num_processed_samples
and torch.distributed.get_rank() == 0
):
# See FIXME above
warnings.warn(
f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
"samples were used for the validation, which might bias the results. "
"Try adjusting the batch size and / or the world size. "
"Setting the world size to 1 is always a safe bet."
)
metric_logger.synchronize_between_processes()
print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
return metric_logger.acc1.global_avg
def _get_cache_path(filepath):
import hashlib
h = hashlib.sha1(filepath.encode()).hexdigest()
cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
cache_path = os.path.expanduser(cache_path)
return cache_path
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data...")
resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (256, 224)
print("Loading training data...")
st = time.time()
cache_path = _get_cache_path(traindir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path)
else:
auto_augment_policy = getattr(args, "auto_augment", None)
random_erase_prob = getattr(args, "random_erase", 0.0)
dataset = torchvision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob))
if args.cache_dataset:
print("Saving dataset_train to {}...".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
print("Data loading took", time.time() - st)
print("Loading validation data...")
cache_path = _get_cache_path(valdir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path)
else:
dataset_test = torchvision.datasets.ImageFolder(
valdir,
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size))
if args.cache_dataset:
print("Saving dataset_test to {}...".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)
print("Creating data loaders...")
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
else:
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
return dataset, dataset_test, train_sampler, test_sampler
def main(args):
if args.output_dir:
utils.mkdir(args.output_dir)
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
if args.use_deterministic_algorithms:
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
else:
torch.backends.cudnn.benchmark = True
train_dir = os.path.join(args.data_path, "train")
val_dir = os.path.join(args.data_path, "val")
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
collate_fn = None
num_classes = len(dataset.classes)
mixup_transforms = []
if args.mixup_alpha > 0.0:
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
if args.cutmix_alpha > 0.0:
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
if mixup_transforms:
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
def collate_fn(batch):
return mixupcutmix(*default_collate(batch))
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=args.workers,
pin_memory=True,
collate_fn=collate_fn,
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
)
print("Creating model")
model = registry.get_model(num_classes=1000, name=args.model, pretrained=args.pretrained, target_dataset='imagenet') #torchvision.models.__dict__[args.model](pretrained=args.pretrained) #torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
model.eval()
print("="*16)
print(model)
example_inputs = torch.randn(1, 3, 224, 224)
base_ops, base_params = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
print("Params: {:.4f} M".format(base_params / 1e6))
print("ops: {:.4f} G".format(base_ops / 1e9))
print("="*16)
if args.prune:
pruner = get_pruner(model, example_inputs=example_inputs, args=args)
if args.sparsity_learning:
if args.sl_resume:
print("Loading sparse model from {}...".format(args.sl_resume))
model.load_state_dict( torch.load(args.sl_resume, map_location='cpu')['model'] )
else:
print("Sparsifying model...")
if args.sl_lr is None: args.sl_lr = args.lr
if args.sl_lr_step_size is None: args.sl_lr_step_size = args.lr_step_size
if args.sl_lr_warmup_epochs is None: args.sl_lr_warmup_epochs = args.lr_warmup_epochs
if args.sl_epochs is None: args.sl_epochs = args.epochs
train(model, args.sl_epochs,
lr=args.sl_lr, lr_step_size=args.sl_lr_step_size, lr_warmup_epochs=args.sl_lr_warmup_epochs,
train_sampler=train_sampler, data_loader=data_loader, data_loader_test=data_loader_test,
device=device, args=args, pruner=pruner, state_dict_only=True)
#model.load_state_dict( torch.load('regularized_{:.4f}_best.pth'.format(args.reg), map_location='cpu')['model'] )
#utils.save_on_master(
# model_without_ddp.state_dict(),
# os.path.join(args.output_dir, 'regularized-{:.4f}.pth'.format(args.reg)))
model = model.to('cpu')
print("Pruning model...")
prune_to_target_flops(pruner, model, args.target_flops, example_inputs)
pruned_ops, pruned_size = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
print("="*16)
print("After pruning:")
print(model)
print("Params: {:.2f} M => {:.2f} M ({:.2f}%)".format(base_params / 1e6, pruned_size / 1e6, pruned_size / base_params * 100))
print("Ops: {:.2f} G => {:.2f} G ({:.2f}%, {:.2f}X )".format(base_ops / 1e9, pruned_ops / 1e9, pruned_ops / base_ops * 100, base_ops / pruned_ops))
print("="*16)
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
print("Finetuning..." if args.prune else "Training...")
train(model, args.epochs,
lr=args.lr, lr_step_size=args.lr_step_size, lr_warmup_epochs=args.lr_warmup_epochs,
train_sampler=train_sampler, data_loader=data_loader, data_loader_test=data_loader_test,
device=device, args=args, pruner=None, state_dict_only=(not args.prune))
def train(
model,
epochs,
lr, lr_step_size, lr_warmup_epochs,
train_sampler, data_loader, data_loader_test,
device, args, pruner=None, state_dict_only=True, recover=None):
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.label_smoothing>0:
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
else:
criterion = nn.CrossEntropyLoss()
weight_decay = args.weight_decay if pruner is None else 0
bias_weight_decay = args.bias_weight_decay if pruner is None else 0
norm_weight_decay = args.norm_weight_decay if pruner is None else 0
custom_keys_weight_decay = []
if bias_weight_decay is not None:
custom_keys_weight_decay.append(("bias", bias_weight_decay))
if args.transformer_embedding_decay is not None:
for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
parameters = utils.set_weight_decay(
model,
weight_decay,
norm_weight_decay=norm_weight_decay,
custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
)
opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD(
parameters,
lr=lr,
momentum=args.momentum,
weight_decay=weight_decay,
nesterov="nesterov" in opt_name,
)
elif opt_name == "rmsprop":
optimizer = torch.optim.RMSprop(
parameters, lr=lr, momentum=args.momentum, weight_decay=weight_decay, eps=0.0316, alpha=0.9
)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=lr, weight_decay=weight_decay)
else:
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
scaler = torch.cuda.amp.GradScaler() if args.amp else None
args.lr_scheduler = args.lr_scheduler.lower()
if args.lr_scheduler == "steplr":
main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_step_size, gamma=args.lr_gamma)
elif args.lr_scheduler == "cosineannealinglr":
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs - lr_warmup_epochs, eta_min=args.lr_min
)
elif args.lr_scheduler == "exponentiallr":
main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
else:
raise RuntimeError(
f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
"are supported."
)
if lr_warmup_epochs > 0:
if args.lr_warmup_method == "linear":
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=args.lr_warmup_decay, total_iters=lr_warmup_epochs
)
elif args.lr_warmup_method == "constant":
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
optimizer, factor=args.lr_warmup_decay, total_iters=lr_warmup_epochs
)
else:
raise RuntimeError(
f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
)
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[lr_warmup_epochs]
)
else:
lr_scheduler = main_lr_scheduler
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
model_ema = None
if args.model_ema:
# Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
# https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
#
# total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
# We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
# adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
adjust = args.world_size * args.batch_size * args.model_ema_steps / epochs
alpha = 1.0 - args.model_ema_decay
alpha = min(1.0, alpha * adjust)
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu")
model_without_ddp.load_state_dict(checkpoint["model"])
if not args.test_only:
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
if model_ema:
model_ema.load_state_dict(checkpoint["model_ema"])
if scaler:
scaler.load_state_dict(checkpoint["scaler"])
if args.test_only:
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if model_ema:
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
else:
evaluate(model, criterion, data_loader_test, device=device)
return
start_time = time.time()
best_acc = 0
prefix = '' if pruner is None else 'regularized_{:e}_'.format(args.reg)
for epoch in range(args.start_epoch, epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler, pruner, recover=recover)
lr_scheduler.step()
acc = evaluate(model, criterion, data_loader_test, device=device)
if model_ema:
acc = evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
if args.output_dir:
checkpoint = {
"model": model_without_ddp.state_dict() if state_dict_only else model_without_ddp,
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args,
}
if model_ema:
checkpoint["model_ema"] = model_ema.state_dict()
if scaler:
checkpoint["scaler"] = scaler.state_dict()
if acc>best_acc:
best_acc=acc
utils.save_on_master(checkpoint, os.path.join(args.output_dir, prefix+"best.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, prefix+"latest.pth"))
print("Epoch {}/{}, Current Best Acc = {:.6f}".format(epoch, epochs, best_acc))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"Training time {total_time_str}")
return model_without_ddp
if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)