From 00b9af14a978da47e6ce2ffab8e2636517fb4bcb Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Sun, 6 Oct 2024 14:06:29 -0700 Subject: [PATCH] Fix up sampling. --- README.md | 27 +++++++++++++++++++++++---- e2_tts_mlx/model.py | 22 ++++++++-------------- e2_tts_mlx/trainer.py | 2 +- 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 3b0bdb6..88e927a 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,6 @@ pip install mlx-e2-tts ## Usage ```python - import mlx.core as mx from e2_tts_mlx.model import E2TTS @@ -21,7 +20,7 @@ from e2_tts_mlx.trainer import E2Trainer from e2_tts_mlx.data import load_libritts_r e2tts = E2TTS( - tokenizer="char-utf8", # or "phoneme_en" for phoneme-based tokenization + tokenizer="char-utf8", # or "phoneme_en" cond_drop_prob = 0.25, frac_lengths_mask = (0.7, 0.9), transformer = dict( @@ -40,11 +39,31 @@ mx.eval(e2tts.parameters()) batch_size = 128 max_duration = 30 -dataset = load_libritts_r(split="dev-clean", max_duration = max_duration) # or any other audio/caption data set +dataset = load_libritts_r(split="dev-clean") # or any audio/caption dataset trainer = E2Trainer(model = e2tts, num_warmup_steps = 1000) -trainer.train(train_dataset = dataset, learning_rate = 7.5e-5, batch_size = batch_size) +trainer.train( + train_dataset = ..., + learning_rate = 7.5e-5, + batch_size = batch_size +) +``` + +... after much training ... + +```python +cond = ... +text = ... +duration = ... # from a trained DurationPredictor or otherwise + +generated_mel_spec = e2tts.sample( + cond = cond, + text = text, + duration = duration, + steps = 32, + cfg_strength = 1.0, # if trained for cfg +) ``` Note the model size specified above (from the paper) is very large. See `train_example.py` for a more practical-sized model you can train on your local device. diff --git a/e2_tts_mlx/model.py b/e2_tts_mlx/model.py index a4c01cf..0a93394 100644 --- a/e2_tts_mlx/model.py +++ b/e2_tts_mlx/model.py @@ -35,7 +35,7 @@ def lens_to_mask( length: int | None = None, ) -> mx.array: # Bool['b n'] if not exists(length): - length = t.amax() + length = t.max() seq = mx.arange(length) return einx.less("n, b -> b n", seq, t) @@ -844,7 +844,6 @@ def __init__( self, transformer: dict | Transformer = None, duration_predictor: dict | DurationPredictor | None = None, - odeint_kwargs: dict = dict(atol=1e-5, rtol=1e-5, method="midpoint"), cond_drop_prob=0.25, num_channels=None, mel_spec_module: nn.Module | None = None, @@ -877,10 +876,6 @@ def __init__( self.duration_predictor = duration_predictor - # sampling - - self.odeint_kwargs = odeint_kwargs - # mel spec self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) @@ -1016,7 +1011,6 @@ def odeint(self, func, y0, t): return mx.stack(ys) - @mx.compile def sample( self, cond: mx.array, @@ -1049,7 +1043,7 @@ def sample( assert text.shape[0] == batch if exists(text): - text_lens = (text != -1).sum(dim=-1) + text_lens = (text != -1).sum(axis=-1) lens = mx.maximum( text_lens, lens ) # make sure lengths are at least those of the text characters @@ -1070,15 +1064,15 @@ def sample( duration = mx.maximum( lens + 1, duration ) # just add one token so something is generated - duration = duration.clamp(max=max_duration) + duration = mx.minimum(duration, max_duration) assert duration.shape[0] == batch - max_duration = duration.amax() + max_duration = duration.max().item() - cond = mx.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) + cond = mx.pad(cond, [(0, 0), (0, max_duration - cond_seq_len), (0, 0)], constant_values=0) cond_mask = mx.pad( - cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False + cond_mask, [(0, 0), (0, max_duration - cond_mask.shape[-1])], constant_values=False ) cond_mask = rearrange(cond_mask, "... -> ... 1") @@ -1088,7 +1082,7 @@ def sample( def fn(t, x): # at each step, conditioning is fixed - + step_cond = mx.where(cond_mask, cond, mx.zeros_like(cond)) # predict flow @@ -1100,7 +1094,7 @@ def fn(t, x): y0 = mx.random.normal(cond.shape) t = mx.linspace(0, 1, steps) - trajectory = self.odeint(fn, y0, t, **self.odeint_kwargs) + trajectory = self.odeint(fn, y0, t) sampled = trajectory[-1] out = sampled diff --git a/e2_tts_mlx/trainer.py b/e2_tts_mlx/trainer.py index b0cc6ee..f26980c 100644 --- a/e2_tts_mlx/trainer.py +++ b/e2_tts_mlx/trainer.py @@ -159,7 +159,7 @@ def train_step(mel_spec, text_inputs, mel_lens): log_start_date = datetime.datetime.now() print( - f"step {global_step}: loss = {loss.item():.4f}, sec per step = {elapsed_time.seconds / log_every}" + f"step {global_step}: loss = {loss.item():.4f}, sec per step = {(elapsed_time.seconds / log_every):.2f}" ) if exists(self.duration_predictor):