-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add AMP to ImageNet classification and segmentation scripts + auto layout #1201
base: master
Are you sure you want to change the base?
Conversation
Job PR-1201-3 is done. |
21c9d60
to
52c5650
Compare
52c5650
to
a90357b
Compare
assert not opt.auto_layout or opt.amp, "--auto-layout needs to be used with --amp" | ||
|
||
if opt.amp: | ||
amp.init(layout_optimization=opt.auto_layout) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Referring to definition of amp.init()
here, seems there is no argument like layout_optimization
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's an internal feature, it will be added soon
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarification.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curiously, when setting both --amp
and --dtype float16
, what will be happening?
5682d9a
to
f51f405
Compare
Job PR-1201-9 is done. |
@@ -105,6 +106,10 @@ def parse_args(): | |||
help='name of training log file') | |||
parser.add_argument('--use-gn', action='store_true', | |||
help='whether to use group norm.') | |||
parser.add_argument('--amp', action='store_true', | |||
help='Use MXNet AMP for mixed precision training.') | |||
parser.add_argument('--auto-layout', action='store_true', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add an option like --target-dtype
since now we not only have float16
for amp, but bfloat16
. Then, we can pass target-dtype
to amp.init()
to enable float16/bfloat16 training for GPU and CPU respectively. Thanks.
if opt.resume_states is not '': | ||
trainer.load_states(opt.resume_states) | ||
|
||
if opt.amp: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here may need change to if opt.amp and opt.target_dtype == 'float16':
@@ -404,8 +417,13 @@ def train(ctx): | |||
p.astype('float32', copy=False)) for yhat, y, p in zip(outputs, label, teacher_prob)] | |||
else: | |||
loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)] | |||
for l in loss: | |||
l.backward() | |||
if opt.amp: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here may need change to if opt.amp and opt.target_dtype == 'float16':
@@ -210,7 +216,12 @@ def __init__(self, args, logger): | |||
v.wd_mult = 0.0 | |||
|
|||
self.optimizer = gluon.Trainer(self.net.module.collect_params(), args.optimizer, | |||
optimizer_params, kvstore=kv) | |||
optimizer_params, update_on_kvstore=(False if args.amp else None)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May I know why kvstore=kv
is deleted? Could you add it back? Thanks.
@@ -95,6 +96,11 @@ def parse_args(): | |||
# synchronized Batch Normalization | |||
parser.add_argument('--syncbn', action='store_true', default=False, | |||
help='using Synchronized Cross-GPU BatchNorm') | |||
# performance related | |||
parser.add_argument('--amp', action='store_true', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We usually add default=False
for arguments. Could you add it? Thank you.
@Kh4L any update on this PR? |
@Kh4L Any update for this? BTW, do you have numbers for the improvement? |
Signed-off-by: Serge Panev <[email protected]>
Signed-off-by: Serge Panev <[email protected]>
f51f405
to
f2e92a4
Compare
Signed-off-by: Serge Panev [email protected]