Skip to content

Commit

Permalink
timing log and separate optim for g and d
Browse files Browse the repository at this point in the history
  • Loading branch information
AnniLi1212 committed Feb 14, 2024
1 parent 51814e8 commit d25ea25
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 32 deletions.
80 changes: 65 additions & 15 deletions setup_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,26 @@ def parse_optimization_args(parser):
type=str,
default="rmsprop",
help="pick optimizer",
choices=["adam", "rmsprop", "adadelta", "agcd"],
choices=["adam", "rmsprop", "adadelta", "agcd", "sgd"],
)
parser.add_argument(
"--use_different_optimizers",
action="store_true",
help="Use different optimizers for generator and discriminator",
)
parser.add_argument(
"--optimizer-G",
type=str,
default="rmsprop",
help="pick optimizer for generator",
choices=["adam", "rmsprop", "adadelta", "agcd", "sgd"],
)
parser.add_argument(
"--optimizer-D",
type=str,
default="rmsprop",
help="pick optimizer for discriminator",
choices=["adam", "rmsprop", "adadelta", "agcd", "sgd"],
)
parser.add_argument(
"--loss",
Expand Down Expand Up @@ -248,6 +267,12 @@ def parse_optimization_args(parser):
default=1,
help="number of generator updates for each critic update (num-critic must be 1 for this to apply)",
)
parser.add_argument(
"--sgd-momentum",
type=float,
default=0.9,
help="momentum for the SGD optimizer",
)


def parse_regularization_args(parser):
Expand Down Expand Up @@ -1506,7 +1531,6 @@ def get_model_args(args):

return model_train_args, model_eval_args, extra_args


def optimizers(args, G, D):
if args.spectral_norm_gen:
G_params = filter(lambda p: p.requires_grad, G.parameters())
Expand All @@ -1518,19 +1542,45 @@ def optimizers(args, G, D):
else:
D_params = D.parameters()

if args.optimizer == "rmsprop":
G_optimizer = optim.RMSprop(G_params, lr=args.lr_gen)
D_optimizer = optim.RMSprop(D_params, lr=args.lr_disc)
elif args.optimizer == "adadelta":
G_optimizer = optim.Adadelta(G_params, lr=args.lr_gen)
D_optimizer = optim.Adadelta(D_params, lr=args.lr_disc)
elif args.optimizer == "adam" or args.optimizer == "None":
G_optimizer = optim.Adam(
G_params, lr=args.lr_gen, weight_decay=5e-4, betas=(args.beta1, args.beta2)
)
D_optimizer = optim.Adam(
D_params, lr=args.lr_disc, weight_decay=5e-4, betas=(args.beta1, args.beta2)
)
if args.use_different_optimizers:
if args.optimizer_G == "rmsprop":
G_optimizer = optim.RMSprop(G_params, lr=args.lr_gen)
elif args.optimizer_G == "adadelta":
G_optimizer = optim.Adadelta(G_params, lr=args.lr_gen)
elif args.optimizer_G == "adam":
G_optimizer = optim.Adam(
G_params, lr=args.lr_gen, weight_decay=5e-4, betas=(args.beta1, args.beta2)
)
elif args.optimizer_G == "sgd":
G_optimizer = optim.SGD(G_params, lr=args.lr_gen, momentum=args.sgd_momentum)

if args.optimizer_D == "rmsprop":
D_optimizer = optim.RMSprop(D_params, lr=args.lr_disc)
elif args.optimizer_D == "adadelta":
D_optimizer = optim.Adadelta(D_params, lr=args.lr_disc)
elif args.optimizer_D == "adam":
D_optimizer = optim.Adam(
D_params, lr=args.lr_disc, weight_decay=5e-4, betas=(args.beta1, args.beta2)
)
elif args.optimizer_D == "sgd":
D_optimizer = optim.SGD(D_params, lr=args.lr_disc, momentum=args.sgd_momentum)
else:
if args.optimizer == "rmsprop":
G_optimizer = optim.RMSprop(G_params, lr=args.lr_gen)
D_optimizer = optim.RMSprop(D_params, lr=args.lr_disc)
elif args.optimizer == "adadelta":
G_optimizer = optim.Adadelta(G_params, lr=args.lr_gen)
D_optimizer = optim.Adadelta(D_params, lr=args.lr_disc)
elif args.optimizer == "adam" or args.optimizer == "None":
G_optimizer = optim.Adam(
G_params, lr=args.lr_gen, weight_decay=5e-4, betas=(args.beta1, args.beta2)
)
D_optimizer = optim.Adam(
D_params, lr=args.lr_disc, weight_decay=5e-4, betas=(args.beta1, args.beta2)
)
elif args.optimizer == "sgd":
G_optimizer = optim.SGD(G_params, lr=args.lr_gen, momentum=args.sgd_momentum)
D_optimizer = optim.SGD(D_params, lr=args.lr_disc, momentum=args.sgd_momentum)

if args.load_model:
G_optimizer.load_state_dict(
Expand Down
Loading

0 comments on commit d25ea25

Please sign in to comment.