Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code Clarity Comments & Style Error Fix In Utils.py #43

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions stable_audio_tools/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,30 @@
from torch import nn
from typing import Tuple


class PadCrop(nn.Module):
# Initialize the PadCrop module
def __init__(self, n_samples, randomize=True):
super().__init__()
self.n_samples = n_samples
self.randomize = randomize
self.n_samples = n_samples # Target number of samples for each signal
self.randomize = randomize # Whether to randomly select the crop start position

# Process the input signal
def __call__(self, signal):
n, s = signal.shape
n, s = signal.shape # n: number of signals, s: original number of samples per signal

# If not randomizing, start from 0; otherwise, pick a random start position within a valid range
start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
end = start + self.n_samples

end = start + self.n_samples # Calculate the end position for cropping

# Create a zero tensor with the desired output shape
output = signal.new_zeros([n, self.n_samples])

# Fill the output tensor with values from the input signal starting from 'start' to 'end'
# If the original signal is shorter than n_samples, fill as much as possible
output[:, :min(s, self.n_samples)] = signal[:, start:end]

return output

class PadCrop_Normalized_T(nn.Module):
Expand Down Expand Up @@ -70,7 +82,7 @@ def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, in
)

class PhaseFlipper(nn.Module):
"Randomly invert the phase of a signal"
"""Randomly invert the phase of a signal"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
Expand Down
67 changes: 56 additions & 11 deletions stable_audio_tools/training/musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,87 @@
from time import time

class Profiler:
"""A simple profiler to track execution time between ticks.

Attributes:
ticks (list): A list where each element is a [timestamp, message] pair.
"""

def __init__(self):
self.ticks = [[time(), None]]
"""Initializes the profiler with a starting tick."""
self.ticks = [[time(), None]] # Initialize with current time and no message

def tick(self, msg):
self.ticks.append([time(), msg])
"""Records a new tick along with an optional message.

Args:
msg (str): A message to associate with the tick.
"""
self.ticks.append([time(), msg]) # Append current time and message to ticks

def __repr__(self):
rep = 80 * "=" + "\n"
"""Generates a string representation of the profiler's tick data.

Returns:
str: A formatted string detailing the elapsed time between ticks.
"""
rep = 80 * "=" + "\n" # Start with a separator line
for i in range(1, len(self.ticks)):
msg = self.ticks[i][1]
ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
msg = self.ticks[i][1] # Message for the current tick
ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] # Time elapsed since last tick
# Append message and elapsed time to representation
rep += msg + f": {ellapsed*1000:.2f}ms\n"
rep += 80 * "=" + "\n\n\n"
rep += 80 * "=" + "\n\n\n" # End with a separator line
return rep


class MusicGenTrainingWrapper(pl.LightningModule):
def __init__(self, musicgen_model, lr = 1e-4, ema_copy=None):
"""A training wrapper for the MusicGen model using PyTorch Lightning.

Attributes:
musicgen_model (MusicGen): The music generation model to be trained.
lr (float): Learning rate for the optimizer.
ema_copy (optional): An optional model for exponential moving average.
lm (torch.nn.Module): The language model component of the music generation model.
lm_ema (EMA): EMA utility for the language model.
cfg_dropout (ClassifierFreeGuidanceDropout): Dropout module for classifier-free guidance.
"""

def __init__(self, musicgen_model, lr=1e-4, ema_copy=None):
"""Initializes the training wrapper with a MusicGen model, learning rate, and EMA copy if provided.

Args:
musicgen_model (MusicGen): Instance of the MusicGen model.
lr (float, optional): Learning rate for the optimizer. Defaults to 1e-4.
ema_copy (optional): Initial model for creating an EMA version. Defaults to None.
"""
super().__init__()

self.musicgen_model: MusicGen = musicgen_model
self.musicgen_model: MusicGen = musicgen_model # Assign the music generation model

# Disable gradients for the compression model to prevent its weights from being updated
self.musicgen_model.compression_model.requires_grad_(False)

self.lm = self.musicgen_model.lm
self.lm = self.musicgen_model.lm # Short reference to the language model component

# Ensure the language model is in float32, set to training mode, and enable gradients
self.lm.to(torch.float32).train().requires_grad_(True)

# Initialize the EMA for the language model with specified decay rate and update frequency
self.lm_ema = EMA(self.lm, ema_model=ema_copy, beta=0.99, update_every=10)

# Setup classifier-free guidance dropout module with a 10% dropout rate
self.cfg_dropout = ClassifierFreeGuidanceDropout(0.1)

self.lr = lr
self.lr = lr # Set the learning rate

def configure_optimizers(self):
"""Configures the optimizer for training.

Returns:
torch.optim.Optimizer: The AdamW optimizer configured with the language model's parameters and learning rate.
"""
# Initialize the optimizer with the language model parameters, learning rate, and other hyperparameters
optimizer = optim.AdamW([*self.lm.parameters()], lr=self.lr, betas=(0.9, 0.95), weight_decay=0.1)

return optimizer
Expand Down Expand Up @@ -228,4 +273,4 @@ def on_train_batch_end(self, trainer, module: MusicGenTrainingWrapper, outputs,
finally:
gc.collect()
torch.cuda.empty_cache()
module.train()
module.train()