Skip to content

Commit

Permalink
Fix up sampling.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman committed Oct 6, 2024
1 parent 3b30ee0 commit 00b9af1
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 19 deletions.
27 changes: 23 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@ pip install mlx-e2-tts
## Usage

```python

import mlx.core as mx

from e2_tts_mlx.model import E2TTS
from e2_tts_mlx.trainer import E2Trainer
from e2_tts_mlx.data import load_libritts_r

e2tts = E2TTS(
tokenizer="char-utf8", # or "phoneme_en" for phoneme-based tokenization
tokenizer="char-utf8", # or "phoneme_en"
cond_drop_prob = 0.25,
frac_lengths_mask = (0.7, 0.9),
transformer = dict(
Expand All @@ -40,11 +39,31 @@ mx.eval(e2tts.parameters())
batch_size = 128
max_duration = 30

dataset = load_libritts_r(split="dev-clean", max_duration = max_duration) # or any other audio/caption data set
dataset = load_libritts_r(split="dev-clean") # or any audio/caption dataset

trainer = E2Trainer(model = e2tts, num_warmup_steps = 1000)
trainer.train(train_dataset = dataset, learning_rate = 7.5e-5, batch_size = batch_size)

trainer.train(
train_dataset = ...,
learning_rate = 7.5e-5,
batch_size = batch_size
)
```

... after much training ...

```python
cond = ...
text = ...
duration = ... # from a trained DurationPredictor or otherwise

generated_mel_spec = e2tts.sample(
cond = cond,
text = text,
duration = duration,
steps = 32,
cfg_strength = 1.0, # if trained for cfg
)
```

Note the model size specified above (from the paper) is very large. See `train_example.py` for a more practical-sized model you can train on your local device.
Expand Down
22 changes: 8 additions & 14 deletions e2_tts_mlx/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def lens_to_mask(
length: int | None = None,
) -> mx.array: # Bool['b n']
if not exists(length):
length = t.amax()
length = t.max()

seq = mx.arange(length)
return einx.less("n, b -> b n", seq, t)
Expand Down Expand Up @@ -844,7 +844,6 @@ def __init__(
self,
transformer: dict | Transformer = None,
duration_predictor: dict | DurationPredictor | None = None,
odeint_kwargs: dict = dict(atol=1e-5, rtol=1e-5, method="midpoint"),
cond_drop_prob=0.25,
num_channels=None,
mel_spec_module: nn.Module | None = None,
Expand Down Expand Up @@ -877,10 +876,6 @@ def __init__(

self.duration_predictor = duration_predictor

# sampling

self.odeint_kwargs = odeint_kwargs

# mel spec

self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
Expand Down Expand Up @@ -1016,7 +1011,6 @@ def odeint(self, func, y0, t):

return mx.stack(ys)

@mx.compile
def sample(
self,
cond: mx.array,
Expand Down Expand Up @@ -1049,7 +1043,7 @@ def sample(
assert text.shape[0] == batch

if exists(text):
text_lens = (text != -1).sum(dim=-1)
text_lens = (text != -1).sum(axis=-1)
lens = mx.maximum(
text_lens, lens
) # make sure lengths are at least those of the text characters
Expand All @@ -1070,15 +1064,15 @@ def sample(
duration = mx.maximum(
lens + 1, duration
) # just add one token so something is generated
duration = duration.clamp(max=max_duration)
duration = mx.minimum(duration, max_duration)

assert duration.shape[0] == batch

max_duration = duration.amax()
max_duration = duration.max().item()

cond = mx.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
cond = mx.pad(cond, [(0, 0), (0, max_duration - cond_seq_len), (0, 0)], constant_values=0)
cond_mask = mx.pad(
cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False
cond_mask, [(0, 0), (0, max_duration - cond_mask.shape[-1])], constant_values=False
)
cond_mask = rearrange(cond_mask, "... -> ... 1")

Expand All @@ -1088,7 +1082,7 @@ def sample(

def fn(t, x):
# at each step, conditioning is fixed

step_cond = mx.where(cond_mask, cond, mx.zeros_like(cond))

# predict flow
Expand All @@ -1100,7 +1094,7 @@ def fn(t, x):
y0 = mx.random.normal(cond.shape)
t = mx.linspace(0, 1, steps)

trajectory = self.odeint(fn, y0, t, **self.odeint_kwargs)
trajectory = self.odeint(fn, y0, t)
sampled = trajectory[-1]

out = sampled
Expand Down
2 changes: 1 addition & 1 deletion e2_tts_mlx/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def train_step(mel_spec, text_inputs, mel_lens):
log_start_date = datetime.datetime.now()

print(
f"step {global_step}: loss = {loss.item():.4f}, sec per step = {elapsed_time.seconds / log_every}"
f"step {global_step}: loss = {loss.item():.4f}, sec per step = {(elapsed_time.seconds / log_every):.2f}"
)

if exists(self.duration_predictor):
Expand Down

0 comments on commit 00b9af1

Please sign in to comment.