diff --git a/train_pl.py b/train_pl.py index d967be5..cba4906 100644 --- a/train_pl.py +++ b/train_pl.py @@ -37,7 +37,7 @@ def __init__(self, args, logger: logging.Logger): requires_grad(self.ema, False) # Load pretrained model if specified - if args.pretrained and args.resume_from_checkpoint is not None: + if args.pretrained: # Load old checkpoint, only load EMA self._load_pretrained_parameters(args) self.logging.info(f"Model Parameters: {sum(p.numel() for p in self.model.parameters()):,}") diff --git a/train_with_img_pl.py b/train_with_img_pl.py index ff6da7a..426e07d 100644 --- a/train_with_img_pl.py +++ b/train_with_img_pl.py @@ -37,7 +37,7 @@ def __init__(self, args, logger: logging.Logger): requires_grad(self.ema, False) # Load pretrained model if specified - if args.pretrained and args.resume_from_checkpoint is not None: + if args.pretrained: # Load old checkpoint, only load EMA self._load_pretrained_parameters(args) self.logging.info(f"Model Parameters: {sum(p.numel() for p in self.model.parameters()):,}")