Skip to content

Commit

Permalink
Add flexible load for lr scheduler.
Browse files Browse the repository at this point in the history
  • Loading branch information
veichta committed Oct 10, 2023
1 parent f7b587e commit 29d332e
Showing 1 changed file with 23 additions and 17 deletions.
40 changes: 23 additions & 17 deletions gluefactory/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@
"optimizer_options": {}, # optional arguments passed to the optimizer
"lr": 0.001, # learning rate
"lr_schedule": {
"type": None,
"type": None, # can be any of [factor, exp] or torch.optim.lr_scheduler
"start": 0,
"exp_div_10": 0,
"on_epoch": False,
"factor": 1.0,
"options": None, # add lr_scheduler arguments here
},
"lr_scaling": [(100, ["dampingnet.const"])],
"eval_every_iter": 1000, # interval for evaluation on the validation set
Expand Down Expand Up @@ -141,6 +142,26 @@ def filter_fn(x):
return params


def get_lr_scheduler(optimizer, conf):
"""Get lr scheduler specified by conf.train.lr_schedule."""
if conf.type not in ["factor", "exp", None]:
return getattr(torch.optim.lr_scheduler, conf.type)(optimizer, **conf.options)

# backward compatibility
def lr_fn(it): # noqa: E306
if conf.type is None:
return 1
if conf.type == "factor":
return 1.0 if it < conf.start else conf.factor
if conf.type == "exp":
gam = 10 ** (-1 / conf.exp_div_10)
return 1.0 if it < conf.start else gam
else:
raise ValueError(conf.type)

return torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_fn)


def pack_lr_parameters(params, base_lr, lr_scaling):
"""Pack each group of parameters with the respective scaled learning rate."""
filters, scales = tuple(zip(*[(n, s) for s, names in lr_scaling for n in names]))
Expand Down Expand Up @@ -310,22 +331,7 @@ def sigint_handler(signal, frame):

results = None # fix bug with it saving

def lr_fn(it): # noqa: E306
if conf.train.lr_schedule.type is None:
return 1
if conf.train.lr_schedule.type == "factor":
return (
1.0
if it < conf.train.lr_schedule.start
else conf.train.lr_schedule.factor
)
if conf.train.lr_schedule.type == "exp":
gam = 10 ** (-1 / conf.train.lr_schedule.exp_div_10)
return 1.0 if it < conf.train.lr_schedule.start else gam
else:
raise ValueError(conf.train.lr_schedule.type)

lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_fn)
lr_scheduler = get_lr_scheduler(optimizer=optimizer, conf=conf.train.lr_scheduler)
if args.restore:
optimizer.load_state_dict(init_cp["optimizer"])
if "lr_scheduler" in init_cp:
Expand Down

0 comments on commit 29d332e

Please sign in to comment.