Skip to content

Commit

Permalink
Merge branch 'main' of github.com:Vchitect/Latte
Browse files Browse the repository at this point in the history
  • Loading branch information
maxin-cn committed Sep 3, 2024
2 parents eb93d57 + 586a76d commit c1650af
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 27 deletions.
57 changes: 44 additions & 13 deletions train_pl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import torch
import math
import logging
import argparse
from pytorch_lightning.callbacks import Callback
Expand Down Expand Up @@ -36,7 +37,7 @@ def __init__(self, args, logger: logging.Logger):
requires_grad(self.ema, False)

# Load pretrained model if specified
if args.pretrained and args.resume_from_checkpoint is not None:
if args.pretrained:
# Load old checkpoint, only load EMA
self._load_pretrained_parameters(args)
self.logging.info(f"Model Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
Expand Down Expand Up @@ -87,7 +88,7 @@ def training_step(self, batch, batch_idx):
x = rearrange(x, "(b f) c h w -> b f c h w", b=b).contiguous()

if self.args.extras == 78: # text-to-video
raise ValueError('T2V training is not supported at this moment!')
raise ValueError("T2V training is not supported at this moment!")
elif self.args.extras == 2:
model_kwargs = dict(y=video_name)
else:
Expand All @@ -114,6 +115,17 @@ def training_step(self, batch, batch_idx):
def on_train_batch_end(self, *args, **kwargs):
update_ema(self.ema, self.model)

def on_save_checkpoint(self, checkpoint):
super().on_save_checkpoint(checkpoint)
checkpoint_dir = self.trainer.checkpoint_callback.dirpath
epoch = self.trainer.current_epoch
step = self.trainer.global_step
checkpoint = {
"model": self.model.state_dict(),
"ema": self.ema.state_dict(),
}
torch.save(checkpoint, f"{checkpoint_dir}/epoch{epoch}-step{step}.ckpt")

def configure_optimizers(self):
self.lr_scheduler = get_scheduler(
name="constant",
Expand Down Expand Up @@ -155,12 +167,20 @@ def main(args):
seed = args.global_seed
torch.manual_seed(seed)

# Setup an experiment folder and logger
experiment_dir, checkpoint_dir = create_experiment_directory(args)
logger = create_logger(experiment_dir)
# Determine if the current process is the main process (rank 0)
is_main_process = (int(os.environ.get("LOCAL_RANK", 0)) == 0)
# Setup an experiment folder and logger only if main process
if is_main_process:
experiment_dir, checkpoint_dir = create_experiment_directory(args)
logger = create_logger(experiment_dir)
OmegaConf.save(args, os.path.join(experiment_dir, "config.yaml"))
logger.info(f"Experiment directory created at {experiment_dir}")
else:
experiment_dir = os.getenv("EXPERIMENT_DIR", "default_path")
checkpoint_dir = os.getenv("CHECKPOINT_DIR", "default_path")
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
tb_logger = TensorBoardLogger(experiment_dir, name="latte")
OmegaConf.save(args, os.path.join(experiment_dir, "config.yaml"))
logger.info(f"Experiment directory created at {experiment_dir}")

# Create the dataset and dataloader
dataset = get_dataset(args)
Expand All @@ -172,11 +192,21 @@ def main(args):
pin_memory=True,
drop_last=True
)
logger.info(f"Dataset contains {len(dataset):,} videos ({args.data_path})")
if is_main_process:
logger.info(f"Dataset contains {len(dataset)} videos ({args.data_path})")

sample_size = args.image_size // 8
args.latent_size = sample_size

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(loader))
# Afterwards we recalculate our number of training epochs
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# In multi GPUs mode, the real batchsize is local_batch_size * GPU numbers
if is_main_process:
logger.info(f"One epoch iteration {num_update_steps_per_epoch} steps")
logger.info(f"Num train epochs: {num_train_epochs}")

# Initialize the training module
pl_module = LatteTrainingModule(args, logger)

Expand All @@ -193,18 +223,19 @@ def main(args):
accelerator="gpu",
# devices=[3], # Specify GPU ids
strategy="auto",
max_steps=args.max_train_steps,
max_epochs=num_train_epochs,
logger=tb_logger,
callbacks=[checkpoint_callback, LearningRateMonitor()],
log_every_n_steps=args.log_every,
)

trainer.fit(pl_module, train_dataloaders=loader, ckpt_path=args.resume_from_checkpoint if args.resume_from_checkpoint else None)
trainer.fit(pl_module, train_dataloaders=loader, ckpt_path=args.resume_from_checkpoint if
args.resume_from_checkpoint else None)

pl_module.model.eval() # important! This disables randomized embedding dropout
# do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
logger.info("Done!")
pl_module.model.eval()
cleanup()
if is_main_process:
logger.info("Done!")


