From 78e6c08ff9c8193a375a3d762999212bcad30765 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Wed, 30 Oct 2024 20:38:05 -0700 Subject: [PATCH] Fix long text generation. --- f5_tts_mlx/duration.py | 4 ++++ f5_tts_mlx/utils.py | 6 +----- pyproject.toml | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/f5_tts_mlx/duration.py b/f5_tts_mlx/duration.py index eafddd7..77b2610 100644 --- a/f5_tts_mlx/duration.py +++ b/f5_tts_mlx/duration.py @@ -183,6 +183,10 @@ def __call__( text = list_str_to_tensor(text) assert text.shape[0] == batch + if seq_len < text.shape[1]: + seq_len = text.shape[1] + inp = mx.pad(inp, [(0, 0), (0, seq_len - inp.shape[1]), (0, 0)]) + # lens and mask if not exists(lens): lens = mx.full((batch,), seq_len) diff --git a/f5_tts_mlx/utils.py b/f5_tts_mlx/utils.py index 263e7f0..dbb7554 100644 --- a/f5_tts_mlx/utils.py +++ b/f5_tts_mlx/utils.py @@ -89,7 +89,7 @@ def maybe_masked_mean(t: mx.array, mask: mx.array | None = None) -> mx.array: return einx.divide("b d, b -> b d", num, mx.maximum(den, 1)) -def pad_to_length(t: mx.array, length: int, value=None): +def pad_to_length(t: mx.array, length: int, value=0): ndim = t.ndim seq_len = t.shape[-1] if length > seq_len: @@ -97,10 +97,6 @@ def pad_to_length(t: mx.array, length: int, value=None): t = mx.pad(t, [(0, length - seq_len)], constant_values=value) elif ndim == 2: t = mx.pad(t, [(0, 0), (0, length - seq_len)], constant_values=value) - elif ndim == 3: - t = mx.pad( - t, [(0, 0), (0, length - seq_len), (0, 0)], constant_values=value - ) else: raise ValueError(f"Unsupported padding dims: {ndim}") return t[..., :length] diff --git a/pyproject.toml b/pyproject.toml index 52bd239..163bbb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" [project] name = "f5-tts-mlx" -version = "0.1.5" +version = "0.1.6" authors = [{name = "Lucas Newman", email = "lucasnewman@me.com"}] license = {text = "MIT"} description = "F5-TTS - MLX"