diff --git a/train_pl.py b/train_pl.py index 7221207..cba4906 100644 --- a/train_pl.py +++ b/train_pl.py @@ -1,5 +1,6 @@ import os import torch +import math import logging import argparse from pytorch_lightning.callbacks import Callback @@ -36,7 +37,7 @@ def __init__(self, args, logger: logging.Logger): requires_grad(self.ema, False) # Load pretrained model if specified - if args.pretrained and args.resume_from_checkpoint is not None: + if args.pretrained: # Load old checkpoint, only load EMA self._load_pretrained_parameters(args) self.logging.info(f"Model Parameters: {sum(p.numel() for p in self.model.parameters()):,}") @@ -87,7 +88,7 @@ def training_step(self, batch, batch_idx): x = rearrange(x, "(b f) c h w -> b f c h w", b=b).contiguous() if self.args.extras == 78: # text-to-video - raise ValueError('T2V training is not supported at this moment!') + raise ValueError("T2V training is not supported at this moment!") elif self.args.extras == 2: model_kwargs = dict(y=video_name) else: @@ -114,6 +115,17 @@ def training_step(self, batch, batch_idx): def on_train_batch_end(self, *args, **kwargs): update_ema(self.ema, self.model) + def on_save_checkpoint(self, checkpoint): + super().on_save_checkpoint(checkpoint) + checkpoint_dir = self.trainer.checkpoint_callback.dirpath + epoch = self.trainer.current_epoch + step = self.trainer.global_step + checkpoint = { + "model": self.model.state_dict(), + "ema": self.ema.state_dict(), + } + torch.save(checkpoint, f"{checkpoint_dir}/epoch{epoch}-step{step}.ckpt") + def configure_optimizers(self): self.lr_scheduler = get_scheduler( name="constant", @@ -155,12 +167,20 @@ def main(args): seed = args.global_seed torch.manual_seed(seed) - # Setup an experiment folder and logger - experiment_dir, checkpoint_dir = create_experiment_directory(args) - logger = create_logger(experiment_dir) + # Determine if the current process is the main process (rank 0) + is_main_process = (int(os.environ.get("LOCAL_RANK", 0)) == 0) + # Setup an experiment folder and logger only if main process + if is_main_process: + experiment_dir, checkpoint_dir = create_experiment_directory(args) + logger = create_logger(experiment_dir) + OmegaConf.save(args, os.path.join(experiment_dir, "config.yaml")) + logger.info(f"Experiment directory created at {experiment_dir}") + else: + experiment_dir = os.getenv("EXPERIMENT_DIR", "default_path") + checkpoint_dir = os.getenv("CHECKPOINT_DIR", "default_path") + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) tb_logger = TensorBoardLogger(experiment_dir, name="latte") - OmegaConf.save(args, os.path.join(experiment_dir, "config.yaml")) - logger.info(f"Experiment directory created at {experiment_dir}") # Create the dataset and dataloader dataset = get_dataset(args) @@ -172,11 +192,21 @@ def main(args): pin_memory=True, drop_last=True ) - logger.info(f"Dataset contains {len(dataset):,} videos ({args.data_path})") + if is_main_process: + logger.info(f"Dataset contains {len(dataset)} videos ({args.data_path})") sample_size = args.image_size // 8 args.latent_size = sample_size + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(loader)) + # Afterwards we recalculate our number of training epochs + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + # In multi GPUs mode, the real batchsize is local_batch_size * GPU numbers + if is_main_process: + logger.info(f"One epoch iteration {num_update_steps_per_epoch} steps") + logger.info(f"Num train epochs: {num_train_epochs}") + # Initialize the training module pl_module = LatteTrainingModule(args, logger) @@ -193,18 +223,19 @@ def main(args): accelerator="gpu", # devices=[3], # Specify GPU ids strategy="auto", - max_steps=args.max_train_steps, + max_epochs=num_train_epochs, logger=tb_logger, callbacks=[checkpoint_callback, LearningRateMonitor()], log_every_n_steps=args.log_every, ) - trainer.fit(pl_module, train_dataloaders=loader, ckpt_path=args.resume_from_checkpoint if args.resume_from_checkpoint else None) + trainer.fit(pl_module, train_dataloaders=loader, ckpt_path=args.resume_from_checkpoint if + args.resume_from_checkpoint else None) - pl_module.model.eval() # important! This disables randomized embedding dropout - # do any sampling/FID calculation/etc. with ema (or model) in eval mode ... - logger.info("Done!") + pl_module.model.eval() cleanup() + if is_main_process: + logger.info("Done!") if __name__ == "__main__": diff --git a/train_with_img_pl.py b/train_with_img_pl.py index 282d0a6..426e07d 100644 --- a/train_with_img_pl.py +++ b/train_with_img_pl.py @@ -1,4 +1,5 @@ import os +import math import torch import logging import argparse @@ -36,7 +37,7 @@ def __init__(self, args, logger: logging.Logger): requires_grad(self.ema, False) # Load pretrained model if specified - if args.pretrained and args.resume_from_checkpoint is not None: + if args.pretrained: # Load old checkpoint, only load EMA self._load_pretrained_parameters(args) self.logging.info(f"Model Parameters: {sum(p.numel() for p in self.model.parameters()):,}") @@ -94,7 +95,7 @@ def training_step(self, batch, batch_idx): x = rearrange(x, "(b f) c h w -> b f c h w", b=b).contiguous() if self.args.extras == 78: # text-to-video - raise ValueError('T2V training is not supported at this moment!') + raise ValueError("T2V training is not supported at this moment!") elif self.args.extras == 2: if self.args.dataset == "ucf101_img": model_kwargs = dict(y=video_name, y_image=image_names, use_image_num=self.args.use_image_num) @@ -124,6 +125,17 @@ def training_step(self, batch, batch_idx): def on_train_batch_end(self, *args, **kwargs): update_ema(self.ema, self.model) + def on_save_checkpoint(self, checkpoint): + super().on_save_checkpoint(checkpoint) + checkpoint_dir = self.trainer.checkpoint_callback.dirpath + epoch = self.trainer.current_epoch + step = self.trainer.global_step + checkpoint = { + "model": self.model.state_dict(), + "ema": self.ema.state_dict(), + } + torch.save(checkpoint, f"{checkpoint_dir}/epoch{epoch}-step{step}.ckpt") + def configure_optimizers(self): self.lr_scheduler = get_scheduler( name="constant", @@ -165,12 +177,20 @@ def main(args): seed = args.global_seed torch.manual_seed(seed) - # Setup an experiment folder and logger - experiment_dir, checkpoint_dir = create_experiment_directory(args) - logger = create_logger(experiment_dir) + # Determine if the current process is the main process (rank 0) + is_main_process = (int(os.environ.get("LOCAL_RANK", 0)) == 0) + # Setup an experiment folder and logger only if main process + if is_main_process: + experiment_dir, checkpoint_dir = create_experiment_directory(args) + logger = create_logger(experiment_dir) + OmegaConf.save(args, os.path.join(experiment_dir, "config.yaml")) + logger.info(f"Experiment directory created at {experiment_dir}") + else: + experiment_dir = os.getenv("EXPERIMENT_DIR", "default_path") + checkpoint_dir = os.getenv("CHECKPOINT_DIR", "default_path") + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) tb_logger = TensorBoardLogger(experiment_dir, name="latte") - OmegaConf.save(args, os.path.join(experiment_dir, "config.yaml")) - logger.info(f"Experiment directory created at {experiment_dir}") # Create the dataset and dataloader dataset = get_dataset(args) @@ -182,11 +202,21 @@ def main(args): pin_memory=True, drop_last=True ) - logger.info(f"Dataset contains {len(dataset):,} videos ({args.data_path})") + if is_main_process: + logger.info(f"Dataset contains {len(dataset)} videos ({args.data_path})") sample_size = args.image_size // 8 args.latent_size = sample_size + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(loader)) + # Afterwards we recalculate our number of training epochs + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + # In multi GPUs mode, the real batchsize is local_batch_size * GPU numbers + if is_main_process: + logger.info(f"One epoch iteration {num_update_steps_per_epoch} steps") + logger.info(f"Num train epochs: {num_train_epochs}") + # Initialize the training module pl_module = LatteTrainingModule(args, logger) @@ -201,20 +231,21 @@ def main(args): # Trainer trainer = Trainer( accelerator="gpu", - # devices=[0,1], # Specify GPU ids + # devices=[3], # Specify GPU ids strategy="auto", - max_steps=args.max_train_steps, + max_epochs=num_train_epochs, logger=tb_logger, callbacks=[checkpoint_callback, LearningRateMonitor()], log_every_n_steps=args.log_every, ) - trainer.fit(pl_module, train_dataloaders=loader, ckpt_path=args.resume_from_checkpoint if args.resume_from_checkpoint else None) + trainer.fit(pl_module, train_dataloaders=loader, ckpt_path=args.resume_from_checkpoint if + args.resume_from_checkpoint else None) - pl_module.model.eval() # important! This disables randomized embedding dropout - # do any sampling/FID calculation/etc. with ema (or model) in eval mode ... - logger.info("Done!") + pl_module.model.eval() cleanup() + if is_main_process: + logger.info("Done!") if __name__ == "__main__":