diff --git a/docs/autoencoders.md b/docs/autoencoders.md index ad413e95..2fea72e9 100644 --- a/docs/autoencoders.md +++ b/docs/autoencoders.md @@ -56,7 +56,7 @@ The `training` config in the autoencoder model config file should have the follo ## Loss configs There are few different types of losses that are used for autoencoder training, including spectral losses, time-domain losses, adversarial losses, and bottleneck-specific losses. -Hyperparameters fo these losses as well as loss weighting factors can be configured in the `loss_configs` property in the `training` config. +Hyperparameters for these losses as well as loss weighting factors can be configured in the `loss_configs` property in the `training` config. ### Spectral losses Multi-resolution STFT losses are the main reconstruction loss used for our audio autoencoders. We use the [auraloss](https://github.com/csteinmetz1/auraloss/tree/main/auraloss) library for our spectral loss functions. @@ -269,7 +269,7 @@ The Wasserstein bottleneck implements the WAE-MMD regularization method from the The Wasserstein bottleneck also exposes the `noise_augment_dim` property, which concatenates `noise_augment_dim` channels of Gaussian noise to the latent series before passing into the decoder. This adds some stochasticity to the latents which can be helpful for adversarial training, while keeping the encoder outputs deterministic. -**Note: The MMD calculation is very VRAM-intensive for longer sequence lengths, so training a Wasserstein autoencoder is best done on autoencoders with a decent downsampling factor, or on short sequence lengths. For inference, the MMD calculation is disabled.** +**Note: The MMD calculation is highly VRAM-intensive for longer sequence lengths. So, training a Wasserstein autoencoder is best done on autoencoders with a decent downsampling factor, or on short sequence lengths. For inference, the MMD calculation is disabled.** ### Example config ```json diff --git a/train.py b/train.py index 22baee42..e07b7278 100644 --- a/train.py +++ b/train.py @@ -1,3 +1,4 @@ +# Import libraries and dependencies. from prefigure.prefigure import get_all_args, push_wandb_config import json import torch @@ -8,10 +9,12 @@ from stable_audio_tools.training import create_training_wrapper_from_config, create_demo_callback_from_config from stable_audio_tools.training.utils import copy_state_dict +# This callback prints out exceptions that occur during training. Useful for debugging. class ExceptionCallback(pl.Callback): def on_exception(self, trainer, module, err): print(f'{type(err).__name__}: {err}') +# Embeds the model's configuration into the checkpoint. class ModelConfigEmbedderCallback(pl.Callback): def __init__(self, model_config): self.model_config = model_config @@ -21,45 +24,54 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint): def main(): + # Collect command-line arguments. args = get_all_args() + # Setting random seed for reproducibility. torch.manual_seed(args.seed) - #Get JSON config from args.model_config + # Get JSON config from args.model_config. with open(args.model_config) as f: model_config = json.load(f) with open(args.dataset_config) as f: dataset_config = json.load(f) + # Create a dataloader from specified configurations. train_dl = create_dataloader_from_configs_and_args(model_config, args, dataset_config) + # Instantiate the model from configuration. model = create_model_from_config(model_config) + # If a pre-trained checkpoint is provided, the model's state is updated with it. if args.pretrained_ckpt_path: copy_state_dict(model, torch.load(args.pretrained_ckpt_path, map_location="cpu")["state_dict"]) + # If a pre-transform checkpoint is provided, only the `pretransform` layer of the model is updated. if args.pretransform_ckpt_path: model.pretransform.load_state_dict(torch.load(args.pretransform_ckpt_path, map_location="cpu")["state_dict"]) - + + # Training wrapper creation. training_wrapper = create_training_wrapper_from_config(model_config, model) - exc_callback = ExceptionCallback() - ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1) - save_model_config_callback = ModelConfigEmbedderCallback(model_config) + exc_callback = ExceptionCallback() # Exception handling callback. + ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1) # Model checkpoint callback. + save_model_config_callback = ModelConfigEmbedderCallback(model_config) # Embed model configuration in the checkpoint. + # `demo_callback` is used to periodically visualize or listen to the model's output during training. demo_callback = create_demo_callback_from_config(model_config, demo_dl=train_dl) + # Weights & Biases logger setup for experiment tracking. wandb_logger = pl.loggers.WandbLogger(project=args.name) wandb_logger.watch(training_wrapper) - #Combine args and config dicts + # Combine args and config dicts. args_dict = vars(args) args_dict.update({"model_config": model_config}) args_dict.update({"dataset_config": dataset_config}) push_wandb_config(wandb_logger, args_dict) - #Set multi-GPU strategy if specified + # Set multi-GPU strategy if specified, otherwise a default strategy is chosen based on the number of GPUs available. if args.strategy: if args.strategy == "deepspeed": from pytorch_lightning.strategies import DeepSpeedStrategy @@ -76,6 +88,7 @@ def main(): else: strategy = 'ddp_find_unused_parameters_true' if args.num_gpus > 1 else None + # Trainer configuration. trainer = pl.Trainer( devices=args.num_gpus, accelerator="gpu", @@ -92,5 +105,6 @@ def main(): trainer.fit(training_wrapper, train_dl, ckpt_path=args.ckpt_path if args.ckpt_path else None) +# Training execution. if __name__ == '__main__': main() \ No newline at end of file