diff --git a/demo/app.py b/demo/app.py index f3f904c..1684521 100644 --- a/demo/app.py +++ b/demo/app.py @@ -5,13 +5,13 @@ import soundfile as sf import streamlit as st +from document_to_podcast.inference.text_to_speech import text_to_speech from document_to_podcast.preprocessing import DATA_LOADERS, DATA_CLEANERS from document_to_podcast.inference.model_loaders import ( load_llama_cpp_model, - load_outetts_model, + load_tts_model, ) from document_to_podcast.config import DEFAULT_PROMPT, DEFAULT_SPEAKERS, Speaker -from document_to_podcast.inference.text_to_speech import text_to_speech from document_to_podcast.inference.text_to_text import text_to_text_stream @@ -24,7 +24,7 @@ def load_text_to_text_model(): @st.cache_resource def load_text_to_speech_model(): - return load_outetts_model("OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf") + return load_tts_model("OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf") script = "script" @@ -153,7 +153,7 @@ def gen_button_clicked(): speech_model, voice_profile, ) - st.audio(speech, sample_rate=speech_model.audio_codec.sr) + st.audio(speech, sample_rate=speech_model.sample_rate) st.session_state.audio.append(speech) text = "" @@ -164,7 +164,7 @@ def gen_button_clicked(): sf.write( "podcast.wav", st.session_state.audio, - samplerate=speech_model.audio_codec.sr, + samplerate=speech_model.sample_rate, ) st.markdown("Podcast saved to disk!") diff --git a/demo/download_models.py b/demo/download_models.py index ed85845..a557f98 100644 --- a/demo/download_models.py +++ b/demo/download_models.py @@ -4,10 +4,10 @@ from document_to_podcast.inference.model_loaders import ( load_llama_cpp_model, - load_outetts_model, + load_tts_model, ) load_llama_cpp_model( "allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf" ) -load_outetts_model("OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf") +load_tts_model("OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf") diff --git a/example_data/config_bark.yaml b/example_data/config_bark.yaml new file mode 100644 index 0000000..77712c8 --- /dev/null +++ b/example_data/config_bark.yaml @@ -0,0 +1,31 @@ +input_file: "example_data/a.md" +output_folder: "example_data/bark/" +text_to_text_model: "allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf" +text_to_speech_model: "suno/bark" +text_to_text_prompt: | + You are a podcast scriptwriter generating engaging and natural-sounding conversations in JSON format. + The script features the following speakers: + {SPEAKERS} + Instructions: + - Write dynamic, easy-to-follow dialogue. + - Include natural interruptions and interjections. + - Avoid repetitive phrasing between speakers. + - Format output as a JSON conversation. + Example: + { + "Speaker 1": "Welcome to our podcast! Today, we're exploring...", + "Speaker 2": "Hi! I'm excited to hear about this. Can you explain...", + "Speaker 1": "Sure! Imagine it like this...", + "Speaker 2": "Oh, that's cool! But how does..." + } +speakers: + - id: 1 + name: Laura + description: The main host. She explains topics clearly using anecdotes and analogies, teaching in an engaging and captivating way. + voice_profile: "v2/en_speaker_0" + + - id: 2 + name: Daniel + description: The co-host. He keeps the conversation on track, asks curious follow-up questions, and reacts with excitement or confusion, often using interjections like hmm or umm. + voice_profile: "v2/en_speaker_1" +outetts_language: "en" # Supported languages in version 0.2-500M: en, zh, ja, ko. \ No newline at end of file diff --git a/example_data/config_parler.yaml b/example_data/config_parler.yaml new file mode 100644 index 0000000..e8240e1 --- /dev/null +++ b/example_data/config_parler.yaml @@ -0,0 +1,31 @@ +input_file: "example_data/a.md" +output_folder: "example_data/parler/" +text_to_text_model: "allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf" +text_to_speech_model: "parler-tts/parler-tts-mini-v1.1" +text_to_text_prompt: | + You are a podcast scriptwriter generating engaging and natural-sounding conversations in JSON format. + The script features the following speakers: + {SPEAKERS} + Instructions: + - Write dynamic, easy-to-follow dialogue. + - Include natural interruptions and interjections. + - Avoid repetitive phrasing between speakers. + - Format output as a JSON conversation. + Example: + { + "Speaker 1": "Welcome to our podcast! Today, we're exploring...", + "Speaker 2": "Hi! I'm excited to hear about this. Can you explain...", + "Speaker 1": "Sure! Imagine it like this...", + "Speaker 2": "Oh, that's cool! But how does..." + } +speakers: + - id: 1 + name: Laura + description: The main host. She explains topics clearly using anecdotes and analogies, teaching in an engaging and captivating way. + voice_profile: Laura's voice is calm and slow in delivery, with no background noise. + + - id: 2 + name: Daniel + description: The co-host. He keeps the conversation on track, asks curious follow-up questions, and reacts with excitement or confusion, often using interjections like hmm or umm. + voice_profile: Daniel's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise. +outetts_language: "en" # Supported languages in version 0.2-500M: en, zh, ja, ko. \ No newline at end of file diff --git a/example_data/config_parler_multi.yaml b/example_data/config_parler_multi.yaml new file mode 100644 index 0000000..656ca90 --- /dev/null +++ b/example_data/config_parler_multi.yaml @@ -0,0 +1,31 @@ +input_file: "example_data/a.md" +output_folder: "example_data/parler_multi/" +text_to_text_model: "allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf" +text_to_speech_model: "parler-tts/parler-tts-mini-multilingual-v1.1" +text_to_text_prompt: | + You are a podcast scriptwriter generating engaging and natural-sounding conversations in JSON format. + The script features the following speakers: + {SPEAKERS} + Instructions: + - Write dynamic, easy-to-follow dialogue. + - Include natural interruptions and interjections. + - Avoid repetitive phrasing between speakers. + - Format output as a JSON conversation. + Example: + { + "Speaker 1": "Welcome to our podcast! Today, we're exploring...", + "Speaker 2": "Hi! I'm excited to hear about this. Can you explain...", + "Speaker 1": "Sure! Imagine it like this...", + "Speaker 2": "Oh, that's cool! But how does..." + } +speakers: + - id: 1 + name: Laura + description: The main host. She explains topics clearly using anecdotes and analogies, teaching in an engaging and captivating way. + voice_profile: Laura's voice is calm and slow in delivery, with no background noise. + + - id: 2 + name: Daniel + description: The co-host. He keeps the conversation on track, asks curious follow-up questions, and reacts with excitement or confusion, often using interjections like hmm or umm. + voice_profile: Daniel's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise. +outetts_language: "en" # Supported languages in version 0.2-500M: en, zh, ja, ko. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a6b8d92..986a172 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "pydantic", "PyPDF2[crypto]", "python-docx", + "transformers>4.31.0", "streamlit", ] diff --git a/src/document_to_podcast/cli.py b/src/document_to_podcast/cli.py index f71e1cb..f75f6e7 100644 --- a/src/document_to_podcast/cli.py +++ b/src/document_to_podcast/cli.py @@ -12,15 +12,14 @@ Speaker, DEFAULT_PROMPT, DEFAULT_SPEAKERS, - SUPPORTED_TTS_MODELS, + TTS_LOADERS, ) from document_to_podcast.inference.model_loaders import ( load_llama_cpp_model, - load_outetts_model, - load_parler_tts_model_and_tokenizer, + load_tts_model, ) -from document_to_podcast.inference.text_to_text import text_to_text_stream from document_to_podcast.inference.text_to_speech import text_to_speech +from document_to_podcast.inference.text_to_text import text_to_text_stream from document_to_podcast.preprocessing import DATA_CLEANERS, DATA_LOADERS @@ -30,8 +29,9 @@ def document_to_podcast( output_folder: str | None = None, text_to_text_model: str = "allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf", text_to_text_prompt: str = DEFAULT_PROMPT, - text_to_speech_model: SUPPORTED_TTS_MODELS = "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf", + text_to_speech_model: TTS_LOADERS = "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf", speakers: list[Speaker] | None = None, + outetts_language: str = "en", # Only applicable to OuteTTS models from_config: str | None = None, ): """ @@ -70,8 +70,10 @@ def document_to_podcast( speakers (list[Speaker] | None, optional): The speakers for the podcast. Defaults to DEFAULT_SPEAKERS. - from_config (str, optional): The path to the config file. Defaults to None. + outetts_language (str): For OuteTTS models we need to specify which language to use. + Supported languages in 0.2-500M: en, zh, ja, ko. More info: https://github.com/edwko/OuteTTS + from_config (str, optional): The path to the config file. Defaults to None. If provided, all other arguments will be ignored. """ @@ -86,6 +88,7 @@ def document_to_podcast( text_to_text_prompt=text_to_text_prompt, text_to_speech_model=text_to_speech_model, speakers=[Speaker.model_validate(speaker) for speaker in speakers], + outetts_language=outetts_language, ) output_folder = Path(config.output_folder) @@ -106,15 +109,9 @@ def document_to_podcast( text_model = load_llama_cpp_model(model_id=config.text_to_text_model) logger.info(f"Loading {config.text_to_speech_model}") - if "oute" in config.text_to_speech_model.lower(): - speech_model = load_outetts_model(model_id=config.text_to_speech_model) - speech_tokenizer = None - sample_rate = speech_model.audio_codec.sr - else: - speech_model, speech_tokenizer = load_parler_tts_model_and_tokenizer( - model_id=config.text_to_speech_model - ) - sample_rate = speech_model.config.sampling_rate + speech_model = load_tts_model( + model_id=config.text_to_speech_model, outetts_language=outetts_language + ) # ~4 characters per token is considered a reasonable default. max_characters = text_model.n_ctx() * 4 @@ -133,33 +130,34 @@ def document_to_podcast( system_prompt = system_prompt.replace( "{SPEAKERS}", "\n".join(str(speaker) for speaker in config.speakers) ) - for chunk in text_to_text_stream( - clean_text, text_model, system_prompt=system_prompt - ): - text += chunk - podcast_script += chunk - if text.endswith("\n") and "Speaker" in text: - logger.debug(text) - speaker_id = re.search(r"Speaker (\d+)", text).group(1) - voice_profile = next( - speaker.voice_profile - for speaker in config.speakers - if speaker.id == int(speaker_id) - ) - speech = text_to_speech( - text.split(f'"Speaker {speaker_id}":')[-1], - speech_model, - voice_profile, - tokenizer=speech_tokenizer, # Applicable only for parler models - ) - podcast_audio.append(speech) - text = "" - + try: + for chunk in text_to_text_stream( + clean_text, text_model, system_prompt=system_prompt + ): + text += chunk + podcast_script += chunk + if text.endswith("\n") and "Speaker" in text: + logger.debug(text) + speaker_id = re.search(r"Speaker (\d+)", text).group(1) + voice_profile = next( + speaker.voice_profile + for speaker in config.speakers + if speaker.id == int(speaker_id) + ) + speech = text_to_speech( + text.split(f'"Speaker {speaker_id}":')[-1], + speech_model, + voice_profile, + ) + podcast_audio.append(speech) + text = "" + except KeyboardInterrupt: + logger.warning("Podcast generation stopped by user.") logger.info("Saving Podcast...") sf.write( str(output_folder / "podcast.wav"), np.concatenate(podcast_audio), - samplerate=sample_rate, + samplerate=speech_model.sample_rate, ) (output_folder / "podcast.txt").write_text(podcast_script) logger.success("Done!") diff --git a/src/document_to_podcast/config.py b/src/document_to_podcast/config.py index e64f381..10f8b6a 100644 --- a/src/document_to_podcast/config.py +++ b/src/document_to_podcast/config.py @@ -1,10 +1,11 @@ from pathlib import Path -from typing import Literal from typing_extensions import Annotated from pydantic import BaseModel, FilePath from pydantic.functional_validators import AfterValidator +from document_to_podcast.inference.model_loaders import TTS_LOADERS +from document_to_podcast.inference.text_to_speech import TTS_INFERENCE from document_to_podcast.preprocessing import DATA_LOADERS @@ -41,14 +42,6 @@ }, ] -SUPPORTED_TTS_MODELS = Literal[ - "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf", - "OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf", - "parler-tts/parler-tts-large-v1", - "parler-tts/parler-tts-mini-v1", - "parler-tts/parler-tts-mini-v1.1", -] - def validate_input_file(value): if Path(value).suffix not in DATA_LOADERS: @@ -73,6 +66,18 @@ def validate_text_to_text_prompt(value): return value +def validate_text_to_speech_model(value): + if value not in TTS_LOADERS: + raise ValueError( + f"Model {value} is missing a loading function. Please define it under model_loaders.py" + ) + if value not in TTS_INFERENCE: + raise ValueError( + f"Model {value} is missing an inference function. Please define it under text_to_speech.py" + ) + return value + + class Speaker(BaseModel): id: int name: str @@ -88,5 +93,6 @@ class Config(BaseModel): output_folder: str text_to_text_model: Annotated[str, AfterValidator(validate_text_to_text_model)] text_to_text_prompt: Annotated[str, AfterValidator(validate_text_to_text_prompt)] - text_to_speech_model: SUPPORTED_TTS_MODELS + text_to_speech_model: Annotated[str, AfterValidator(validate_text_to_speech_model)] speakers: list[Speaker] + outetts_language: str = "en" diff --git a/src/document_to_podcast/inference/model_loaders.py b/src/document_to_podcast/inference/model_loaders.py index ade90e0..33b0b73 100644 --- a/src/document_to_podcast/inference/model_loaders.py +++ b/src/document_to_podcast/inference/model_loaders.py @@ -1,20 +1,22 @@ -from typing import Tuple - +from typing import Union from huggingface_hub import hf_hub_download from llama_cpp import Llama from outetts import GGUFModelConfig_v1, InterfaceGGUF -from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase +from dataclasses import dataclass, field +from transformers import ( + AutoTokenizer, + PreTrainedModel, + AutoProcessor, + BarkModel, +) -def load_llama_cpp_model( - model_id: str, -) -> Llama: +def load_llama_cpp_model(model_id: str) -> Llama: """ Loads the given model_id using Llama.from_pretrained. Examples: - >>> model = load_llama_cpp_model( - "allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf") + >>> model = load_llama_cpp_model("allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf") Args: model_id (str): The model id to load. @@ -34,56 +36,147 @@ def load_llama_cpp_model( return model -def load_outetts_model( - model_id: str, language: str = "en", device: str = "cpu" -) -> InterfaceGGUF: +@dataclass +class TTSModel: """ - Loads the given model_id using the OuteTTS interface. For more info: https://github.com/edwko/OuteTTS + The purpose of this class is to provide a unified interface for all the TTS models supported. + Specifically, different TTS model families have different peculiarities, for example, the bark models need a + BarkProcessor, the parler models need their own tokenizer, etc. This wrapper takes care of this complexity so that + the user doesn't have to deal with it. - Examples: - >>> model = load_outetts_model("OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf", "en", "cpu") + Args: + model (Union[InterfaceGGUF, BarkModel, PreTrainedModel]): A TTS model that has a .generate() method or similar + that takes text as input, and returns an audio in the form of a numpy array. + model_id (str): The model's identifier string. + sample_rate (int): The sample rate of the audio, required for properly saving the audio to a file. + custom_args (dict): Any model-specific arguments that a TTS model might require, e.g. tokenizer. + """ + + model: Union[InterfaceGGUF, BarkModel, PreTrainedModel] + model_id: str + sample_rate: int + custom_args: field(default_factory=dict) + + +def _load_oute_tts(model_id: str, **kwargs) -> TTSModel: + """ + Loads the given model_id using the OuteTTS interface. For more info: https://github.com/edwko/OuteTTS Args: model_id (str): The model id to load. Format is expected to be `{org}/{repo}/{filename}`. language (str): Supported languages in 0.2-500M: en, zh, ja, ko. - device (str): The device to load the model on, such as "cuda:0" or "cpu". - Returns: - PreTrainedModel: The loaded model. + TTSModel: The loaded model using the TTSModel wrapper. """ - n_layers_on_gpu = 0 if device == "cpu" else -1 model_version = model_id.split("-")[1] org, repo, filename = model_id.split("/") local_path = hf_hub_download(repo_id=f"{org}/{repo}", filename=filename) model_config = GGUFModelConfig_v1( - model_path=local_path, language=language, n_gpu_layers=n_layers_on_gpu + model_path=local_path, language=kwargs.pop("language", "en") + ) + model = InterfaceGGUF(model_version=model_version, cfg=model_config) + + return TTSModel( + model=model, model_id=model_id, sample_rate=model.audio_codec.sr, custom_args={} ) - return InterfaceGGUF(model_version=model_version, cfg=model_config) + +def _load_bark_tts(model_id: str, **kwargs) -> TTSModel: + """ + Loads the given model_id and its required processor. For more info: https://github.com/suno-ai/bark + + Args: + model_id (str): The model id to load. + Format is expected to be `{repo}/{filename}`. + Returns: + TTSModel: The loaded model with its required processor using the TTSModel. + """ + + processor = AutoProcessor.from_pretrained(model_id) + model = BarkModel.from_pretrained(model_id) + + return TTSModel( + model=model, + model_id=model_id, + sample_rate=24_000, + custom_args={"processor": processor}, + ) -def load_parler_tts_model_and_tokenizer( - model_id: str, device: str = "cpu" -) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: +def _load_parler_tts(model_id: str, **kwargs) -> TTSModel: """ Loads the given model_id using parler_tts.from_pretrained. For more info: https://github.com/huggingface/parler-tts - Examples: - >>> model, tokenizer = load_parler_tts_model_and_tokenizer("parler-tts/parler-tts-mini-v1", "cpu") + Args: + model_id (str): The model id to load. + Format is expected to be `{repo}/{filename}`. + + Returns: + TTSModel: The loaded model with its required tokenizer for the input. + """ + from parler_tts import ParlerTTSForConditionalGeneration + + model = ParlerTTSForConditionalGeneration.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + return TTSModel( + model=model, + model_id=model_id, + sample_rate=model.config.sampling_rate, + custom_args={ + "tokenizer": tokenizer, + }, + ) + + +def _load_parler_tts_multi(model_id: str, **kwargs) -> TTSModel: + """ + Loads the given model_id using parler_tts.from_pretrained. For more info: https://github.com/huggingface/parler-tts Args: model_id (str): The model id to load. Format is expected to be `{repo}/{filename}`. - device (str): The device to load the model on, such as "cuda:0" or "cpu". Returns: - PreTrainedModel: The loaded model. + TTSModel: The loaded model with its required tokenizer for the input text and + another tokenizer for the description. """ + from parler_tts import ParlerTTSForConditionalGeneration - model = ParlerTTSForConditionalGeneration.from_pretrained(model_id).to(device) + model = ParlerTTSForConditionalGeneration.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) + description_tokenizer = AutoTokenizer.from_pretrained( + model.config.text_encoder._name_or_path + ) + + return TTSModel( + model=model, + model_id=model_id, + sample_rate=model.config.sampling_rate, + custom_args={ + "tokenizer": tokenizer, + "description_tokenizer": description_tokenizer, + }, + ) + + +TTS_LOADERS = { + # To add support for your model, add it here in the format {model_id} : _load_function + "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-Q2_K.gguf": _load_oute_tts, + "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf": _load_oute_tts, + "OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf": _load_oute_tts, + "suno/bark": _load_bark_tts, + "suno/bark-small": _load_bark_tts, + "parler-tts/parler-tts-large-v1": _load_parler_tts, + "parler-tts/parler-tts-mini-v1": _load_parler_tts, + "parler-tts/parler-tts-mini-v1.1": _load_parler_tts, + "parler-tts/parler-tts-mini-multilingual-v1.1": _load_parler_tts_multi, + "ai4bharat/indic-parler-tts": _load_parler_tts_multi, +} + - return model, tokenizer +def load_tts_model(model_id: str, **kwargs) -> TTSModel: + return TTS_LOADERS[model_id](model_id, **kwargs) diff --git a/src/document_to_podcast/inference/text_to_speech.py b/src/document_to_podcast/inference/text_to_speech.py index 2041da8..de5e006 100644 --- a/src/document_to_podcast/inference/text_to_speech.py +++ b/src/document_to_podcast/inference/text_to_speech.py @@ -1,23 +1,39 @@ -from typing import Union - import numpy as np from outetts.version.v1.interface import InterfaceGGUF -from transformers import PreTrainedTokenizerBase, PreTrainedModel +from transformers import PreTrainedModel, BarkModel + +from document_to_podcast.inference.model_loaders import TTSModel def _text_to_speech_oute( input_text: str, model: InterfaceGGUF, voice_profile: str, - temperature: float = 0.3, + **kwargs, ) -> np.ndarray: + """ + TTS generation function for the Oute TTS model family. + Args: + input_text (str): The text to convert to speech. + model: A model from the Oute TTS family. + voice_profile: a pre-defined ID for the Oute models (e.g. "female_1") + more info here https://github.com/edwko/OuteTTS/tree/main/outetts/version/v1/default_speakers + temperature (float, default = 0.3): Controls the randomness of predictions by scaling the logits. + Lower values make the output more focused and deterministic, higher values produce more diverse results. + repetition_penalty (float, default = 1.1): Applies a penalty to tokens that have already been generated, + reducing the likelihood of repetition and enhancing text variety. + max_length (int, default = 4096): Defines the maximum number of tokens for the generated text sequence. + + Returns: + numpy array: The waveform of the speech as a 2D numpy array + """ speaker = model.load_default_speaker(name=voice_profile) output = model.generate( text=input_text, - temperature=temperature, - repetition_penalty=1.1, - max_length=4096, + temperature=kwargs.pop("temperature", 0.3), + repetition_penalty=kwargs.pop("repetition_penalty", 1.1), + max_length=kwargs.pop("max_length", 4096), speaker=speaker, ) @@ -25,12 +41,52 @@ def _text_to_speech_oute( return output_as_np +def _text_to_speech_bark( + input_test: str, model: BarkModel, voice_profile: str, **kwargs +) -> np.ndarray: + """ + TTS generation function for the Bark model family. + Args: + input_text (str): The text to convert to speech. + model: A model from the Bark family. + voice_profile: a pre-defined ID for the Bark model (e.g. "v2/en_speaker_0") + more info here https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c + processor: Required BarkProcessor to prepare the input text for the Bark model + + Returns: + numpy array: The waveform of the speech as a 2D numpy array + """ + processor = kwargs.get("processor") + if processor is None: + raise ValueError("Bark model requires a processor") + + inputs = processor(input_test, voice_preset=voice_profile) + + generation = model.generate(**inputs) + waveform = generation.cpu().numpy().squeeze() + + return waveform + + def _text_to_speech_parler( - input_text: str, - model: PreTrainedModel, - tokenizer: PreTrainedTokenizerBase, - voice_profile: str, + input_text: str, model: PreTrainedModel, voice_profile: str, **kwargs ) -> np.ndarray: + """ + TTS generation function for the Parler TTS model family. + Args: + input_text (str): The text to convert to speech. + model: A model from the Parler TTS family. + voice_profile: a natural description of the voice profile using a pre-defined name for the Parler model (e.g. Laura's voice is calm) + more info here https://github.com/huggingface/parler-tts?tab=readme-ov-file#-using-a-specific-speaker + tokenizer (PreTrainedTokenizer): Required PreTrainedTokenizer to tokenize the input text. + + Returns: + numpy array: The waveform of the speech as a 2D numpy array + """ + tokenizer = kwargs.get("tokenizer") + if tokenizer is None: + raise ValueError("Parler model requires a tokenizer") + input_ids = tokenizer(voice_profile, return_tensors="pt").input_ids prompt_input_ids = tokenizer(input_text, return_tensors="pt").input_ids @@ -40,33 +96,54 @@ def _text_to_speech_parler( return waveform -def text_to_speech( - input_text: str, - model: Union[InterfaceGGUF, PreTrainedModel], - voice_profile: str, - tokenizer: PreTrainedTokenizerBase = None, +def _text_to_speech_parler_multi( + input_text: str, model: PreTrainedModel, voice_profile: str, **kwargs ) -> np.ndarray: """ - Generates a speech waveform from a text input using a pre-trained text-to-speech (TTS) model. - - Examples: - >>> waveform = text_to_speech(input_text="Welcome to our amazing podcast", model=model, voice_profile="male_1") - + TTS generation function for the Parler TTS multilingual model family. Args: input_text (str): The text to convert to speech. - model (PreTrainedModel): The model used for generating the waveform. - voice_profile (str): Depending on the selected TTS model it should either be - - a pre-defined ID for the Oute models (e.g. "female_1") - more info here https://github.com/edwko/OuteTTS/tree/main/outetts/version/v1/default_speakers - - a natural description of the voice profile using a pre-defined name for the Parler model (e.g. Laura's voice is calm) - more info here https://github.com/huggingface/parler-tts?tab=readme-ov-file#-using-a-specific-speaker - tokenizer (PreTrainedTokenizerBase): [Only used for the Parler models!] The tokenizer used for tokenizing the text in order to send to the model. + model: A model from the Parler TTS multilingual family. + voice_profile: a natural description of the voice profile using a pre-defined name for the Parler model (e.g. Laura's voice is calm) + more info here https://huggingface.co/parler-tts/parler-tts-mini-multilingual-v1.1 + for the indic version https://huggingface.co/ai4bharat/indic-parler-tts + tokenizer (PreTrainedTokenizer): Required PreTrainedTokenizer to tokenize the input text. + description_tokenizer (PreTrainedTokenizer): Required PreTrainedTokenizer to tokenize the description text. + Returns: numpy array: The waveform of the speech as a 2D numpy array """ - if isinstance(model, InterfaceGGUF): - return _text_to_speech_oute(input_text, model, voice_profile) - elif isinstance(model, PreTrainedModel): - return _text_to_speech_parler(input_text, model, tokenizer, voice_profile) - else: - raise NotImplementedError("Model not yet implemented for TTS") + tokenizer = kwargs.get("tokenizer") + if tokenizer is None: + raise ValueError("Parler model requires a tokenizer") + + description_tokenizer = kwargs.get("description_tokenizer") + if description_tokenizer is None: + raise ValueError("Parler multilingual model requires a description tokenizer") + + input_ids = description_tokenizer(voice_profile, return_tensors="pt").input_ids + prompt_input_ids = tokenizer(input_text, return_tensors="pt").input_ids + + generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) + waveform = generation.cpu().numpy().squeeze() + return waveform + + +TTS_INFERENCE = { + # To add support for your model, add it here in the format {model_id} : _inference_function + "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-Q2_K.gguf": _text_to_speech_oute, + "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf": _text_to_speech_oute, + "OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf": _text_to_speech_oute, + "suno/bark": _text_to_speech_bark, + "parler-tts/parler-tts-large-v1": _text_to_speech_parler, + "parler-tts/parler-tts-mini-v1": _text_to_speech_parler, + "parler-tts/parler-tts-mini-v1.1": _text_to_speech_parler, + "parler-tts/parler-tts-mini-multilingual-v1.1": _text_to_speech_parler_multi, + "ai4bharat/indic-parler-tts": _text_to_speech_parler_multi, +} + + +def text_to_speech(input_text: str, model: TTSModel, voice_profile: str) -> np.ndarray: + return TTS_INFERENCE[model.model_id]( + input_text, model.model, voice_profile, **model.custom_args + ) diff --git a/tests/unit/inference/test_model_loaders.py b/tests/unit/inference/test_model_loaders.py index 2c11629..e337ff5 100644 --- a/tests/unit/inference/test_model_loaders.py +++ b/tests/unit/inference/test_model_loaders.py @@ -1,15 +1,19 @@ +from typing import Dict, Any, Union + +import pytest from llama_cpp import Llama from document_to_podcast.inference.model_loaders import ( load_llama_cpp_model, - load_outetts_model, + load_tts_model, ) -from transformers import PreTrainedModel, PreTrainedTokenizerBase -from outetts.version.v1.interface import InterfaceGGUF - -from document_to_podcast.inference.model_loaders import ( - load_parler_tts_model_and_tokenizer, +from transformers import ( + PreTrainedModel, + PreTrainedTokenizerBase, + BarkModel, + BarkProcessor, ) +from outetts.version.v1.interface import InterfaceGGUF def test_load_llama_cpp_model(): @@ -21,16 +25,31 @@ def test_load_llama_cpp_model(): assert model.n_ctx() == 2048 -def test_load_outetts_model(): - model = load_outetts_model( - "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-Q2_K.gguf" - ) - assert isinstance(model, InterfaceGGUF) - - -def test_load_parler_tts_model_and_tokenizer(): - model, tokenizer = load_parler_tts_model_and_tokenizer( - "parler-tts/parler-tts-mini-v1" - ) - assert isinstance(model, PreTrainedModel) - assert isinstance(tokenizer, PreTrainedTokenizerBase) +@pytest.mark.parametrize( + "model_id, expected_model_type, expected_custom_args", + [ + ["OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-Q2_K.gguf", InterfaceGGUF, {}], + ["suno/bark-small", BarkModel, {"processor": BarkProcessor}], + [ + "parler-tts/parler-tts-mini-v1.1", + PreTrainedModel, + { + "tokenizer": PreTrainedTokenizerBase, + "description_tokenizer": PreTrainedTokenizerBase, + }, + ], + ], +) +def test_load_tts_model( + model_id: str, + expected_model_type: Union[InterfaceGGUF, BarkModel, PreTrainedModel], + expected_custom_args: Dict[str, Any], +) -> None: + model = load_tts_model(model_id) + assert isinstance(model.model, expected_model_type) + assert model.model_id == model_id + for (k, v), (e_k, e_v) in zip( + model.custom_args.items(), expected_custom_args.items() + ): + assert k == e_k + assert isinstance(v, e_v) diff --git a/tests/unit/inference/test_text_to_speech.py b/tests/unit/inference/test_text_to_speech.py index 7ae239c..7d3a1b4 100644 --- a/tests/unit/inference/test_text_to_speech.py +++ b/tests/unit/inference/test_text_to_speech.py @@ -1,16 +1,24 @@ from outetts.version.v1.interface import InterfaceGGUF from transformers import PreTrainedModel +from document_to_podcast.inference.model_loaders import TTSModel from document_to_podcast.inference.text_to_speech import text_to_speech def test_text_to_speech_oute(mocker): model = mocker.MagicMock(spec_set=InterfaceGGUF) - text_to_speech( - "Hello?", + tts_model = TTSModel( model=model, + model_id="OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-Q2_K.gguf", + sample_rate=0, + custom_args={}, + ) + text_to_speech( + input_text="Hello?", + model=tts_model, voice_profile="female_1", ) + model.load_default_speaker.assert_called_with(name=mocker.ANY) model.generate.assert_called_with( text=mocker.ANY, @@ -21,18 +29,79 @@ def test_text_to_speech_oute(mocker): ) +def test_text_to_speech_bark(mocker): + model = mocker.MagicMock(spec_set=PreTrainedModel) + processor = mocker.MagicMock() + + tts_model = TTSModel( + model=model, + model_id="suno/bark", + sample_rate=0, + custom_args={"processor": processor}, + ) + text_to_speech( + input_text="Hello?", + model=tts_model, + voice_profile="v2/en_speaker_0", + ) + processor.assert_has_calls( + [ + mocker.call("Hello?", voice_preset="v2/en_speaker_0"), + ] + ) + model.generate.assert_called_with() + + def test_text_to_speech_parler(mocker): model = mocker.MagicMock(spec_set=PreTrainedModel) tokenizer = mocker.MagicMock() + + tts_model = TTSModel( + model=model, + model_id="parler-tts/parler-tts-mini-v1.1", + sample_rate=0, + custom_args={"tokenizer": tokenizer}, + ) text_to_speech( - "Hello?", + input_text="Hello?", + model=tts_model, + voice_profile="Laura's voice is calm", + ) + tokenizer.assert_has_calls( + [ + mocker.call("Laura's voice is calm", return_tensors="pt"), + mocker.call("Hello?", return_tensors="pt"), + ] + ) + model.generate.assert_called_with(input_ids=mocker.ANY, prompt_input_ids=mocker.ANY) + + +def test_text_to_speech_parler_multi(mocker): + model = mocker.MagicMock(spec_set=PreTrainedModel) + tokenizer = mocker.MagicMock() + description_tokenizer = mocker.MagicMock() + + tts_model = TTSModel( model=model, - tokenizer=tokenizer, - voice_profile="default", + model_id="ai4bharat/indic-parler-tts", + sample_rate=0, + custom_args={ + "tokenizer": tokenizer, + "description_tokenizer": description_tokenizer, + }, + ) + text_to_speech( + input_text="Hello?", + model=tts_model, + voice_profile="Laura's voice is calm", + ) + description_tokenizer.assert_has_calls( + [ + mocker.call("Laura's voice is calm", return_tensors="pt"), + ] ) tokenizer.assert_has_calls( [ - mocker.call("default", return_tensors="pt"), mocker.call("Hello?", return_tensors="pt"), ] )