diff --git a/f5_tts_mlx/README.md b/f5_tts_mlx/README.md index c027c0c..c83f792 100644 --- a/f5_tts_mlx/README.md +++ b/f5_tts_mlx/README.md @@ -40,9 +40,9 @@ Provide a caption for the reference audio. `--output` -string, default: "output.wav" +string, default: None -Specify the output path where the generated audio will be saved. If not specified, the script will save the output to a default location. +Specify the output path where the generated audio will be saved. If not specified, audio will play as it's generated. `--cfg` @@ -52,13 +52,13 @@ Specifies the strength used for classifier free guidance `--method` -str, default: "euler" +str, default: "rk4" -Specify the sampling method for the ODE. Options are "euler" and "midpoint". +Specify the sampling method for the ODE. Options are "euler", "midpoint", and "rk4". `--steps` -int, default: 32 +int, default: 8 Specify the number of steps used to sample the neural ODE. Lower steps trade off quality for latency. diff --git a/f5_tts_mlx/cfm.py b/f5_tts_mlx/cfm.py index 05e1eab..c929811 100644 --- a/f5_tts_mlx/cfm.py +++ b/f5_tts_mlx/cfm.py @@ -8,8 +8,6 @@ """ from __future__ import annotations -from datetime import datetime -import os from pathlib import Path from random import random from typing import Callable, Literal @@ -26,15 +24,106 @@ from f5_tts_mlx.modules import MelSpec from f5_tts_mlx.utils import ( exists, + fetch_from_hub, default, lens_to_mask, - mask_from_frac_lengths, list_str_to_idx, list_str_to_tensor, + mask_from_frac_lengths, pad_sequence, - fetch_from_hub, ) + +# ode solvers + + +def odeint_euler(func, y0, t): + """ + Solves ODE using the Euler method. + + Parameters: + - func: Function representing the ODE, with signature func(t, y). + - y0: Initial state, an MLX array of any shape. + - t: Array of time steps, an MLX array. + """ + ys = [y0] + y_current = y0 + + for i in range(len(t) - 1): + t_current = t[i] + dt = t[i + 1] - t_current + + # compute the next value + k = func(t_current, y_current) + y_next = y_current + dt * k + + ys.append(y_next) + y_current = y_next + + return mx.stack(ys) + + +def odeint_midpoint(func, y0, t): + """ + Solves ODE using the midpoint method. + + Parameters: + - func: Function representing the ODE, with signature func(t, y). + - y0: Initial state, an MLX array of any shape. + - t: Array of time steps, an MLX array. + """ + ys = [y0] + y_current = y0 + + for i in range(len(t) - 1): + t_current = t[i] + dt = t[i + 1] - t_current + + # midpoint approximation + k1 = func(t_current, y_current) + mid = y_current + 0.5 * dt * k1 + + # compute the next value + k2 = func(t_current + 0.5 * dt, mid) + y_next = y_current + dt * k2 + + ys.append(y_next) + y_current = y_next + + return mx.stack(ys) + + +def odeint_rk4(func, y0, t): + """ + Solves ODE using the Runge-Kutta 4th-order (RK4) method. + + Parameters: + - func: Function representing the ODE, with signature func(t, y). + - y0: Initial state, an MLX array of any shape. + - t: Array of time steps, an MLX array. + """ + ys = [y0] + y_current = y0 + + for i in range(len(t) - 1): + t_current = t[i] + dt = t[i + 1] - t_current + + # rk4 steps + k1 = func(t_current, y_current) + k2 = func(t_current + 0.5 * dt, y_current + 0.5 * dt * k1) + k3 = func(t_current + 0.5 * dt, y_current + 0.5 * dt * k2) + k4 = func(t_current + dt, y_current + dt * k3) + + # compute the next value + y_next = y_current + (dt / 6) * (k1 + 2 * k2 + 2 * k3 + k4) + + ys.append(y_next) + y_current = y_next + + return mx.stack(ys) + + # conditional flow matching @@ -168,57 +257,16 @@ def __call__( return loss.mean() - def odeint_midpoint(self, func, y0, t): - """ - Solves ODE using the midpoint method. - - Parameters: - - y0: Initial state, an MLX array of any shape. - - t: Array of time steps, an MLX array. - """ - ys = [y0] - y_current = y0 - - for i in range(len(t) - 1): - t_current = t[i] - dt = t[i + 1] - t_current - - # midpoint approximation - k1 = func(t_current, y_current) - mid = y_current + 0.5 * dt * k1 - - # compute the next value - k2 = func(t_current + 0.5 * dt, mid) - y_next = y_current + dt * k2 - - ys.append(y_next) - y_current = y_next - - return mx.stack(ys) - - def odeint_euler(self, func, y0, t): - """ - Solves ODE using the Euler method. - - Parameters: - - y0: Initial state, an MLX array of any shape. - - t: Array of time steps, an MLX array. - """ - ys = [y0] - y_current = y0 - - for i in range(len(t) - 1): - t_current = t[i] - dt = t[i + 1] - t_current - - # compute the next value - k = func(t_current, y_current) - y_next = y_current + dt * k - - ys.append(y_next) - y_current = y_next - - return mx.stack(ys) + def predict_duration( + self, + cond: mx.array["b n d"], + text: mx.array["b nt"], + speed: float = 1.0, + ) -> int: + duration_in_sec = self._duration_predictor(cond, text) + frame_rate = self._mel_spec.sample_rate // self._mel_spec.hop_length + duration = (duration_in_sec * frame_rate / speed).astype(mx.int32) + return duration def sample( self, @@ -227,18 +275,14 @@ def sample( duration: int | mx.array["b"] | None = None, *, lens: mx.array["b"] | None = None, - steps=32, - method: Literal["euler", "midpoint"] = "euler", + steps=8, + method: Literal["euler", "midpoint", "rk4"] = "rk4", cfg_strength=2.0, speed=1.0, sway_sampling_coef=-1.0, seed: int | None = None, max_duration=4096, - no_ref_audio=False, - edit_mask=None, ) -> tuple[mx.array, mx.array]: - start_date = datetime.now() - self.eval() # raw wave @@ -246,7 +290,6 @@ def sample( if cond.ndim == 2: cond = rearrange(cond, "1 n -> n") cond = self._mel_spec(cond) - # cond = rearrange(cond, "b d n -> b n d") assert cond.shape[-1] == self.num_channels batch, cond_seq_len, dtype = *cond.shape[:2], cond.dtype @@ -269,20 +312,13 @@ def sample( # duration if duration is None and self._duration_predictor is not None: - duration_in_sec = self._duration_predictor(cond, text) - frame_rate = self._mel_spec.sample_rate // self._mel_spec.hop_length - duration = (duration_in_sec * frame_rate / speed).astype(mx.int32).item() - print( - f"Got duration of {duration} frames ({duration_in_sec.item()} secs) for generated speech." - ) + duration = self.predict_duration(cond, text, speed) elif duration is None: raise ValueError( "Duration must be provided or a duration predictor must be set." ) cond_mask = lens_to_mask(lens) - if edit_mask is not None: - cond_mask = cond_mask & edit_mask if isinstance(duration, int): duration = mx.full((batch,), duration, dtype=dtype) @@ -308,10 +344,6 @@ def sample( else: mask = None - # test for no ref audio - if no_ref_audio: - cond = mx.zeros_like(cond) - # neural ode def fn(t, x): @@ -325,8 +357,8 @@ def fn(t, x): drop_audio_cond=False, drop_text=False, ) + if cfg_strength < 1e-5: - mx.eval(pred) return pred null_pred = self.transformer( @@ -339,17 +371,17 @@ def fn(t, x): drop_text=True, ) output = pred + (pred - null_pred) * cfg_strength - mx.eval(output) return output # noise input - + y0 = [] for dur in duration: if exists(seed): mx.random.seed(seed) - y0.append(mx.random.normal((dur, self.num_channels))) + y0.append(mx.random.normal((self.num_channels, dur))) y0 = pad_sequence(y0, padding_value=0) + y0 = rearrange(y0, "b d n -> b n d") t_start = 0 @@ -358,12 +390,17 @@ def fn(t, x): t = t + sway_sampling_coef * (mx.cos(mx.pi / 2 * t) - 1 + t) if method == "midpoint": - trajectory = self.odeint_midpoint(fn, y0, t) + ode_step_fn = odeint_midpoint elif method == "euler": - trajectory = self.odeint_euler(fn, y0, t) + ode_step_fn = odeint_euler + elif method == "rk4": + ode_step_fn = odeint_rk4 else: raise ValueError(f"Unknown method: {method}") + fn = mx.compile(fn) + trajectory = ode_step_fn(fn, y0, t) + sampled = trajectory[-1] out = sampled out = mx.where(cond_mask, cond, out) @@ -371,14 +408,12 @@ def fn(t, x): if exists(self._vocoder): out = self._vocoder(out) - mx.eval(out) - - print(f"Generated speech in {datetime.now() - start_date}") - return out, trajectory @classmethod - def from_pretrained(cls, hf_model_name_or_path: str, convert_weights = False) -> F5TTS: + def from_pretrained( + cls, hf_model_name_or_path: str, convert_weights=False + ) -> F5TTS: path = fetch_from_hub(hf_model_name_or_path) if path is None: @@ -436,40 +471,40 @@ def from_pretrained(cls, hf_model_name_or_path: str, convert_weights = False) -> ) weights = mx.load(model_path.as_posix(), format="safetensors") - + if convert_weights: new_weights = {} for k, v in weights.items(): - k = k.replace('ema_model.', '') - + k = k.replace("ema_model.", "") + # rename layers - if len(k) < 1 or 'mel_spec.' in k or k in ('initted', 'step'): + if len(k) < 1 or "mel_spec." in k or k in ("initted", "step"): continue - elif '.to_out' in k: - k = k.replace('.to_out', '.to_out.layers') - elif '.text_blocks' in k: - k = k.replace('.text_blocks', '.text_blocks.layers') - elif '.ff.ff.0.0' in k: - k = k.replace('.ff.ff.0.0', '.ff.ff.layers.0.layers.0') - elif '.ff.ff.2' in k: - k = k.replace('.ff.ff.2', '.ff.ff.layers.2') - elif '.time_mlp' in k: - k = k.replace('.time_mlp', '.time_mlp.layers') - elif '.conv1d' in k: - k = k.replace('.conv1d', '.conv1d.layers') - + elif ".to_out" in k: + k = k.replace(".to_out", ".to_out.layers") + elif ".text_blocks" in k: + k = k.replace(".text_blocks", ".text_blocks.layers") + elif ".ff.ff.0.0" in k: + k = k.replace(".ff.ff.0.0", ".ff.ff.layers.0.layers.0") + elif ".ff.ff.2" in k: + k = k.replace(".ff.ff.2", ".ff.ff.layers.2") + elif ".time_mlp" in k: + k = k.replace(".time_mlp", ".time_mlp.layers") + elif ".conv1d" in k: + k = k.replace(".conv1d", ".conv1d.layers") + # reshape weights - if '.dwconv.weight' in k: + if ".dwconv.weight" in k: v = v.swapaxes(1, 2) - elif '.conv1d.layers.0.weight' in k: + elif ".conv1d.layers.0.weight" in k: v = v.swapaxes(1, 2) - elif '.conv1d.layers.2.weight' in k: + elif ".conv1d.layers.2.weight" in k: v = v.swapaxes(1, 2) - + new_weights[k] = v - + weights = new_weights - + f5tts.load_weights(list(weights.items())) mx.eval(f5tts.parameters()) diff --git a/f5_tts_mlx/generate.py b/f5_tts_mlx/generate.py index 0e4612a..a3faadb 100644 --- a/f5_tts_mlx/generate.py +++ b/f5_tts_mlx/generate.py @@ -1,40 +1,124 @@ import argparse +from collections import deque import datetime import pkgutil +import re +import sys +from threading import Event, Lock from typing import Literal, Optional import mlx.core as mx - import numpy as np from f5_tts_mlx.cfm import F5TTS from f5_tts_mlx.utils import convert_char_to_pinyin +import sounddevice as sd import soundfile as sf +from tqdm import tqdm + SAMPLE_RATE = 24_000 HOP_LENGTH = 256 FRAMES_PER_SEC = SAMPLE_RATE / HOP_LENGTH TARGET_RMS = 0.1 +# utilities + + +def split_sentences(text): + sentence_endings = re.compile(r"([.!?;:])") + sentences = sentence_endings.split(text) + sentences = [ + sentences[i] + sentences[i + 1] for i in range(0, len(sentences) - 1, 2) + ] + return [sentence.strip() for sentence in sentences if sentence.strip()] + + +# playback + + +class AudioPlayer: + def __init__(self, sample_rate=24000, buffer_size=2048): + self.sample_rate = sample_rate + self.buffer_size = buffer_size + self.audio_buffer = deque() + self.buffer_lock = Lock() + self.playing = False + self.drain_event = Event() + + def callback(self, outdata, frames, time, status): + with self.buffer_lock: + if len(self.audio_buffer) > 0: + available = min(frames, len(self.audio_buffer[0])) + chunk = self.audio_buffer[0][:available].copy() + self.audio_buffer[0] = self.audio_buffer[0][available:] + + if len(self.audio_buffer[0]) == 0: + self.audio_buffer.popleft() + if len(self.audio_buffer) == 0: + self.drain_event.set() + + outdata[:, 0] = np.zeros(frames) + outdata[:available, 0] = chunk + else: + outdata[:, 0] = np.zeros(frames) + self.drain_event.set() + + def play(self): + if not self.playing: + self.stream = sd.OutputStream( + samplerate=self.sample_rate, + channels=1, + callback=self.callback, + blocksize=self.buffer_size, + ) + self.stream.start() + self.playing = True + self.drain_event.clear() + + def queue_audio(self, samples): + with self.buffer_lock: + self.audio_buffer.append(np.array(samples)) + if not self.playing: + self.play() + + def wait_for_drain(self): + return self.drain_event.wait() + + def stop(self): + if self.playing: + self.wait_for_drain() + sd.sleep(100) + + self.stream.stop() + self.stream.close() + self.playing = False + + +# generation + + def generate( generation_text: str, duration: Optional[float] = None, model_name: str = "lucasnewman/f5-tts-mlx", ref_audio_path: Optional[str] = None, ref_audio_text: Optional[str] = None, - steps: int = 32, - method: Literal["euler", "midpoint"] = "euler", + steps: int = 8, + method: Literal["euler", "midpoint"] = "rk4", cfg_strength: float = 2.0, sway_sampling_coef: float = -1.0, speed: float = 1.0, # used when duration is None as part of the duration heuristic seed: Optional[int] = None, - output_path: str = "output.wav", + output_path: Optional[str] = None, ): + player = AudioPlayer(sample_rate=SAMPLE_RATE) if output_path is None else None + # the default model already has converted weights convert_weights = model_name != "lucasnewman/f5-tts-mlx" - + f5tts = F5TTS.from_pretrained(model_name, convert_weights=convert_weights) if ref_audio_path is None: @@ -61,35 +145,91 @@ def generate( rms = mx.sqrt(mx.mean(mx.square(audio))) if rms < TARGET_RMS: audio = audio * TARGET_RMS / rms + + sentences = split_sentences(generation_text) + is_single_generation = len(sentences) == 1 or duration is not None - # generate the audio for the given text - text = convert_char_to_pinyin([ref_audio_text + " " + generation_text]) + if is_single_generation: + generation_text = convert_char_to_pinyin( + [ref_audio_text + " " + generation_text] + ) - start_date = datetime.datetime.now() + if duration is not None: + duration = int(duration * FRAMES_PER_SEC) - if duration is not None: - duration = int(duration * FRAMES_PER_SEC) + start_date = datetime.datetime.now() - wave, _ = f5tts.sample( - mx.expand_dims(audio, axis=0), - text=text, - duration=duration, - steps=steps, - method=method, - speed=speed, - cfg_strength=cfg_strength, - sway_sampling_coef=sway_sampling_coef, - seed=seed, - ) + wave, _ = f5tts.sample( + mx.expand_dims(audio, axis=0), + text=generation_text, + duration=duration, + steps=steps, + method=method, + speed=speed, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + seed=seed, + ) + + wave = wave[audio.shape[0] :] + mx.eval(wave) + + generated_duration = wave.shape[0] / SAMPLE_RATE + print( + f"Generated {generated_duration:.2f}s of audio in {datetime.datetime.now() - start_date}." + ) + + if player is not None: + player.queue_audio(wave) + + if output_path is not None: + sf.write(output_path, np.array(wave), SAMPLE_RATE) + + player.stop() + else: + start_date = datetime.datetime.now() + + output = [] + + for sentence_text in tqdm(split_sentences(generation_text)): + text = convert_char_to_pinyin([ref_audio_text + " " + sentence_text]) - # trim the reference audio - wave = wave[audio.shape[0] :] - generated_duration = wave.shape[0] / SAMPLE_RATE - elapsed_time = datetime.datetime.now() - start_date + if duration is not None: + duration = int(duration * FRAMES_PER_SEC) - print(f"Generated {generated_duration:.2f} seconds of audio in {elapsed_time}.") + wave, _ = f5tts.sample( + mx.expand_dims(audio, axis=0), + text=text, + duration=duration, + steps=steps, + method=method, + speed=speed, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + seed=seed, + ) - sf.write(output_path, np.array(wave), SAMPLE_RATE) + # trim the reference audio + wave = wave[audio.shape[0] :] + mx.eval(wave) + + output.append(wave) + + if player is not None: + mx.eval(wave) + player.queue_audio(wave) + + wave = mx.concatenate(output, axis=0) + + generated_duration = wave.shape[0] / SAMPLE_RATE + print( + f"Generated {generated_duration:.2f}s of audio in {datetime.datetime.now() - start_date}." + ) + + if output_path is not None: + sf.write(output_path, np.array(wave), SAMPLE_RATE) + + player.stop() if __name__ == "__main__": @@ -104,7 +244,10 @@ def generate( help="Name of the model to use", ) parser.add_argument( - "--text", type=str, required=True, help="Text to generate speech from" + "--text", + type=str, + default=None, + help="Text to generate speech from (leave blank to input via stdin)", ) parser.add_argument( "--duration", @@ -127,20 +270,20 @@ def generate( parser.add_argument( "--output", type=str, - default="output.wav", + default=None, help="Path to save the generated audio output", ) parser.add_argument( "--steps", type=int, - default=32, + default=8, help="Number of steps to take when sampling the neural ODE", ) parser.add_argument( "--method", type=str, - default="euler", - choices=["euler", "midpoint"], + default="rk4", + choices=["euler", "midpoint", "rk4"], help="Method to use for sampling the neural ODE", ) parser.add_argument( @@ -170,6 +313,13 @@ def generate( args = parser.parse_args() + if args.text is None: + if not sys.stdin.isatty(): + args.text = sys.stdin.read().strip() + else: + print("Please enter the text to generate:") + args.text = input("> ").strip() + generate( generation_text=args.text, duration=args.duration, diff --git a/f5_tts_mlx/utils.py b/f5_tts_mlx/utils.py index 02014e1..50c1154 100644 --- a/f5_tts_mlx/utils.py +++ b/f5_tts_mlx/utils.py @@ -20,6 +20,7 @@ import jieba from pypinyin import lazy_pinyin, Style +jieba.setLogLevel(20) def exists(v): return v is not None