Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docs improvement: autoencoders.md #13

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/autoencoders.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ The `training` config in the autoencoder model config file should have the follo
## Loss configs
There are few different types of losses that are used for autoencoder training, including spectral losses, time-domain losses, adversarial losses, and bottleneck-specific losses.

Hyperparameters fo these losses as well as loss weighting factors can be configured in the `loss_configs` property in the `training` config.
Hyperparameters for these losses as well as loss weighting factors can be configured in the `loss_configs` property in the `training` config.

### Spectral losses
Multi-resolution STFT losses are the main reconstruction loss used for our audio autoencoders. We use the [auraloss](https://github.com/csteinmetz1/auraloss/tree/main/auraloss) library for our spectral loss functions.
Expand Down Expand Up @@ -269,7 +269,7 @@ The Wasserstein bottleneck implements the WAE-MMD regularization method from the

The Wasserstein bottleneck also exposes the `noise_augment_dim` property, which concatenates `noise_augment_dim` channels of Gaussian noise to the latent series before passing into the decoder. This adds some stochasticity to the latents which can be helpful for adversarial training, while keeping the encoder outputs deterministic.

**Note: The MMD calculation is very VRAM-intensive for longer sequence lengths, so training a Wasserstein autoencoder is best done on autoencoders with a decent downsampling factor, or on short sequence lengths. For inference, the MMD calculation is disabled.**
**Note: The MMD calculation is highly VRAM-intensive for longer sequence lengths. So, training a Wasserstein autoencoder is best done on autoencoders with a decent downsampling factor, or on short sequence lengths. For inference, the MMD calculation is disabled.**

### Example config
```json
Expand Down
28 changes: 21 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Import libraries and dependencies.
from prefigure.prefigure import get_all_args, push_wandb_config
import json
import torch
Expand All @@ -8,10 +9,12 @@
from stable_audio_tools.training import create_training_wrapper_from_config, create_demo_callback_from_config
from stable_audio_tools.training.utils import copy_state_dict

# This callback prints out exceptions that occur during training. Useful for debugging.
class ExceptionCallback(pl.Callback):
def on_exception(self, trainer, module, err):
print(f'{type(err).__name__}: {err}')

# Embeds the model's configuration into the checkpoint.
class ModelConfigEmbedderCallback(pl.Callback):
def __init__(self, model_config):
self.model_config = model_config
Expand All @@ -21,45 +24,54 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):

def main():

# Collect command-line arguments.
args = get_all_args()

# Setting random seed for reproducibility.
torch.manual_seed(args.seed)

#Get JSON config from args.model_config
# Get JSON config from args.model_config.
with open(args.model_config) as f:
model_config = json.load(f)

with open(args.dataset_config) as f:
dataset_config = json.load(f)

# Create a dataloader from specified configurations.
train_dl = create_dataloader_from_configs_and_args(model_config, args, dataset_config)

# Instantiate the model from configuration.
model = create_model_from_config(model_config)

# If a pre-trained checkpoint is provided, the model's state is updated with it.
if args.pretrained_ckpt_path:
copy_state_dict(model, torch.load(args.pretrained_ckpt_path, map_location="cpu")["state_dict"])

# If a pre-transform checkpoint is provided, only the `pretransform` layer of the model is updated.
if args.pretransform_ckpt_path:
model.pretransform.load_state_dict(torch.load(args.pretransform_ckpt_path, map_location="cpu")["state_dict"])


# Training wrapper creation.
training_wrapper = create_training_wrapper_from_config(model_config, model)

exc_callback = ExceptionCallback()
ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1)
save_model_config_callback = ModelConfigEmbedderCallback(model_config)
exc_callback = ExceptionCallback() # Exception handling callback.
ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1) # Model checkpoint callback.
save_model_config_callback = ModelConfigEmbedderCallback(model_config) # Embed model configuration in the checkpoint.

# `demo_callback` is used to periodically visualize or listen to the model's output during training.
demo_callback = create_demo_callback_from_config(model_config, demo_dl=train_dl)

# Weights & Biases logger setup for experiment tracking.
wandb_logger = pl.loggers.WandbLogger(project=args.name)
wandb_logger.watch(training_wrapper)

#Combine args and config dicts
# Combine args and config dicts.
args_dict = vars(args)
args_dict.update({"model_config": model_config})
args_dict.update({"dataset_config": dataset_config})
push_wandb_config(wandb_logger, args_dict)

#Set multi-GPU strategy if specified
# Set multi-GPU strategy if specified, otherwise a default strategy is chosen based on the number of GPUs available.
if args.strategy:
if args.strategy == "deepspeed":
from pytorch_lightning.strategies import DeepSpeedStrategy
Expand All @@ -76,6 +88,7 @@ def main():
else:
strategy = 'ddp_find_unused_parameters_true' if args.num_gpus > 1 else None

# Trainer configuration.
trainer = pl.Trainer(
devices=args.num_gpus,
accelerator="gpu",
Expand All @@ -92,5 +105,6 @@ def main():

trainer.fit(training_wrapper, train_dl, ckpt_path=args.ckpt_path if args.ckpt_path else None)

# Training execution.
if __name__ == '__main__':
main()