if __name__ == "__main__":
Expand Down
59 changes: 45 additions & 14 deletions train_with_img_pl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import math
import torch
import logging
import argparse
Expand Down Expand Up @@ -36,7 +37,7 @@ def __init__(self, args, logger: logging.Logger):
requires_grad(self.ema, False)

# Load pretrained model if specified
if args.pretrained and args.resume_from_checkpoint is not None:
if args.pretrained:
# Load old checkpoint, only load EMA
self._load_pretrained_parameters(args)
self.logging.info(f"Model Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
Expand Down Expand Up @@ -94,7 +95,7 @@ def training_step(self, batch, batch_idx):
x = rearrange(x, "(b f) c h w -> b f c h w", b=b).contiguous()

if self.args.extras == 78: # text-to-video
raise ValueError('T2V training is not supported at this moment!')
raise ValueError("T2V training is not supported at this moment!")
elif self.args.extras == 2:
if self.args.dataset == "ucf101_img":
model_kwargs = dict(y=video_name, y_image=image_names, use_image_num=self.args.use_image_num)
Expand Down Expand Up @@ -124,6 +125,17 @@ def training_step(self, batch, batch_idx):
def on_train_batch_end(self, *args, **kwargs):
update_ema(self.ema, self.model)

def on_save_checkpoint(self, checkpoint):
super().on_save_checkpoint(checkpoint)
checkpoint_dir = self.trainer.checkpoint_callback.dirpath
epoch = self.trainer.current_epoch
step = self.trainer.global_step
checkpoint = {
"model": self.model.state_dict(),
"ema": self.ema.state_dict(),
}
torch.save(checkpoint, f"{checkpoint_dir}/epoch{epoch}-step{step}.ckpt")

def configure_optimizers(self):
self.lr_scheduler = get_scheduler(
name="constant",
Expand Down Expand Up @@ -165,12 +177,20 @@ def main(args):
seed = args.global_seed
torch.manual_seed(seed)

# Setup an experiment folder and logger
experiment_dir, checkpoint_dir = create_experiment_directory(args)
logger = create_logger(experiment_dir)
# Determine if the current process is the main process (rank 0)
is_main_process = (int(os.environ.get("LOCAL_RANK", 0)) == 0)
# Setup an experiment folder and logger only if main process
if is_main_process:
experiment_dir, checkpoint_dir = create_experiment_directory(args)
logger = create_logger(experiment_dir)
OmegaConf.save(args, os.path.join(experiment_dir, "config.yaml"))
logger.info(f"Experiment directory created at {experiment_dir}")
else:
experiment_dir = os.getenv("EXPERIMENT_DIR", "default_path")
checkpoint_dir = os.getenv("CHECKPOINT_DIR", "default_path")
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
tb_logger = TensorBoardLogger(experiment_dir, name="latte")
OmegaConf.save(args, os.path.join(experiment_dir, "config.yaml"))
logger.info(f"Experiment directory created at {experiment_dir}")

# Create the dataset and dataloader
dataset = get_dataset(args)
Expand All @@ -182,11 +202,21 @@ def main(args):
pin_memory=True,
drop_last=True
)
logger.info(f"Dataset contains {len(dataset):,} videos ({args.data_path})")
if is_main_process:
logger.info(f"Dataset contains {len(dataset)} videos ({args.data_path})")

sample_size = args.image_size // 8
args.latent_size = sample_size

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(loader))
# Afterwards we recalculate our number of training epochs
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# In multi GPUs mode, the real batchsize is local_batch_size * GPU numbers
if is_main_process:
logger.info(f"One epoch iteration {num_update_steps_per_epoch} steps")
logger.info(f"Num train epochs: {num_train_epochs}")

# Initialize the training module
pl_module = LatteTrainingModule(args, logger)

Expand All @@ -201,20 +231,21 @@ def main(args):
# Trainer
trainer = Trainer(
accelerator="gpu",
# devices=[0,1], # Specify GPU ids
# devices=[3], # Specify GPU ids
strategy="auto",
max_steps=args.max_train_steps,
max_epochs=num_train_epochs,
logger=tb_logger,
callbacks=[checkpoint_callback, LearningRateMonitor()],
log_every_n_steps=args.log_every,
)

trainer.fit(pl_module, train_dataloaders=loader, ckpt_path=args.resume_from_checkpoint if args.resume_from_checkpoint else None)
trainer.fit(pl_module, train_dataloaders=loader, ckpt_path=args.resume_from_checkpoint if
args.resume_from_checkpoint else None)

pl_module.model.eval() # important! This disables randomized embedding dropout
# do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
logger.info("Done!")
pl_module.model.eval()
cleanup()
if is_main_process:
logger.info("Done!")


if __name__ == "__main__":
Expand Down

0 comments on commit c1650af

Please sign in to comment.