From aaf25b05612d41f440993990ac3c9d4ec3fa286b Mon Sep 17 00:00:00 2001 From: Sabian Hibbs Date: Sun, 18 Feb 2024 21:23:52 +0000 Subject: [PATCH 1/2] Update utils.py Added some comments for code clarity, fixed style errors. --- stable_audio_tools/data/utils.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/stable_audio_tools/data/utils.py b/stable_audio_tools/data/utils.py index 848012e4..8a4b3514 100644 --- a/stable_audio_tools/data/utils.py +++ b/stable_audio_tools/data/utils.py @@ -5,18 +5,30 @@ from torch import nn from typing import Tuple + class PadCrop(nn.Module): + # Initialize the PadCrop module def __init__(self, n_samples, randomize=True): super().__init__() - self.n_samples = n_samples - self.randomize = randomize + self.n_samples = n_samples # Target number of samples for each signal + self.randomize = randomize # Whether to randomly select the crop start position + # Process the input signal def __call__(self, signal): - n, s = signal.shape + n, s = signal.shape # n: number of signals, s: original number of samples per signal + + # If not randomizing, start from 0; otherwise, pick a random start position within a valid range start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() - end = start + self.n_samples + + end = start + self.n_samples # Calculate the end position for cropping + + # Create a zero tensor with the desired output shape output = signal.new_zeros([n, self.n_samples]) + + # Fill the output tensor with values from the input signal starting from 'start' to 'end' + # If the original signal is shorter than n_samples, fill as much as possible output[:, :min(s, self.n_samples)] = signal[:, start:end] + return output class PadCrop_Normalized_T(nn.Module): @@ -70,7 +82,7 @@ def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, in ) class PhaseFlipper(nn.Module): - "Randomly invert the phase of a signal" + """Randomly invert the phase of a signal""" def __init__(self, p=0.5): super().__init__() self.p = p From b9e9bc218622f24b657a7dc942577f7e9243ce04 Mon Sep 17 00:00:00 2001 From: Sabian Hibbs Date: Sun, 18 Feb 2024 21:39:03 +0000 Subject: [PATCH 2/2] Update Code Comments in musicgen.py Added comments for code clarity with google style guide. --- stable_audio_tools/training/musicgen.py | 67 +++++++++++++++++++++---- 1 file changed, 56 insertions(+), 11 deletions(-) diff --git a/stable_audio_tools/training/musicgen.py b/stable_audio_tools/training/musicgen.py index 9893a747..79448946 100644 --- a/stable_audio_tools/training/musicgen.py +++ b/stable_audio_tools/training/musicgen.py @@ -19,42 +19,87 @@ from time import time class Profiler: + """A simple profiler to track execution time between ticks. + + Attributes: + ticks (list): A list where each element is a [timestamp, message] pair. + """ def __init__(self): - self.ticks = [[time(), None]] + """Initializes the profiler with a starting tick.""" + self.ticks = [[time(), None]] # Initialize with current time and no message def tick(self, msg): - self.ticks.append([time(), msg]) + """Records a new tick along with an optional message. + + Args: + msg (str): A message to associate with the tick. + """ + self.ticks.append([time(), msg]) # Append current time and message to ticks def __repr__(self): - rep = 80 * "=" + "\n" + """Generates a string representation of the profiler's tick data. + + Returns: + str: A formatted string detailing the elapsed time between ticks. + """ + rep = 80 * "=" + "\n" # Start with a separator line for i in range(1, len(self.ticks)): - msg = self.ticks[i][1] - ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] + msg = self.ticks[i][1] # Message for the current tick + ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] # Time elapsed since last tick + # Append message and elapsed time to representation rep += msg + f": {ellapsed*1000:.2f}ms\n" - rep += 80 * "=" + "\n\n\n" + rep += 80 * "=" + "\n\n\n" # End with a separator line return rep class MusicGenTrainingWrapper(pl.LightningModule): - def __init__(self, musicgen_model, lr = 1e-4, ema_copy=None): + """A training wrapper for the MusicGen model using PyTorch Lightning. + + Attributes: + musicgen_model (MusicGen): The music generation model to be trained. + lr (float): Learning rate for the optimizer. + ema_copy (optional): An optional model for exponential moving average. + lm (torch.nn.Module): The language model component of the music generation model. + lm_ema (EMA): EMA utility for the language model. + cfg_dropout (ClassifierFreeGuidanceDropout): Dropout module for classifier-free guidance. + """ + + def __init__(self, musicgen_model, lr=1e-4, ema_copy=None): + """Initializes the training wrapper with a MusicGen model, learning rate, and EMA copy if provided. + + Args: + musicgen_model (MusicGen): Instance of the MusicGen model. + lr (float, optional): Learning rate for the optimizer. Defaults to 1e-4. + ema_copy (optional): Initial model for creating an EMA version. Defaults to None. + """ super().__init__() - self.musicgen_model: MusicGen = musicgen_model + self.musicgen_model: MusicGen = musicgen_model # Assign the music generation model + # Disable gradients for the compression model to prevent its weights from being updated self.musicgen_model.compression_model.requires_grad_(False) - self.lm = self.musicgen_model.lm + self.lm = self.musicgen_model.lm # Short reference to the language model component + # Ensure the language model is in float32, set to training mode, and enable gradients self.lm.to(torch.float32).train().requires_grad_(True) + # Initialize the EMA for the language model with specified decay rate and update frequency self.lm_ema = EMA(self.lm, ema_model=ema_copy, beta=0.99, update_every=10) + # Setup classifier-free guidance dropout module with a 10% dropout rate self.cfg_dropout = ClassifierFreeGuidanceDropout(0.1) - self.lr = lr + self.lr = lr # Set the learning rate def configure_optimizers(self): + """Configures the optimizer for training. + + Returns: + torch.optim.Optimizer: The AdamW optimizer configured with the language model's parameters and learning rate. + """ + # Initialize the optimizer with the language model parameters, learning rate, and other hyperparameters optimizer = optim.AdamW([*self.lm.parameters()], lr=self.lr, betas=(0.9, 0.95), weight_decay=0.1) return optimizer @@ -228,4 +273,4 @@ def on_train_batch_end(self, trainer, module: MusicGenTrainingWrapper, outputs, finally: gc.collect() torch.cuda.empty_cache() - module.train() \ No newline at end of file + module.train()