-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
541 lines (453 loc) · 19.7 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
"""Trainer and Evaluator for DETR.
Modifications from Scenic:
* Optax APIs instead of deprecated flax.optim API.
* Refactored `train_step`.
* Addl. debug logging.
"""
import os
from concurrent import futures
import functools
import time
from typing import Any, Callable, Optional, Tuple
from absl import logging
from clu import metric_writers, periodic_actions
import flax
from flax import jax_utils
from flax.training.checkpoints import \
restore_checkpoint as flax_restore_checkpoint
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
from dataset_lib import dataset_utils
import detr_train_utils
from models import detr
from train_lib import pretrain_utils, train_utils
import utils as u
def get_train_step(apply_fn: Callable, loss_and_metrics_fn: Callable,
update_batch_stats: bool, tx: optax.GradientTransformation):
"""Returns a function that runs a single step of training.
Given the state of the training and a batch of data, the function computes
the loss and updates the parameters of the model.
Buffers of the first (train_state) argument is donated to the computation.
Args:
apply_fn: Flax model apply function.
loss_and_metrics_fn: Function to calculate loss and metrics.
update_batch_stats: bool; whether to update BN stats during training.
tx: An `optax.GradientTransformation`
Returns:
Train step function that takes a `train_state` and `batch` and returns
`new_train_state`, `metrics` and `predictions`.
"""
def train_step(train_state, batch):
def loss_fn(params):
new_rng, rng = jax.random.split(train_state.rng)
# Bind the rng to the host/device we are on.
model_rng = train_utils.bind_rng_to_host_device(
rng, axis_name='batch', bind_to='device')
variables = {'params': params, **train_state.model_state}
predictions, new_model_state = apply_fn(
variables,
batch['inputs'],
padding_mask=batch['padding_mask'],
update_batch_stats=update_batch_stats,
mutable=train_state.model_state.keys(),
train=True,
rngs={'dropout': model_rng})
loss, metrics = loss_and_metrics_fn(
predictions, batch, model_params=params)
return loss, (new_model_state, metrics, predictions, new_rng)
new_global_step = train_state.global_step + 1
(_, (new_model_state, metrics, predictions,
new_rng)), grads = jax.value_and_grad(
loss_fn, has_aux=True)(
train_state.params)
grads = jax.lax.pmean(grads, axis_name='batch')
updates, new_opt_state = tx.update(
grads, train_state.opt_state, params=train_state.params)
new_params = optax.apply_updates(train_state.params, updates)
train_state = train_state.replace(
global_step=new_global_step,
params=new_params,
opt_state=new_opt_state,
model_state=new_model_state,
rng=new_rng)
# Measurements
gs = jax.tree.leaves(grads)
metrics['l2_grads'] = (jnp.sqrt(sum([jnp.vdot(g, g) for g in gs])), 1)
ps = jax.tree.leaves(new_params)
metrics['l2_params'] = (jnp.sqrt(sum([jnp.vdot(p, p) for p in ps])), 1)
us = jax.tree.leaves(updates)
metrics['l2_updates'] = (jnp.sqrt(sum([jnp.vdot(u, u) for u in us])), 1)
return train_state, metrics, predictions
return train_step
def get_eval_step(flax_model,
loss_and_metrics_fn,
logits_to_probs_fn,
metrics_only=False):
"""Runs a single step of evaluation.
Note that in this code, the buffer of the second argument (batch) is donated
to the computation.
Args:
flax_model: Flax model (an instance of nn.Module).
loss_and_metrics_fn: A function that given model predictions, a batch and
parameters of the model calculates the loss as well as metrics.
logits_to_probs_fn: Function that takes logits and converts them to probs.
metrics_only: bool; Only return metrics.
Returns:
Eval step function which returns predictions and calculated metrics.
"""
def metrics_fn(train_state, batch, predictions):
_, metrics = loss_and_metrics_fn(
predictions, batch, model_params=train_state.params)
if metrics_only:
return None, None, metrics
pred_probs = logits_to_probs_fn(predictions['pred_logits'])
# Collect necessary predictions and target information from all hosts.
predictions_out = {
'pred_probs': pred_probs,
'pred_logits': predictions['pred_logits'],
'pred_boxes': predictions['pred_boxes']
}
labels = {
'image/id': batch['label']['image/id'],
'size': batch['label']['size'],
'orig_size': batch['label']['orig_size']
}
to_copy = [
'labels', 'boxes', 'not_exhaustive_category_ids', 'neg_category_ids'
]
for name in to_copy:
if name in batch['label']:
labels[name] = batch['label'][name]
targets = {'label': labels, 'batch_mask': batch['batch_mask']}
predictions_out = jax.lax.all_gather(predictions_out, 'batch')
targets = jax.lax.all_gather(targets, 'batch')
return targets, predictions_out, metrics
def eval_step(train_state, batch):
variables = {'params': train_state.params, **train_state.model_state}
predictions = flax_model.apply(
variables,
batch['inputs'],
padding_mask=batch['padding_mask'],
train=False,
mutable=False)
return metrics_fn(train_state, batch, predictions)
return eval_step
def make_optimizer(
config: ml_collections.ConfigDict, params, *, sched_kw: dict
) -> Tuple[optax.GradientTransformation, Callable[..., float]]:
"""Makes an Optax optimizer for DETR."""
oc = config.optimizer_configs
def is_bn(path):
# For DETR we need to skip the BN affine transforms as well.
if not config.freeze_backbone_batch_stats:
return False
names = ['bn1', 'bn2', 'bn3', 'downsample/1']
for s in names:
if s in path:
return True
return False
def is_early_layer(path):
if not config.load_pretrained_backbone:
return False
for name in ["backbone/conv1", "backbone/bn1", "backbone/layer1"]:
if name in path:
return True
return False
backbone_traversal = flax.traverse_util.ModelParamTraversal(
lambda path, _: ('backbone' in path) and not is_bn(path))
bn_traversal = flax.traverse_util.ModelParamTraversal(
lambda path, _: is_bn(path))
early_layer_traversal = flax.traverse_util.ModelParamTraversal(
lambda path, _: is_early_layer(path))
weight_decay_traversal = flax.traverse_util.ModelParamTraversal(
lambda path, _: path.endswith('kernel'))
all_false = jax.tree_util.tree_map(lambda _: False, params)
def get_mask(traversal: flax.traverse_util.ModelParamTraversal):
return traversal.update(lambda _: True, all_false)
# Masks
bn_mask = get_mask(bn_traversal)
backbone_mask = get_mask(backbone_traversal)
early_layer_mask = get_mask(early_layer_traversal)
weight_decay_mask = get_mask(weight_decay_traversal)
# LR Schedule
sched_fn = u.create_learning_rate_schedule(
**sched_kw, **oc.schedule, base=oc.base_lr)
# Optimizer
tx = optax.chain(
optax.clip_by_global_norm(oc.grad_clip_norm),
optax.adamw(
learning_rate=sched_fn,
mask=weight_decay_mask,
**oc.optax_kw,
), optax.masked(optax.scale(oc.backbone_lr_reduction), backbone_mask),
optax.masked(optax.set_to_zero(), bn_mask),
optax.masked(optax.set_to_zero(), early_layer_mask))
return tx, sched_fn
def train_and_evaluate(*, rng: jnp.ndarray, dataset: dataset_utils.Dataset,
config: ml_collections.ConfigDict, workdir: str,
writer: metric_writers.MetricWriter):
lead_host = jax.process_index() == 0
# Store a copy of the experiment config.
if lead_host:
with open(os.path.join(workdir, 'config.json'), 'w') as f:
f.write(config.to_json())
def info(s, *a):
if lead_host:
logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a)
# This pool is used for async I/O operations like logging metrics
pool = futures.ThreadPoolExecutor(max_workers=2)
# Calculate total train steps using available information.
ntrain_img = dataset.meta_data['num_train_examples']
total_steps = u.steps(
'total', config, data_size=ntrain_img, batch_size=config.batch_size)
info('Running for %d steps (%f epochs)', total_steps,
total_steps / (ntrain_img / config.batch_size))
# Initialize model, loss_fn
model = detr.DETRModel(config, dataset.meta_data)
rng, init_rng = jax.random.split(rng)
(params, model_state, num_trainable_params,
gflops) = train_utils.initialize_model(
model=model.flax_model,
input_spec=[(dataset.meta_data['input_shape'],
dataset.meta_data.get('input_dtype', jnp.float32))],
config=config,
rngs=init_rng)
# Create optimizer.
tx, sched_fn = make_optimizer(
config,
params,
sched_kw=dict(
total_steps=total_steps,
batch_size=config.batch_size,
data_size=ntrain_img))
opt_state = jax.jit(tx.init, backend='cpu')(params)
sched_fn_cpu = jax.jit(sched_fn, backend='cpu')
# Build TrainState
rng, train_rng = jax.random.split(rng)
train_state = train_utils.TrainState(
global_step=0,
params=params,
opt_state=opt_state,
model_state=model_state,
rng=train_rng)
# Load checkpoint/pretrained weights
start_step = train_state.global_step
if config.checkpoint:
train_state, start_step = train_utils.restore_checkpoint(
workdir, train_state)
if (start_step == 0 # Which means no checkpoint was restored.
and config.get('init_from') is not None):
init_checkpoint_path = config.init_from.get('checkpoint_path')
restored_train_state = flax_restore_checkpoint(
init_checkpoint_path, target=None)
train_state = pretrain_utils.init_from_pretrain_state(
train_state,
restored_train_state,
ckpt_prefix_path=config.init_from.get('ckpt_prefix_path'),
model_prefix_path=config.init_from.get('model_prefix_path'),
name_mapping=config.init_from.get('name_mapping'),
skip_regex=config.init_from.get('skip_regex'))
del restored_train_state
elif start_step == 0 and config.get('load_pretrained_backbone', False):
# Only load pretrained backbone if we are at the beginning of training.
bb_ckpt_path = config.pretrained_backbone_configs.get('checkpoint_path')
bb_train_state = flax_restore_checkpoint(bb_ckpt_path, target=None)
train_state = pretrain_utils.init_from_pretrain_state(
train_state, bb_train_state, model_prefix_path=['backbone'])
# Calculate total number of training steps.
steps_per_epoch = ntrain_img // config.batch_size
update_batch_stats = not config.get('freeze_backbone_batch_stats', False)
if not update_batch_stats:
if not config.load_pretrained_backbone:
raise ValueError(
'Freezing backbone stats without a pretrained backbone '
'does not make rational sense. Please check your config.')
# Replicate.
train_state = flax.jax_utils.replicate(train_state)
del params
train_step = get_train_step(
apply_fn=model.flax_model.apply,
loss_and_metrics_fn=model.loss_function,
update_batch_stats=update_batch_stats,
tx=tx)
train_step_pmapped = jax.pmap(
train_step, axis_name='batch', donate_argnums=(0,))
# Evaluation code.
eval_step = get_eval_step(
flax_model=model.flax_model,
loss_and_metrics_fn=model.loss_function,
logits_to_probs_fn=model.logits_to_probs)
eval_step_pmapped = jax.pmap(
eval_step, axis_name='batch', donate_argnums=(1,))
# Ceil rounding such that we include the last incomplete batch.
eval_batch_size = config.get('eval_batch_size', config.batch_size)
total_eval_steps = int(
np.ceil(dataset.meta_data['num_eval_examples'] / eval_batch_size))
steps_per_eval = config.get('steps_per_eval') or total_eval_steps
metrics_normalizer_fn = functools.partial(
detr_train_utils.normalize_metrics_summary,
object_detection_loss_keys=model.loss_terms_weights.keys())
def evaluate(train_state, step):
"""Runs evaluation code."""
future = None
def _wait(future: Optional[futures.Future]) -> Any:
if future is None:
return future
return future.result()
def _add_examples(predictions, labels):
for pred, label in zip(predictions, labels):
global_metrics_evaluator.add_example(prediction=pred, target=label)
eval_metrics = []
if global_metrics_evaluator is not None:
global_metrics_evaluator.clear()
for eval_step in range(steps_per_eval):
eval_batch = next(dataset.valid_iter)
# Do the eval step.
(eval_batch_all_hosts, eval_predictions_all_hosts,
e_metrics) = eval_step_pmapped(train_state, eval_batch)
# aux_outputs is not needed anymore.
eval_predictions_all_hosts.pop('aux_outputs', None)
# Collect local metrics (returned by the loss function).
eval_metrics.append(train_utils.unreplicate_and_get(e_metrics))
if global_metrics_evaluator is not None:
# Unreplicate the output of eval_step_pmapped (used `lax.all_gather`).
eval_batch_all_hosts = jax_utils.unreplicate(eval_batch_all_hosts)
eval_predictions_all_hosts = jax_utils.unreplicate(
eval_predictions_all_hosts)
# Collect preds and labels to be sent for computing global metrics.
predictions = detr_train_utils.process_and_fetch_to_host(
eval_predictions_all_hosts, eval_batch_all_hosts['batch_mask'])
predictions = jax.tree_util.tree_map(np.asarray, predictions)
labels = detr_train_utils.process_and_fetch_to_host(
eval_batch_all_hosts['label'], eval_batch_all_hosts['batch_mask'])
labels = jax.tree_util.tree_map(np.asarray, labels)
if eval_step == 0:
logging.info('Pred keys: %s', list(predictions[0].keys()))
logging.info('Labels keys: %s', list(labels[0].keys()))
# Add to evaluator.
_wait(future)
future = pool.submit(_add_examples, predictions, labels)
del predictions, labels
del eval_batch, eval_batch_all_hosts, eval_predictions_all_hosts
eval_global_metrics_summary_future = None
if global_metrics_evaluator is not None:
_wait(future)
logging.info('Number of eval examples: %d', len(global_metrics_evaluator))
eval_global_metrics_summary_future = pool.submit(
global_metrics_evaluator.compute_metrics, clear_annotations=False)
return (step, eval_metrics), eval_global_metrics_summary_future
#####################################################
log_eval_steps = config.get('log_eval_steps') or steps_per_epoch
log_summary_steps = config.get('log_summary_steps', 25)
log_large_summary_steps = config.get('log_large_summary_steps', 0)
checkpoint_steps = config.get('checkpoint_steps') or log_eval_steps
global_metrics_evaluator = None # Only run eval on the lead host.
if lead_host:
global_metrics_evaluator = detr_train_utils.DetrGlobalEvaluator(
config.dataset_configs.name, annotations_loc=config.annotations_loc)
train_metrics, extra_training_logs = [], []
train_summary, eval_summary = None, None
info('Starting training loop at step %d.', start_step + 1)
report_progress = periodic_actions.ReportProgress(
num_train_steps=total_steps,
writer=writer,
every_secs=None,
every_steps=log_summary_steps,
)
hooks = []
if lead_host:
hooks.append(report_progress)
if config.get('xprof', True) and lead_host:
hooks.append(periodic_actions.Profile(num_profile_steps=5, logdir=workdir))
if start_step == 0:
step0_log = {'num_trainable_params': num_trainable_params}
if gflops:
step0_log['gflops'] = gflops
writer.write_scalars(1, step0_log)
(last_eval_step, last_eval_metrics), last_eval_future = (None, None), None
for step in range(start_step + 1, total_steps + 1):
with jax.profiler.StepTraceAnnotation('train', step_num=step):
train_batch = next(dataset.train_iter)
(train_state, t_metrics,
train_predictions) = train_step_pmapped(train_state, train_batch)
# Accumulate metrics (do not use for large metrics like segmentation maps).
train_metrics.append(train_utils.unreplicate_and_get(t_metrics))
[h(step) for h in hooks]
if (log_large_summary_steps and step % log_large_summary_steps == 0 and
lead_host):
################# LOG EXPENSIVE TRAIN SUMMARY ################
# Visualize detections side-by-side using gt-pred images.
to_cpu = lambda x: jax.device_get(dataset_utils.unshard(x))
del train_batch['batch_mask']
train_pred_cpu = to_cpu(train_predictions)
train_batch_cpu = to_cpu(train_batch)
viz = detr_train_utils.draw_boxes_side_by_side(
train_pred_cpu,
train_batch_cpu,
label_map=dataset.meta_data['label_to_name'])
writer.write_images(step, {
f'sidebyside_{i}/detection': viz_[None, ...]
for i, viz_ in enumerate(viz)
})
del train_predictions
if (step % log_summary_steps == 0) or (step == total_steps - 1):
########## LOG TRAIN SUMMARY #########
extra_training_logs.append({"global_schedule": sched_fn_cpu(step - 1)})
train_summary = train_utils.log_train_summary(
step,
writer=writer,
train_metrics=train_metrics,
extra_training_logs=extra_training_logs,
metrics_normalizer_fn=metrics_normalizer_fn)
# Reset for next round.
train_metrics, extra_training_logs = [], []
######################################
if (step % log_eval_steps == 0) or (step == total_steps):
# First wait for the previous eval to finish and write summary.
if last_eval_future is not None:
train_utils.log_eval_summary(
step=last_eval_step,
eval_metrics=last_eval_metrics,
extra_eval_summary=last_eval_future.result(),
writer=writer,
metrics_normalizer_fn=metrics_normalizer_fn)
last_eval_future = None
# Sync model state across replicas (in case of having model state, e.g.
# batch statistics when using BatchNorm).
start_time = time.time()
with report_progress.timed('eval'):
train_state = train_utils.sync_model_state_across_replicas(train_state)
(last_eval_step,
last_eval_metrics), last_eval_future = evaluate(train_state, step)
duration = time.time() - start_time
info('Done with async evaluation %.4f sec.', duration)
writer.flush()
####################### CHECKPOINTING ####################
if ((step % checkpoint_steps == 0 and step > 0) or
(step == total_steps)) and config.checkpoint:
with report_progress.timed('checkpoint'):
# Sync model state across replicas.
train_state = train_utils.sync_model_state_across_replicas(train_state)
if lead_host:
train_utils.save_checkpoint(workdir,
jax_utils.unreplicate(train_state))
# Last eval (useful if training is skipped).
with report_progress.timed('eval'):
train_state = train_utils.sync_model_state_across_replicas(train_state)
(last_eval_step,
last_eval_metrics), last_eval_future = evaluate(train_state,
total_steps + 1)
# Wait until computations are done before exiting.
pool.shutdown()
if last_eval_future is not None:
train_utils.log_eval_summary(
step=last_eval_step,
eval_metrics=last_eval_metrics,
extra_eval_summary=last_eval_future.result(),
writer=writer,
metrics_normalizer_fn=metrics_normalizer_fn)
jax.random.normal(jax.random.key(0), ()).block_until_ready()
return train_state, train_summary, eval_summary