From 286a93c0ae8cdfa3bc5a441ebe13ff35fdecea0b Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Tue, 17 Dec 2024 16:09:07 +0000
Subject: [PATCH 01/26] [WIP] Add bark and parler multi support

---
 .../inference/model_loaders.py                | 65 ++++++++++++++-----
 .../inference/text_to_speech.py               | 44 +++++++++++--
 2 files changed, 84 insertions(+), 25 deletions(-)

diff --git a/src/document_to_podcast/inference/model_loaders.py b/src/document_to_podcast/inference/model_loaders.py
index ade90e0..f5622a1 100644
--- a/src/document_to_podcast/inference/model_loaders.py
+++ b/src/document_to_podcast/inference/model_loaders.py
@@ -1,9 +1,15 @@
 from typing import Tuple
-
 from huggingface_hub import hf_hub_download
 from llama_cpp import Llama
 from outetts import GGUFModelConfig_v1, InterfaceGGUF
-from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
+from transformers import (
+    AutoTokenizer,
+    PreTrainedModel,
+    PreTrainedTokenizerBase,
+    AutoProcessor,
+    BarkModel,
+    BarkProcessor,
+)
 
 
 def load_llama_cpp_model(
@@ -34,56 +40,79 @@ def load_llama_cpp_model(
     return model
 
 
-def load_outetts_model(
-    model_id: str, language: str = "en", device: str = "cpu"
-) -> InterfaceGGUF:
+def load_outetts_model(model_id: str, language: str = "en") -> InterfaceGGUF:
     """
     Loads the given model_id using the OuteTTS interface. For more info: https://github.com/edwko/OuteTTS
 
     Examples:
-        >>> model = load_outetts_model("OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf", "en", "cpu")
+        >>> model = load_outetts_model("OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf", "en")
 
     Args:
         model_id (str): The model id to load.
             Format is expected to be `{org}/{repo}/{filename}`.
         language (str): Supported languages in 0.2-500M: en, zh, ja, ko.
-        device (str): The device to load the model on, such as "cuda:0" or "cpu".
-
     Returns:
         PreTrainedModel: The loaded model.
     """
-    n_layers_on_gpu = 0 if device == "cpu" else -1
     model_version = model_id.split("-")[1]
 
     org, repo, filename = model_id.split("/")
     local_path = hf_hub_download(repo_id=f"{org}/{repo}", filename=filename)
-    model_config = GGUFModelConfig_v1(
-        model_path=local_path, language=language, n_gpu_layers=n_layers_on_gpu
-    )
+    model_config = GGUFModelConfig_v1(model_path=local_path, language=language)
 
     return InterfaceGGUF(model_version=model_version, cfg=model_config)
 
 
+def load_bark_tts(model_id: str) -> Tuple[BarkModel, BarkProcessor]:
+    """
+    Loads the given model_id and its required processor. For more info: https://github.com/suno-ai/bark
+
+    Examples:
+        >>> model, processor = load_bark_tts("suno/bark", "cpu")
+
+    Args:
+        model_id (str): The model id to load.
+            Format is expected to be `{repo}/{filename}`.
+
+    Returns:
+        BarkModel: The loaded model.
+        BarkProcessor: The loaded model.
+    """
+
+    processor = AutoProcessor.from_pretrained(model_id)
+    model = BarkModel.from_pretrained(model_id)
+
+    return model, processor
+
+
 def load_parler_tts_model_and_tokenizer(
-    model_id: str, device: str = "cpu"
-) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
+    model_id: str,
+) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase, PreTrainedTokenizerBase]:
     """
     Loads the given model_id using parler_tts.from_pretrained. For more info: https://github.com/huggingface/parler-tts
 
     Examples:
-        >>> model, tokenizer = load_parler_tts_model_and_tokenizer("parler-tts/parler-tts-mini-v1", "cpu")
+        >>> model, tokenizer, _ = load_parler_tts_model_and_tokenizer("parler-tts/parler-tts-mini-v1")
 
     Args:
         model_id (str): The model id to load.
             Format is expected to be `{repo}/{filename}`.
-        device (str): The device to load the model on, such as "cuda:0" or "cpu".
 
     Returns:
         PreTrainedModel: The loaded model.
+        PreTrainedTokenizer: The loaded tokenizer for the input.
+        PreTrainedTokenizer: [Only for the multilingual models] The loaded tokenizer for the description.
     """
     from parler_tts import ParlerTTSForConditionalGeneration
 
-    model = ParlerTTSForConditionalGeneration.from_pretrained(model_id).to(device)
+    model = ParlerTTSForConditionalGeneration.from_pretrained(model_id)
     tokenizer = AutoTokenizer.from_pretrained(model_id)
 
-    return model, tokenizer
+    description_tokenizer = (
+        AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path)
+        if model_id == "parler-tts/parler-tts-mini-multilingual-v1.1"
+        or model_id == "ai4bharat/indic-parler-tts"
+        else None
+    )
+
+    return model, tokenizer, description_tokenizer
diff --git a/src/document_to_podcast/inference/text_to_speech.py b/src/document_to_podcast/inference/text_to_speech.py
index 2041da8..405735b 100644
--- a/src/document_to_podcast/inference/text_to_speech.py
+++ b/src/document_to_podcast/inference/text_to_speech.py
@@ -2,7 +2,12 @@
 
 import numpy as np
 from outetts.version.v1.interface import InterfaceGGUF
-from transformers import PreTrainedTokenizerBase, PreTrainedModel
+from transformers import (
+    PreTrainedTokenizerBase,
+    PreTrainedModel,
+    BarkModel,
+    BarkProcessor,
+)
 
 
 def _text_to_speech_oute(
@@ -25,13 +30,28 @@ def _text_to_speech_oute(
     return output_as_np
 
 
+def _text_to_speech__bark(
+    input_test: str, model: BarkModel, processor: BarkProcessor, voice_profile: str
+):
+    inputs = processor(input_test, voice_preset=voice_profile)
+
+    generation = model.generate(**inputs)
+    waveform = generation.cpu().numpy().squeeze()
+
+    return waveform
+
+
 def _text_to_speech_parler(
     input_text: str,
     model: PreTrainedModel,
     tokenizer: PreTrainedTokenizerBase,
     voice_profile: str,
+    description_tokenizer: PreTrainedTokenizerBase = None,
 ) -> np.ndarray:
-    input_ids = tokenizer(voice_profile, return_tensors="pt").input_ids
+    if description_tokenizer:
+        input_ids = description_tokenizer(voice_profile, return_tensors="pt").input_ids
+    else:
+        input_ids = tokenizer(voice_profile, return_tensors="pt").input_ids
     prompt_input_ids = tokenizer(input_text, return_tensors="pt").input_ids
 
     generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
@@ -42,9 +62,10 @@ def _text_to_speech_parler(
 
 def text_to_speech(
     input_text: str,
-    model: Union[InterfaceGGUF, PreTrainedModel],
+    model: Union[InterfaceGGUF, PreTrainedModel, BarkModel],
     voice_profile: str,
-    tokenizer: PreTrainedTokenizerBase = None,
+    processor: Union[BarkProcessor, PreTrainedTokenizerBase] = None,
+    description_tokenizer: PreTrainedTokenizerBase = None,
 ) -> np.ndarray:
     """
     Generates a speech waveform from a text input using a pre-trained text-to-speech (TTS) model.
@@ -58,15 +79,24 @@ def text_to_speech(
         voice_profile (str): Depending on the selected TTS model it should either be
             - a pre-defined ID for the Oute models (e.g. "female_1")
             more info here https://github.com/edwko/OuteTTS/tree/main/outetts/version/v1/default_speakers
+            - a pre-defined ID for the Bark model (e.g. "v2/en_speaker_0")
+            more info here https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c
             - a natural description of the voice profile using a pre-defined name for the Parler model (e.g. Laura's voice is calm)
             more info here https://github.com/huggingface/parler-tts?tab=readme-ov-file#-using-a-specific-speaker
-        tokenizer (PreTrainedTokenizerBase): [Only used for the Parler models!] The tokenizer used for tokenizing the text in order to send to the model.
+            for the multilingual model: https://huggingface.co/parler-tts/parler-tts-mini-multilingual-v1.1
+            for the indic model: https://huggingface.co/ai4bharat/indic-parler-tts
+        processor (BarkProcessor or PreTrainedTokenizerBase): [Only used for the Bark or Parler models!]
+            In bark models, this is an HF processor. In Parler models, this is a pretrained tokenizer.
     Returns:
         numpy array: The waveform of the speech as a 2D numpy array
     """
     if isinstance(model, InterfaceGGUF):
         return _text_to_speech_oute(input_text, model, voice_profile)
+    elif isinstance(model, BarkModel):
+        return _text_to_speech__bark(input_text, model, processor, voice_profile)
     elif isinstance(model, PreTrainedModel):
-        return _text_to_speech_parler(input_text, model, tokenizer, voice_profile)
+        return _text_to_speech_parler(
+            input_text, model, processor, voice_profile, description_tokenizer
+        )
     else:
-        raise NotImplementedError("Model not yet implemented for TTS")
+        raise NotImplementedError

From 14b69bf2569926802d5522fd7d84de0cd0cac8c0 Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Tue, 17 Dec 2024 19:33:57 +0000
Subject: [PATCH 02/26] Add config files for other models to easily test across
 models

---
 example_data/config_bark.yaml         | 31 +++++++++++++++++++++++++++
 example_data/config_parler.yaml       | 31 +++++++++++++++++++++++++++
 example_data/config_parler_multi.yaml | 31 +++++++++++++++++++++++++++
 3 files changed, 93 insertions(+)
 create mode 100644 example_data/config_bark.yaml
 create mode 100644 example_data/config_parler.yaml
 create mode 100644 example_data/config_parler_multi.yaml

diff --git a/example_data/config_bark.yaml b/example_data/config_bark.yaml
new file mode 100644
index 0000000..77712c8
--- /dev/null
+++ b/example_data/config_bark.yaml
@@ -0,0 +1,31 @@
+input_file: "example_data/a.md"
+output_folder: "example_data/bark/"
+text_to_text_model: "allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf"
+text_to_speech_model: "suno/bark"
+text_to_text_prompt: |
+  You are a podcast scriptwriter generating engaging and natural-sounding conversations in JSON format. 
+  The script features the following speakers:
+  {SPEAKERS}
+  Instructions:
+  - Write dynamic, easy-to-follow dialogue.
+  - Include natural interruptions and interjections.
+  - Avoid repetitive phrasing between speakers.
+  - Format output as a JSON conversation.
+  Example:
+  {
+    "Speaker 1": "Welcome to our podcast! Today, we're exploring...",
+    "Speaker 2": "Hi! I'm excited to hear about this. Can you explain...",
+    "Speaker 1": "Sure! Imagine it like this...",
+    "Speaker 2": "Oh, that's cool! But how does..."
+  }
+speakers:
+  - id: 1
+    name: Laura
+    description: The main host. She explains topics clearly using anecdotes and analogies, teaching in an engaging and captivating way.
+    voice_profile: "v2/en_speaker_0"
+
+  - id: 2
+    name: Daniel
+    description: The co-host. He keeps the conversation on track, asks curious follow-up questions, and reacts with excitement or confusion, often using interjections like hmm or umm.
+    voice_profile: "v2/en_speaker_1"
+outetts_language: "en"  # Supported languages in version 0.2-500M: en, zh, ja, ko.
\ No newline at end of file
diff --git a/example_data/config_parler.yaml b/example_data/config_parler.yaml
new file mode 100644
index 0000000..e8240e1
--- /dev/null
+++ b/example_data/config_parler.yaml
@@ -0,0 +1,31 @@
+input_file: "example_data/a.md"
+output_folder: "example_data/parler/"
+text_to_text_model: "allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf"
+text_to_speech_model: "parler-tts/parler-tts-mini-v1.1"
+text_to_text_prompt: |
+  You are a podcast scriptwriter generating engaging and natural-sounding conversations in JSON format. 
+  The script features the following speakers:
+  {SPEAKERS}
+  Instructions:
+  - Write dynamic, easy-to-follow dialogue.
+  - Include natural interruptions and interjections.
+  - Avoid repetitive phrasing between speakers.
+  - Format output as a JSON conversation.
+  Example:
+  {
+    "Speaker 1": "Welcome to our podcast! Today, we're exploring...",
+    "Speaker 2": "Hi! I'm excited to hear about this. Can you explain...",
+    "Speaker 1": "Sure! Imagine it like this...",
+    "Speaker 2": "Oh, that's cool! But how does..."
+  }
+speakers:
+  - id: 1
+    name: Laura
+    description: The main host. She explains topics clearly using anecdotes and analogies, teaching in an engaging and captivating way.
+    voice_profile: Laura's voice is calm and slow in delivery, with no background noise.
+
+  - id: 2
+    name: Daniel
+    description: The co-host. He keeps the conversation on track, asks curious follow-up questions, and reacts with excitement or confusion, often using interjections like hmm or umm.
+    voice_profile: Daniel's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise.
+outetts_language: "en"  # Supported languages in version 0.2-500M: en, zh, ja, ko.
\ No newline at end of file
diff --git a/example_data/config_parler_multi.yaml b/example_data/config_parler_multi.yaml
new file mode 100644
index 0000000..656ca90
--- /dev/null
+++ b/example_data/config_parler_multi.yaml
@@ -0,0 +1,31 @@
+input_file: "example_data/a.md"
+output_folder: "example_data/parler_multi/"
+text_to_text_model: "allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf"
+text_to_speech_model: "parler-tts/parler-tts-mini-multilingual-v1.1"
+text_to_text_prompt: |
+  You are a podcast scriptwriter generating engaging and natural-sounding conversations in JSON format. 
+  The script features the following speakers:
+  {SPEAKERS}
+  Instructions:
+  - Write dynamic, easy-to-follow dialogue.
+  - Include natural interruptions and interjections.
+  - Avoid repetitive phrasing between speakers.
+  - Format output as a JSON conversation.
+  Example:
+  {
+    "Speaker 1": "Welcome to our podcast! Today, we're exploring...",
+    "Speaker 2": "Hi! I'm excited to hear about this. Can you explain...",
+    "Speaker 1": "Sure! Imagine it like this...",
+    "Speaker 2": "Oh, that's cool! But how does..."
+  }
+speakers:
+  - id: 1
+    name: Laura
+    description: The main host. She explains topics clearly using anecdotes and analogies, teaching in an engaging and captivating way.
+    voice_profile: Laura's voice is calm and slow in delivery, with no background noise.
+
+  - id: 2
+    name: Daniel
+    description: The co-host. He keeps the conversation on track, asks curious follow-up questions, and reacts with excitement or confusion, often using interjections like hmm or umm.
+    voice_profile: Daniel's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise.
+outetts_language: "en"  # Supported languages in version 0.2-500M: en, zh, ja, ko.
\ No newline at end of file

From 20ab8e97b9f2dd0fb424dee5a6534d10ab45fd2c Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Tue, 17 Dec 2024 19:46:19 +0000
Subject: [PATCH 03/26] Use model loading wrapper function for
 download_models.py

---
 demo/download_models.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/demo/download_models.py b/demo/download_models.py
index ed85845..a557f98 100644
--- a/demo/download_models.py
+++ b/demo/download_models.py
@@ -4,10 +4,10 @@
 
 from document_to_podcast.inference.model_loaders import (
     load_llama_cpp_model,
-    load_outetts_model,
+    load_tts_model,
 )
 
 load_llama_cpp_model(
     "allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf"
 )
-load_outetts_model("OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf")
+load_tts_model("OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf")

From ee38e10bd75123c3066b25ac7a32e20b1acf00cb Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Tue, 17 Dec 2024 19:46:54 +0000
Subject: [PATCH 04/26] Make sure transformers>4.31.0 (required for bark model)

---
 pyproject.toml | 1 +
 1 file changed, 1 insertion(+)

diff --git a/pyproject.toml b/pyproject.toml
index 4cc6fc8..642d553 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,6 +17,7 @@ dependencies = [
   "pydantic",
   "PyPDF2[crypto]",
   "python-docx",
+  "transformers>4.31.0",
   "streamlit",
 ]
 

From 890c684e6a65d73160869ff62b2d6c8e61e8aea8 Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Tue, 17 Dec 2024 19:47:06 +0000
Subject: [PATCH 05/26] Add parler dependency

---
 pyproject.toml | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/pyproject.toml b/pyproject.toml
index 642d553..8c15697 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -34,6 +34,10 @@ tests = [
   "pytest-mock>=3.14.0"
 ]
 
+parler = [
+  "parler_tts @ git+https://github.com/huggingface/parler-tts.git",
+]
+
 [project.urls]
 Documentation = "https://mozilla-ai.github.io/document-to-podcast/"
 Issues = "https://github.com/mozilla-ai/document-to-podcast/issues"

From 8cc7b0d60b27b5fc380204518062b801846142f0 Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Tue, 17 Dec 2024 19:56:11 +0000
Subject: [PATCH 06/26] Use TTSModelWrapper for demo code

---
 demo/app.py | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/demo/app.py b/demo/app.py
index f3f904c..3e1eb49 100644
--- a/demo/app.py
+++ b/demo/app.py
@@ -8,10 +8,9 @@
 from document_to_podcast.preprocessing import DATA_LOADERS, DATA_CLEANERS
 from document_to_podcast.inference.model_loaders import (
     load_llama_cpp_model,
-    load_outetts_model,
+    load_tts_model,
 )
 from document_to_podcast.config import DEFAULT_PROMPT, DEFAULT_SPEAKERS, Speaker
-from document_to_podcast.inference.text_to_speech import text_to_speech
 from document_to_podcast.inference.text_to_text import text_to_text_stream
 
 
@@ -24,7 +23,7 @@ def load_text_to_text_model():
 
 @st.cache_resource
 def load_text_to_speech_model():
-    return load_outetts_model("OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf")
+    return load_tts_model("OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf")
 
 
 script = "script"
@@ -148,12 +147,11 @@ def gen_button_clicked():
                         if speaker["id"] == int(speaker_id)
                     )
                     with st.spinner("Generating Audio..."):
-                        speech = text_to_speech(
+                        speech = speech_model.text_to_speech(
                             text.split(f'"Speaker {speaker_id}":')[-1],
-                            speech_model,
                             voice_profile,
                         )
-                    st.audio(speech, sample_rate=speech_model.audio_codec.sr)
+                    st.audio(speech, sample_rate=speech_model.sample_rate)
 
                     st.session_state.audio.append(speech)
                     text = ""
@@ -164,7 +162,7 @@ def gen_button_clicked():
             sf.write(
                 "podcast.wav",
                 st.session_state.audio,
-                samplerate=speech_model.audio_codec.sr,
+                samplerate=speech_model.sample_rate,
             )
             st.markdown("Podcast saved to disk!")
 

From dcbb25418f04578ce8438c5bd9c81022919ef5ea Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Tue, 17 Dec 2024 19:57:36 +0000
Subject: [PATCH 07/26] Use TTSModelWrapper for cli

---
 src/document_to_podcast/cli.py | 28 +++++++++++-----------------
 1 file changed, 11 insertions(+), 17 deletions(-)

diff --git a/src/document_to_podcast/cli.py b/src/document_to_podcast/cli.py
index f71e1cb..d79168c 100644
--- a/src/document_to_podcast/cli.py
+++ b/src/document_to_podcast/cli.py
@@ -16,11 +16,9 @@
 )
 from document_to_podcast.inference.model_loaders import (
     load_llama_cpp_model,
-    load_outetts_model,
-    load_parler_tts_model_and_tokenizer,
+    load_tts_model,
 )
 from document_to_podcast.inference.text_to_text import text_to_text_stream
-from document_to_podcast.inference.text_to_speech import text_to_speech
 from document_to_podcast.preprocessing import DATA_CLEANERS, DATA_LOADERS
 
 
@@ -32,6 +30,7 @@ def document_to_podcast(
     text_to_text_prompt: str = DEFAULT_PROMPT,
     text_to_speech_model: SUPPORTED_TTS_MODELS = "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf",
     speakers: list[Speaker] | None = None,
+    outetts_language: str = "en",  # Only applicable to OuteTTS models
     from_config: str | None = None,
 ):
     """
@@ -70,8 +69,10 @@ def document_to_podcast(
         speakers (list[Speaker] | None, optional): The speakers for the podcast.
             Defaults to DEFAULT_SPEAKERS.
 
-        from_config (str, optional): The path to the config file. Defaults to None.
+        outetts_language (str): For OuteTTS models we need to specify which language to use.
+            Supported languages in 0.2-500M: en, zh, ja, ko. More info: https://github.com/edwko/OuteTTS
 
+        from_config (str, optional): The path to the config file. Defaults to None.
 
             If provided, all other arguments will be ignored.
     """
@@ -86,6 +87,7 @@ def document_to_podcast(
             text_to_text_prompt=text_to_text_prompt,
             text_to_speech_model=text_to_speech_model,
             speakers=[Speaker.model_validate(speaker) for speaker in speakers],
+            outetts_language=outetts_language,
         )
 
     output_folder = Path(config.output_folder)
@@ -106,15 +108,9 @@ def document_to_podcast(
     text_model = load_llama_cpp_model(model_id=config.text_to_text_model)
 
     logger.info(f"Loading {config.text_to_speech_model}")
-    if "oute" in config.text_to_speech_model.lower():
-        speech_model = load_outetts_model(model_id=config.text_to_speech_model)
-        speech_tokenizer = None
-        sample_rate = speech_model.audio_codec.sr
-    else:
-        speech_model, speech_tokenizer = load_parler_tts_model_and_tokenizer(
-            model_id=config.text_to_speech_model
-        )
-        sample_rate = speech_model.config.sampling_rate
+    speech_model = load_tts_model(
+        model_id=config.text_to_speech_model, outetts_language=outetts_language
+    )
 
     # ~4 characters per token is considered a reasonable default.
     max_characters = text_model.n_ctx() * 4
@@ -146,11 +142,9 @@ def document_to_podcast(
                 for speaker in config.speakers
                 if speaker.id == int(speaker_id)
             )
-            speech = text_to_speech(
+            speech = speech_model.text_to_speech(
                 text.split(f'"Speaker {speaker_id}":')[-1],
-                speech_model,
                 voice_profile,
-                tokenizer=speech_tokenizer,  # Applicable only for parler models
             )
             podcast_audio.append(speech)
             text = ""
@@ -159,7 +153,7 @@ def document_to_podcast(
     sf.write(
         str(output_folder / "podcast.wav"),
         np.concatenate(podcast_audio),
-        samplerate=sample_rate,
+        samplerate=speech_model.sample_rate,
     )
     (output_folder / "podcast.txt").write_text(podcast_script)
     logger.success("Done!")

From b0d40bc5c458b7454d5df939e9ba81982a786439 Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Tue, 17 Dec 2024 19:58:48 +0000
Subject: [PATCH 08/26] Add outetts_language attribute

---
 src/document_to_podcast/config.py | 12 +++---------
 1 file changed, 3 insertions(+), 9 deletions(-)

diff --git a/src/document_to_podcast/config.py b/src/document_to_podcast/config.py
index e64f381..d072509 100644
--- a/src/document_to_podcast/config.py
+++ b/src/document_to_podcast/config.py
@@ -5,6 +5,7 @@
 from pydantic import BaseModel, FilePath
 from pydantic.functional_validators import AfterValidator
 
+from document_to_podcast.inference.model_loaders import SUPPORTED_TTS_MODELS
 from document_to_podcast.preprocessing import DATA_LOADERS
 
 
@@ -41,14 +42,6 @@
     },
 ]
 
-SUPPORTED_TTS_MODELS = Literal[
-    "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf",
-    "OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf",
-    "parler-tts/parler-tts-large-v1",
-    "parler-tts/parler-tts-mini-v1",
-    "parler-tts/parler-tts-mini-v1.1",
-]
-
 
 def validate_input_file(value):
     if Path(value).suffix not in DATA_LOADERS:
@@ -88,5 +81,6 @@ class Config(BaseModel):
     output_folder: str
     text_to_text_model: Annotated[str, AfterValidator(validate_text_to_text_model)]
     text_to_text_prompt: Annotated[str, AfterValidator(validate_text_to_text_prompt)]
-    text_to_speech_model: SUPPORTED_TTS_MODELS
+    text_to_speech_model: Literal[tuple(list(SUPPORTED_TTS_MODELS.keys()))]
     speakers: list[Speaker]
+    outetts_language: str = "en"

From 5e47b1e9bb9bed7bb818273d8ac842c8e3cc7a2d Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Tue, 17 Dec 2024 19:59:28 +0000
Subject: [PATCH 09/26] Add TTSModelWrapper

---
 .../inference/model_loaders.py                | 166 +++++++++++++++---
 1 file changed, 137 insertions(+), 29 deletions(-)

diff --git a/src/document_to_podcast/inference/model_loaders.py b/src/document_to_podcast/inference/model_loaders.py
index f5622a1..1e53cdf 100644
--- a/src/document_to_podcast/inference/model_loaders.py
+++ b/src/document_to_podcast/inference/model_loaders.py
@@ -1,4 +1,4 @@
-from typing import Tuple
+from typing import Union
 from huggingface_hub import hf_hub_download
 from llama_cpp import Llama
 from outetts import GGUFModelConfig_v1, InterfaceGGUF
@@ -10,17 +10,23 @@
     BarkModel,
     BarkProcessor,
 )
+import numpy as np
+
+from document_to_podcast.inference.text_to_speech import (
+    _text_to_speech_oute,
+    _text_to_speech_bark,
+    _text_to_speech_parler,
+    _text_to_speech_parler_multi,
+    _text_to_speech_parler_indic,
+)
 
 
-def load_llama_cpp_model(
-    model_id: str,
-) -> Llama:
+def load_llama_cpp_model(model_id: str) -> Llama:
     """
     Loads the given model_id using Llama.from_pretrained.
 
     Examples:
-        >>> model = load_llama_cpp_model(
-            "allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf")
+        >>> model = load_llama_cpp_model("allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf")
 
     Args:
         model_id (str): The model id to load.
@@ -40,68 +46,142 @@ def load_llama_cpp_model(
     return model
 
 
-def load_outetts_model(model_id: str, language: str = "en") -> InterfaceGGUF:
+class TTSModelWrapper:
+    """
+    The purpose of this wrapper is to provide a unified interface for all the TTS models supported.
+    Specifically, different TTS model families have different peculiarities, for example, the bark model needs a
+    BarkProcessor, the parler models need their own tokenizer, etc. This wrapper takes care of this complexity so that
+    the user doesn't have to deal with it.
     """
-    Loads the given model_id using the OuteTTS interface. For more info: https://github.com/edwko/OuteTTS
 
-    Examples:
-        >>> model = load_outetts_model("OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf", "en")
+    def __init__(
+        self,
+        model: Union[InterfaceGGUF, BarkModel, PreTrainedModel],
+        model_id: str,
+        sample_rate: int,
+        bark_processor: BarkProcessor = None,  # Only required for Bark models
+        parler_tokenizer: PreTrainedTokenizerBase = None,  # Only required for Parler models
+        parler_description_tokenizer: PreTrainedTokenizerBase = None,  # Only required for multilingual Parler models
+    ):
+        self.model = model
+        self.model_id = model_id
+        self.sample_rate = sample_rate
+        self.bark_processor = bark_processor
+        self.parler_tokenizer = parler_tokenizer
+        self.parler_description_tokenizer = parler_description_tokenizer
+
+    def text_to_speech(
+        self,
+        input_text: str,
+        voice_profile: str,
+    ) -> np.ndarray:
+        """
+        Generates a speech waveform from a text input using a pre-trained text-to-speech (TTS) model.
+
+        Args:
+            input_text (str): The text to convert to speech.
+            voice_profile (str): Depending on the selected TTS model it should either be
+                - a pre-defined ID for the Oute models (e.g. "female_1")
+                more info here https://github.com/edwko/OuteTTS/tree/main/outetts/version/v1/default_speakers
+                - a pre-defined ID for the Bark model (e.g. "v2/en_speaker_0")
+                more info here https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c
+                - a natural description of the voice profile using a pre-defined name for the Parler model (e.g. Laura's voice is calm)
+                more info here https://github.com/huggingface/parler-tts?tab=readme-ov-file#-using-a-specific-speaker
+                for the multilingual model: https://huggingface.co/parler-tts/parler-tts-mini-multilingual-v1.1
+                for the indic model: https://huggingface.co/ai4bharat/indic-parler-tts
+        Returns:
+            numpy array: The waveform of the speech as a 2D numpy array
+        """
+        # TODO: no.
+        if "oute" in self.model_id:
+            return SUPPORTED_TTS_MODELS[self.model_id][1](
+                input_text, self.model, voice_profile
+            )
+        elif "bark" in self.model_id:
+            return SUPPORTED_TTS_MODELS[self.model_id][1](
+                input_text, self.model, self.bark_processor, voice_profile
+            )
+        elif (
+            "parler" in self.model_id
+            and "multi" in self.model_id
+            or "indic" in self.model_id
+        ):
+            return SUPPORTED_TTS_MODELS[self.model_id][1](
+                input_text,
+                self.model,
+                self.parler_tokenizer,
+                voice_profile,
+                self.parler_description_tokenizer,
+            )
+        elif "parler" in self.model_id:
+            return SUPPORTED_TTS_MODELS[self.model_id][1](
+                input_text, self.model, self.parler_tokenizer, voice_profile
+            )
+        else:
+            raise NotImplementedError
+
+
+def load_tts_model(model_id: str, outetts_language: str = "en") -> TTSModelWrapper:
+    if "oute" in model_id:
+        return SUPPORTED_TTS_MODELS[model_id][0](model_id, outetts_language)
+    else:
+        return SUPPORTED_TTS_MODELS[model_id][0](model_id)
+
+
+def _load_oute_tts(model_id: str, language: str = "en") -> TTSModelWrapper:
+    """
+    Loads the given model_id using the OuteTTS interface. For more info: https://github.com/edwko/OuteTTS
 
     Args:
         model_id (str): The model id to load.
             Format is expected to be `{org}/{repo}/{filename}`.
         language (str): Supported languages in 0.2-500M: en, zh, ja, ko.
     Returns:
-        PreTrainedModel: The loaded model.
+        TTSModelWrapper: The loaded model using the TTSModelWrapper.
     """
     model_version = model_id.split("-")[1]
 
     org, repo, filename = model_id.split("/")
     local_path = hf_hub_download(repo_id=f"{org}/{repo}", filename=filename)
     model_config = GGUFModelConfig_v1(model_path=local_path, language=language)
+    model = InterfaceGGUF(model_version=model_version, cfg=model_config)
 
-    return InterfaceGGUF(model_version=model_version, cfg=model_config)
+    return TTSModelWrapper(
+        model=model, model_id=model_id, sample_rate=model.audio_codec.sr
+    )
 
 
-def load_bark_tts(model_id: str) -> Tuple[BarkModel, BarkProcessor]:
+def _load_bark_tts(model_id: str) -> TTSModelWrapper:
     """
     Loads the given model_id and its required processor. For more info: https://github.com/suno-ai/bark
 
-    Examples:
-        >>> model, processor = load_bark_tts("suno/bark", "cpu")
-
     Args:
         model_id (str): The model id to load.
             Format is expected to be `{repo}/{filename}`.
 
     Returns:
-        BarkModel: The loaded model.
-        BarkProcessor: The loaded model.
+        TTSModelWrapper: The loaded model with its required processor using the TTSModelWrapper.
     """
 
     processor = AutoProcessor.from_pretrained(model_id)
     model = BarkModel.from_pretrained(model_id)
 
-    return model, processor
+    return TTSModelWrapper(
+        model=model, model_id=model_id, sample_rate=24_000, bark_processor=processor
+    )
 
 
-def load_parler_tts_model_and_tokenizer(
-    model_id: str,
-) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase, PreTrainedTokenizerBase]:
+def _load_parler_tts(model_id: str) -> TTSModelWrapper:
     """
     Loads the given model_id using parler_tts.from_pretrained. For more info: https://github.com/huggingface/parler-tts
 
-    Examples:
-        >>> model, tokenizer, _ = load_parler_tts_model_and_tokenizer("parler-tts/parler-tts-mini-v1")
-
     Args:
         model_id (str): The model id to load.
             Format is expected to be `{repo}/{filename}`.
 
     Returns:
-        PreTrainedModel: The loaded model.
-        PreTrainedTokenizer: The loaded tokenizer for the input.
-        PreTrainedTokenizer: [Only for the multilingual models] The loaded tokenizer for the description.
+        TTSModelWrapper: The loaded model with its required tokenizer for the input. For the multilingual models we also
+        load another tokenizer for the description
     """
     from parler_tts import ParlerTTSForConditionalGeneration
 
@@ -115,4 +195,32 @@ def load_parler_tts_model_and_tokenizer(
         else None
     )
 
-    return model, tokenizer, description_tokenizer
+    return TTSModelWrapper(
+        model=model,
+        model_id=model_id,
+        sample_rate=model.config.sampling_rate,
+        parler_tokenizer=tokenizer,
+        parler_description_tokenizer=description_tokenizer,
+    )
+
+
+SUPPORTED_TTS_MODELS = {
+    # To add support for your model, add it here in the format {model_id} : [_load_function, _text_to_speech_function]
+    "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf": [
+        _load_oute_tts,
+        _text_to_speech_oute,
+    ],
+    "OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf": [
+        _load_oute_tts,
+        _text_to_speech_oute,
+    ],
+    "suno/bark": [_load_bark_tts, _text_to_speech_bark],
+    "parler-tts/parler-tts-large-v1": [_load_parler_tts, _text_to_speech_parler],
+    "parler-tts/parler-tts-mini-v1": [_load_parler_tts, _text_to_speech_parler],
+    "parler-tts/parler-tts-mini-v1.1": [_load_parler_tts, _text_to_speech_parler],
+    "parler-tts/parler-tts-mini-multilingual-v1.1": [
+        _load_parler_tts,
+        _text_to_speech_parler_multi,
+    ],
+    "ai4bharat/indic-parler-tts": [_load_parler_tts, _text_to_speech_parler_indic],
+}

From 945c44f3daf43b9b44d12fc4b59d46cfaaf7feb7 Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Tue, 17 Dec 2024 20:00:22 +0000
Subject: [PATCH 10/26] Update text_to_speech.py

---
 .../inference/text_to_speech.py               | 77 ++++++++-----------
 1 file changed, 33 insertions(+), 44 deletions(-)

diff --git a/src/document_to_podcast/inference/text_to_speech.py b/src/document_to_podcast/inference/text_to_speech.py
index 405735b..a14ed1e 100644
--- a/src/document_to_podcast/inference/text_to_speech.py
+++ b/src/document_to_podcast/inference/text_to_speech.py
@@ -1,5 +1,3 @@
-from typing import Union
-
 import numpy as np
 from outetts.version.v1.interface import InterfaceGGUF
 from transformers import (
@@ -30,9 +28,9 @@ def _text_to_speech_oute(
     return output_as_np
 
 
-def _text_to_speech__bark(
+def _text_to_speech_bark(
     input_test: str, model: BarkModel, processor: BarkProcessor, voice_profile: str
-):
+) -> np.ndarray:
     inputs = processor(input_test, voice_preset=voice_profile)
 
     generation = model.generate(**inputs)
@@ -46,12 +44,24 @@ def _text_to_speech_parler(
     model: PreTrainedModel,
     tokenizer: PreTrainedTokenizerBase,
     voice_profile: str,
+) -> np.ndarray:
+    input_ids = tokenizer(voice_profile, return_tensors="pt").input_ids
+    prompt_input_ids = tokenizer(input_text, return_tensors="pt").input_ids
+
+    generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
+    waveform = generation.cpu().numpy().squeeze()
+
+    return waveform
+
+
+def _text_to_speech_parler_multi(
+    input_text: str,
+    model: PreTrainedModel,
+    tokenizer: PreTrainedTokenizerBase,
+    voice_profile: str,
     description_tokenizer: PreTrainedTokenizerBase = None,
 ) -> np.ndarray:
-    if description_tokenizer:
-        input_ids = description_tokenizer(voice_profile, return_tensors="pt").input_ids
-    else:
-        input_ids = tokenizer(voice_profile, return_tensors="pt").input_ids
+    input_ids = description_tokenizer(voice_profile, return_tensors="pt").input_ids
     prompt_input_ids = tokenizer(input_text, return_tensors="pt").input_ids
 
     generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
@@ -60,43 +70,22 @@ def _text_to_speech_parler(
     return waveform
 
 
-def text_to_speech(
+def _text_to_speech_parler_indic(
     input_text: str,
-    model: Union[InterfaceGGUF, PreTrainedModel, BarkModel],
+    model: PreTrainedModel,
+    tokenizer: PreTrainedTokenizerBase,
     voice_profile: str,
-    processor: Union[BarkProcessor, PreTrainedTokenizerBase] = None,
     description_tokenizer: PreTrainedTokenizerBase = None,
 ) -> np.ndarray:
-    """
-    Generates a speech waveform from a text input using a pre-trained text-to-speech (TTS) model.
-
-    Examples:
-        >>> waveform = text_to_speech(input_text="Welcome to our amazing podcast", model=model, voice_profile="male_1")
-
-    Args:
-        input_text (str): The text to convert to speech.
-        model (PreTrainedModel): The model used for generating the waveform.
-        voice_profile (str): Depending on the selected TTS model it should either be
-            - a pre-defined ID for the Oute models (e.g. "female_1")
-            more info here https://github.com/edwko/OuteTTS/tree/main/outetts/version/v1/default_speakers
-            - a pre-defined ID for the Bark model (e.g. "v2/en_speaker_0")
-            more info here https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c
-            - a natural description of the voice profile using a pre-defined name for the Parler model (e.g. Laura's voice is calm)
-            more info here https://github.com/huggingface/parler-tts?tab=readme-ov-file#-using-a-specific-speaker
-            for the multilingual model: https://huggingface.co/parler-tts/parler-tts-mini-multilingual-v1.1
-            for the indic model: https://huggingface.co/ai4bharat/indic-parler-tts
-        processor (BarkProcessor or PreTrainedTokenizerBase): [Only used for the Bark or Parler models!]
-            In bark models, this is an HF processor. In Parler models, this is a pretrained tokenizer.
-    Returns:
-        numpy array: The waveform of the speech as a 2D numpy array
-    """
-    if isinstance(model, InterfaceGGUF):
-        return _text_to_speech_oute(input_text, model, voice_profile)
-    elif isinstance(model, BarkModel):
-        return _text_to_speech__bark(input_text, model, processor, voice_profile)
-    elif isinstance(model, PreTrainedModel):
-        return _text_to_speech_parler(
-            input_text, model, processor, voice_profile, description_tokenizer
-        )
-    else:
-        raise NotImplementedError
+    input_ids = description_tokenizer(voice_profile, return_tensors="pt").input_ids
+    prompt_input_ids = tokenizer(input_text, return_tensors="pt").input_ids
+
+    generation = model.generate(
+        input_ids=input_ids.input_ids,
+        attention_mask=input_ids.attention_mask,
+        prompt_input_ids=prompt_input_ids.input_ids,
+        prompt_attention_mask=prompt_input_ids.attention_mask,
+    )
+    waveform = generation.cpu().numpy().squeeze()
+
+    return waveform

From 4565fb80edc832a93684eb8a388894c9a2e35e2f Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Wed, 18 Dec 2024 12:50:44 +0000
Subject: [PATCH 11/26] Pass model-specific variables as **kwargs

---
 .../inference/model_loaders.py                |  70 ++++-------
 .../inference/text_to_speech.py               | 112 ++++++++++++++----
 2 files changed, 109 insertions(+), 73 deletions(-)

diff --git a/src/document_to_podcast/inference/model_loaders.py b/src/document_to_podcast/inference/model_loaders.py
index 1e53cdf..ba7d542 100644
--- a/src/document_to_podcast/inference/model_loaders.py
+++ b/src/document_to_podcast/inference/model_loaders.py
@@ -5,10 +5,8 @@
 from transformers import (
     AutoTokenizer,
     PreTrainedModel,
-    PreTrainedTokenizerBase,
     AutoProcessor,
     BarkModel,
-    BarkProcessor,
 )
 import numpy as np
 
@@ -48,7 +46,7 @@ def load_llama_cpp_model(model_id: str) -> Llama:
 
 class TTSModelWrapper:
     """
-    The purpose of this wrapper is to provide a unified interface for all the TTS models supported.
+    The purpose of this class is to provide a unified interface for all the TTS models supported.
     Specifically, different TTS model families have different peculiarities, for example, the bark model needs a
     BarkProcessor, the parler models need their own tokenizer, etc. This wrapper takes care of this complexity so that
     the user doesn't have to deal with it.
@@ -59,21 +57,15 @@ def __init__(
         model: Union[InterfaceGGUF, BarkModel, PreTrainedModel],
         model_id: str,
         sample_rate: int,
-        bark_processor: BarkProcessor = None,  # Only required for Bark models
-        parler_tokenizer: PreTrainedTokenizerBase = None,  # Only required for Parler models
-        parler_description_tokenizer: PreTrainedTokenizerBase = None,  # Only required for multilingual Parler models
+        **kwargs,
     ):
         self.model = model
         self.model_id = model_id
         self.sample_rate = sample_rate
-        self.bark_processor = bark_processor
-        self.parler_tokenizer = parler_tokenizer
-        self.parler_description_tokenizer = parler_description_tokenizer
+        self.kwargs = kwargs
 
     def text_to_speech(
-        self,
-        input_text: str,
-        voice_profile: str,
+        self, input_text: str, voice_profile: str, **kwargs
     ) -> np.ndarray:
         """
         Generates a speech waveform from a text input using a pre-trained text-to-speech (TTS) model.
@@ -92,40 +84,22 @@ def text_to_speech(
         Returns:
             numpy array: The waveform of the speech as a 2D numpy array
         """
-        # TODO: no.
-        if "oute" in self.model_id:
-            return SUPPORTED_TTS_MODELS[self.model_id][1](
-                input_text, self.model, voice_profile
-            )
-        elif "bark" in self.model_id:
-            return SUPPORTED_TTS_MODELS[self.model_id][1](
-                input_text, self.model, self.bark_processor, voice_profile
-            )
-        elif (
-            "parler" in self.model_id
-            and "multi" in self.model_id
-            or "indic" in self.model_id
-        ):
-            return SUPPORTED_TTS_MODELS[self.model_id][1](
-                input_text,
-                self.model,
-                self.parler_tokenizer,
-                voice_profile,
-                self.parler_description_tokenizer,
-            )
-        elif "parler" in self.model_id:
-            return SUPPORTED_TTS_MODELS[self.model_id][1](
-                input_text, self.model, self.parler_tokenizer, voice_profile
-            )
-        else:
-            raise NotImplementedError
-
-
-def load_tts_model(model_id: str, outetts_language: str = "en") -> TTSModelWrapper:
-    if "oute" in model_id:
-        return SUPPORTED_TTS_MODELS[model_id][0](model_id, outetts_language)
-    else:
-        return SUPPORTED_TTS_MODELS[model_id][0](model_id)
+        return SUPPORTED_TTS_MODELS[self.model_id][1](
+            input_text, self.model, voice_profile, **self.kwargs | kwargs
+        )
+
+
+def load_tts_model(model_id: str, **kwargs) -> TTSModelWrapper:
+    """
+
+    Args:
+        model_id:
+        outetts_language:
+
+    Returns:
+
+    """
+    return SUPPORTED_TTS_MODELS[model_id][0](model_id, **kwargs)
 
 
 def _load_oute_tts(model_id: str, language: str = "en") -> TTSModelWrapper:
@@ -151,7 +125,7 @@ def _load_oute_tts(model_id: str, language: str = "en") -> TTSModelWrapper:
     )
 
 
-def _load_bark_tts(model_id: str) -> TTSModelWrapper:
+def _load_bark_tts(model_id: str, **kwargs) -> TTSModelWrapper:
     """
     Loads the given model_id and its required processor. For more info: https://github.com/suno-ai/bark
 
@@ -171,7 +145,7 @@ def _load_bark_tts(model_id: str) -> TTSModelWrapper:
     )
 
 
-def _load_parler_tts(model_id: str) -> TTSModelWrapper:
+def _load_parler_tts(model_id: str, **kwargs) -> TTSModelWrapper:
     """
     Loads the given model_id using parler_tts.from_pretrained. For more info: https://github.com/huggingface/parler-tts
 
diff --git a/src/document_to_podcast/inference/text_to_speech.py b/src/document_to_podcast/inference/text_to_speech.py
index a14ed1e..f4477a5 100644
--- a/src/document_to_podcast/inference/text_to_speech.py
+++ b/src/document_to_podcast/inference/text_to_speech.py
@@ -1,26 +1,35 @@
 import numpy as np
 from outetts.version.v1.interface import InterfaceGGUF
-from transformers import (
-    PreTrainedTokenizerBase,
-    PreTrainedModel,
-    BarkModel,
-    BarkProcessor,
-)
+from transformers import PreTrainedModel, BarkModel
 
 
 def _text_to_speech_oute(
     input_text: str,
     model: InterfaceGGUF,
     voice_profile: str,
-    temperature: float = 0.3,
+    **kwargs,
 ) -> np.ndarray:
+    """
+
+    Args:
+        input_text:
+        model:
+        voice_profile:
+
+        temperature: float = 0.3:
+        repetition_penalty: float = 1.1
+        max_length: int = 4096
+
+    Returns:
+
+    """
     speaker = model.load_default_speaker(name=voice_profile)
 
     output = model.generate(
         text=input_text,
-        temperature=temperature,
-        repetition_penalty=1.1,
-        max_length=4096,
+        temperature=kwargs.pop("temperature", 0.3),
+        repetition_penalty=kwargs.pop("repetition_penalty", 1.1),
+        max_length=kwargs.pop("max_length", 4096),
         speaker=speaker,
     )
 
@@ -29,8 +38,22 @@ def _text_to_speech_oute(
 
 
 def _text_to_speech_bark(
-    input_test: str, model: BarkModel, processor: BarkProcessor, voice_profile: str
+    input_test: str, model: BarkModel, voice_profile: str, **kwargs
 ) -> np.ndarray:
+    """
+
+    Args:
+        input_test:
+        model:
+        voice_profile:
+        processor: BarkProcessor
+
+    Returns:
+
+    """
+    if not (processor := kwargs.pop("processor", None)):
+        raise ValueError("Bark model requires a processor")
+
     inputs = processor(input_test, voice_preset=voice_profile)
 
     generation = model.generate(**inputs)
@@ -40,11 +63,22 @@ def _text_to_speech_bark(
 
 
 def _text_to_speech_parler(
-    input_text: str,
-    model: PreTrainedModel,
-    tokenizer: PreTrainedTokenizerBase,
-    voice_profile: str,
+    input_text: str, model: PreTrainedModel, voice_profile: str, **kwargs
 ) -> np.ndarray:
+    """
+
+    Args:
+        input_text:
+        model:
+        voice_profile:
+        tokenizer: PreTrainedTokenizer
+
+    Returns:
+
+    """
+
+    if not (tokenizer := kwargs.pop("parler_tokenizer", None)):
+        raise ValueError("Parler model requires a tokenizer")
     input_ids = tokenizer(voice_profile, return_tensors="pt").input_ids
     prompt_input_ids = tokenizer(input_text, return_tensors="pt").input_ids
 
@@ -55,12 +89,27 @@ def _text_to_speech_parler(
 
 
 def _text_to_speech_parler_multi(
-    input_text: str,
-    model: PreTrainedModel,
-    tokenizer: PreTrainedTokenizerBase,
-    voice_profile: str,
-    description_tokenizer: PreTrainedTokenizerBase = None,
+    input_text: str, model: PreTrainedModel, voice_profile: str, **kwargs
 ) -> np.ndarray:
+    """
+
+    Args:
+        input_text:
+        model:
+        voice_profile:
+        tokenizer:
+        description_tokenizer:
+
+    Returns:
+
+    """
+
+    if not (tokenizer := kwargs.pop("parler_tokenizer", None)):
+        raise ValueError("Parler model requires a tokenizer")
+
+    if not (description_tokenizer := kwargs.pop("parler_description_tokenizer", None)):
+        raise ValueError("Parler multilingual model requires a description tokenizer")
+
     input_ids = description_tokenizer(voice_profile, return_tensors="pt").input_ids
     prompt_input_ids = tokenizer(input_text, return_tensors="pt").input_ids
 
@@ -71,12 +120,25 @@ def _text_to_speech_parler_multi(
 
 
 def _text_to_speech_parler_indic(
-    input_text: str,
-    model: PreTrainedModel,
-    tokenizer: PreTrainedTokenizerBase,
-    voice_profile: str,
-    description_tokenizer: PreTrainedTokenizerBase = None,
+    input_text: str, model: PreTrainedModel, voice_profile: str, **kwargs
 ) -> np.ndarray:
+    """
+
+    Args:
+        input_text:
+        model:
+        voice_profile:
+        tokenizer:
+        description_tokenizer:
+
+    Returns:
+
+    """
+    if not (tokenizer := kwargs.pop("parler_tokenizer", None)):
+        raise ValueError("Parler model requires a tokenizer")
+    if not (description_tokenizer := kwargs.pop("parler_description_tokenizer", None)):
+        raise ValueError("Parler multilingual model requires a description tokenizer")
+
     input_ids = description_tokenizer(voice_profile, return_tensors="pt").input_ids
     prompt_input_ids = tokenizer(input_text, return_tensors="pt").input_ids
 

From 01d0e7a6a825ca5b0958f3049cd738e260e34d58 Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Wed, 18 Dec 2024 12:51:35 +0000
Subject: [PATCH 12/26] Rename TTSModelWrapper to TTSInterface

---
 .../inference/model_loaders.py                | 22 +++++++++----------
 1 file changed, 11 insertions(+), 11 deletions(-)

diff --git a/src/document_to_podcast/inference/model_loaders.py b/src/document_to_podcast/inference/model_loaders.py
index ba7d542..6d6ff41 100644
--- a/src/document_to_podcast/inference/model_loaders.py
+++ b/src/document_to_podcast/inference/model_loaders.py
@@ -44,7 +44,7 @@ def load_llama_cpp_model(model_id: str) -> Llama:
     return model
 
 
-class TTSModelWrapper:
+class TTSInterface:
     """
     The purpose of this class is to provide a unified interface for all the TTS models supported.
     Specifically, different TTS model families have different peculiarities, for example, the bark model needs a
@@ -89,7 +89,7 @@ def text_to_speech(
         )
 
 
-def load_tts_model(model_id: str, **kwargs) -> TTSModelWrapper:
+def load_tts_model(model_id: str, **kwargs) -> TTSInterface:
     """
 
     Args:
@@ -102,7 +102,7 @@ def load_tts_model(model_id: str, **kwargs) -> TTSModelWrapper:
     return SUPPORTED_TTS_MODELS[model_id][0](model_id, **kwargs)
 
 
-def _load_oute_tts(model_id: str, language: str = "en") -> TTSModelWrapper:
+def _load_oute_tts(model_id: str, language: str = "en") -> TTSInterface:
     """
     Loads the given model_id using the OuteTTS interface. For more info: https://github.com/edwko/OuteTTS
 
@@ -111,7 +111,7 @@ def _load_oute_tts(model_id: str, language: str = "en") -> TTSModelWrapper:
             Format is expected to be `{org}/{repo}/{filename}`.
         language (str): Supported languages in 0.2-500M: en, zh, ja, ko.
     Returns:
-        TTSModelWrapper: The loaded model using the TTSModelWrapper.
+        TTSInterface: The loaded model using the TTSModelWrapper.
     """
     model_version = model_id.split("-")[1]
 
@@ -120,12 +120,12 @@ def _load_oute_tts(model_id: str, language: str = "en") -> TTSModelWrapper:
     model_config = GGUFModelConfig_v1(model_path=local_path, language=language)
     model = InterfaceGGUF(model_version=model_version, cfg=model_config)
 
-    return TTSModelWrapper(
+    return TTSInterface(
         model=model, model_id=model_id, sample_rate=model.audio_codec.sr
     )
 
 
-def _load_bark_tts(model_id: str, **kwargs) -> TTSModelWrapper:
+def _load_bark_tts(model_id: str, **kwargs) -> TTSInterface:
     """
     Loads the given model_id and its required processor. For more info: https://github.com/suno-ai/bark
 
@@ -134,18 +134,18 @@ def _load_bark_tts(model_id: str, **kwargs) -> TTSModelWrapper:
             Format is expected to be `{repo}/{filename}`.
 
     Returns:
-        TTSModelWrapper: The loaded model with its required processor using the TTSModelWrapper.
+        TTSInterface: The loaded model with its required processor using the TTSModelWrapper.
     """
 
     processor = AutoProcessor.from_pretrained(model_id)
     model = BarkModel.from_pretrained(model_id)
 
-    return TTSModelWrapper(
+    return TTSInterface(
         model=model, model_id=model_id, sample_rate=24_000, bark_processor=processor
     )
 
 
-def _load_parler_tts(model_id: str, **kwargs) -> TTSModelWrapper:
+def _load_parler_tts(model_id: str, **kwargs) -> TTSInterface:
     """
     Loads the given model_id using parler_tts.from_pretrained. For more info: https://github.com/huggingface/parler-tts
 
@@ -154,7 +154,7 @@ def _load_parler_tts(model_id: str, **kwargs) -> TTSModelWrapper:
             Format is expected to be `{repo}/{filename}`.
 
     Returns:
-        TTSModelWrapper: The loaded model with its required tokenizer for the input. For the multilingual models we also
+        TTSInterface: The loaded model with its required tokenizer for the input. For the multilingual models we also
         load another tokenizer for the description
     """
     from parler_tts import ParlerTTSForConditionalGeneration
@@ -169,7 +169,7 @@ def _load_parler_tts(model_id: str, **kwargs) -> TTSModelWrapper:
         else None
     )
 
-    return TTSModelWrapper(
+    return TTSInterface(
         model=model,
         model_id=model_id,
         sample_rate=model.config.sampling_rate,

From 5af3e72563e535c4affe05fac9f84a426df27823 Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Wed, 18 Dec 2024 12:54:50 +0000
Subject: [PATCH 13/26] Update language argument to kwargs

---
 src/document_to_podcast/inference/model_loaders.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/src/document_to_podcast/inference/model_loaders.py b/src/document_to_podcast/inference/model_loaders.py
index 6d6ff41..4560932 100644
--- a/src/document_to_podcast/inference/model_loaders.py
+++ b/src/document_to_podcast/inference/model_loaders.py
@@ -102,7 +102,7 @@ def load_tts_model(model_id: str, **kwargs) -> TTSInterface:
     return SUPPORTED_TTS_MODELS[model_id][0](model_id, **kwargs)
 
 
-def _load_oute_tts(model_id: str, language: str = "en") -> TTSInterface:
+def _load_oute_tts(model_id: str, **kwargs) -> TTSInterface:
     """
     Loads the given model_id using the OuteTTS interface. For more info: https://github.com/edwko/OuteTTS
 
@@ -117,7 +117,9 @@ def _load_oute_tts(model_id: str, language: str = "en") -> TTSInterface:
 
     org, repo, filename = model_id.split("/")
     local_path = hf_hub_download(repo_id=f"{org}/{repo}", filename=filename)
-    model_config = GGUFModelConfig_v1(model_path=local_path, language=language)
+    model_config = GGUFModelConfig_v1(
+        model_path=local_path, language=kwargs.pop("language", "en")
+    )
     model = InterfaceGGUF(model_version=model_version, cfg=model_config)
 
     return TTSInterface(

From e3a3f17d0bae0fd44a6738623bd280be086855bb Mon Sep 17 00:00:00 2001
From: Kostis <Kostis-S-Z@users.noreply.github.com>
Date: Wed, 18 Dec 2024 14:33:58 +0000
Subject: [PATCH 14/26] Remove parler from dependencies

Co-authored-by: David de la Iglesia Castro <daviddelaiglesiacastro@gmail.com>
---
 pyproject.toml | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml
index 8c15697..642d553 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -34,10 +34,6 @@ tests = [
   "pytest-mock>=3.14.0"
 ]
 
-parler = [
-  "parler_tts @ git+https://github.com/huggingface/parler-tts.git",
-]
-
 [project.urls]
 Documentation = "https://mozilla-ai.github.io/document-to-podcast/"
 Issues = "https://github.com/mozilla-ai/document-to-podcast/issues"

From fb814faf75c2b06273fb55bc8a5f32308f26e0c2 Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Thu, 19 Dec 2024 11:55:33 +0000
Subject: [PATCH 15/26] Separate inference from TTSModel

---
 demo/app.py                    | 4 +++-
 src/document_to_podcast/cli.py | 8 +++++---
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/demo/app.py b/demo/app.py
index 3e1eb49..1684521 100644
--- a/demo/app.py
+++ b/demo/app.py
@@ -5,6 +5,7 @@
 import soundfile as sf
 import streamlit as st
 
+from document_to_podcast.inference.text_to_speech import text_to_speech
 from document_to_podcast.preprocessing import DATA_LOADERS, DATA_CLEANERS
 from document_to_podcast.inference.model_loaders import (
     load_llama_cpp_model,
@@ -147,8 +148,9 @@ def gen_button_clicked():
                         if speaker["id"] == int(speaker_id)
                     )
                     with st.spinner("Generating Audio..."):
-                        speech = speech_model.text_to_speech(
+                        speech = text_to_speech(
                             text.split(f'"Speaker {speaker_id}":')[-1],
+                            speech_model,
                             voice_profile,
                         )
                     st.audio(speech, sample_rate=speech_model.sample_rate)
diff --git a/src/document_to_podcast/cli.py b/src/document_to_podcast/cli.py
index d79168c..9532265 100644
--- a/src/document_to_podcast/cli.py
+++ b/src/document_to_podcast/cli.py
@@ -12,12 +12,13 @@
     Speaker,
     DEFAULT_PROMPT,
     DEFAULT_SPEAKERS,
-    SUPPORTED_TTS_MODELS,
+    TTS_LOADERS,
 )
 from document_to_podcast.inference.model_loaders import (
     load_llama_cpp_model,
     load_tts_model,
 )
+from document_to_podcast.inference.text_to_speech import text_to_speech
 from document_to_podcast.inference.text_to_text import text_to_text_stream
 from document_to_podcast.preprocessing import DATA_CLEANERS, DATA_LOADERS
 
@@ -28,7 +29,7 @@ def document_to_podcast(
     output_folder: str | None = None,
     text_to_text_model: str = "allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf",
     text_to_text_prompt: str = DEFAULT_PROMPT,
-    text_to_speech_model: SUPPORTED_TTS_MODELS = "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf",
+    text_to_speech_model: TTS_LOADERS = "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf",
     speakers: list[Speaker] | None = None,
     outetts_language: str = "en",  # Only applicable to OuteTTS models
     from_config: str | None = None,
@@ -142,8 +143,9 @@ def document_to_podcast(
                 for speaker in config.speakers
                 if speaker.id == int(speaker_id)
             )
-            speech = speech_model.text_to_speech(
+            speech = text_to_speech(
                 text.split(f'"Speaker {speaker_id}":')[-1],
+                speech_model,
                 voice_profile,
             )
             podcast_audio.append(speech)

From 672c0e0cf01b48ecb7ca26132e5fbc803a7e2200 Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Thu, 19 Dec 2024 11:56:18 +0000
Subject: [PATCH 16/26] Make sure config model is properly registered

---
 src/document_to_podcast/config.py | 18 +++++++++++++++---
 1 file changed, 15 insertions(+), 3 deletions(-)

diff --git a/src/document_to_podcast/config.py b/src/document_to_podcast/config.py
index d072509..10f8b6a 100644
--- a/src/document_to_podcast/config.py
+++ b/src/document_to_podcast/config.py
@@ -1,11 +1,11 @@
 from pathlib import Path
-from typing import Literal
 from typing_extensions import Annotated
 
 from pydantic import BaseModel, FilePath
 from pydantic.functional_validators import AfterValidator
 
-from document_to_podcast.inference.model_loaders import SUPPORTED_TTS_MODELS
+from document_to_podcast.inference.model_loaders import TTS_LOADERS
+from document_to_podcast.inference.text_to_speech import TTS_INFERENCE
 from document_to_podcast.preprocessing import DATA_LOADERS
 
 
@@ -66,6 +66,18 @@ def validate_text_to_text_prompt(value):
     return value
 
 
+def validate_text_to_speech_model(value):
+    if value not in TTS_LOADERS:
+        raise ValueError(
+            f"Model {value} is missing a loading function. Please define it under model_loaders.py"
+        )
+    if value not in TTS_INFERENCE:
+        raise ValueError(
+            f"Model {value} is missing an inference function. Please define it under text_to_speech.py"
+        )
+    return value
+
+
 class Speaker(BaseModel):
     id: int
     name: str
@@ -81,6 +93,6 @@ class Config(BaseModel):
     output_folder: str
     text_to_text_model: Annotated[str, AfterValidator(validate_text_to_text_model)]
     text_to_text_prompt: Annotated[str, AfterValidator(validate_text_to_text_prompt)]
-    text_to_speech_model: Literal[tuple(list(SUPPORTED_TTS_MODELS.keys()))]
+    text_to_speech_model: Annotated[str, AfterValidator(validate_text_to_speech_model)]
     speakers: list[Speaker]
     outetts_language: str = "en"

From 28b02b86e9a046b6b286bc17607cf266e4ba2428 Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Thu, 19 Dec 2024 11:57:26 +0000
Subject: [PATCH 17/26] Decouple loading & inference of TTS model

---
 .../inference/model_loaders.py                | 173 ++++++++----------
 1 file changed, 75 insertions(+), 98 deletions(-)

diff --git a/src/document_to_podcast/inference/model_loaders.py b/src/document_to_podcast/inference/model_loaders.py
index 4560932..5f47fb1 100644
--- a/src/document_to_podcast/inference/model_loaders.py
+++ b/src/document_to_podcast/inference/model_loaders.py
@@ -2,21 +2,13 @@
 from huggingface_hub import hf_hub_download
 from llama_cpp import Llama
 from outetts import GGUFModelConfig_v1, InterfaceGGUF
+from dataclasses import dataclass, field
 from transformers import (
     AutoTokenizer,
     PreTrainedModel,
     AutoProcessor,
     BarkModel,
 )
-import numpy as np
-
-from document_to_podcast.inference.text_to_speech import (
-    _text_to_speech_oute,
-    _text_to_speech_bark,
-    _text_to_speech_parler,
-    _text_to_speech_parler_multi,
-    _text_to_speech_parler_indic,
-)
 
 
 def load_llama_cpp_model(model_id: str) -> Llama:
@@ -44,65 +36,29 @@ def load_llama_cpp_model(model_id: str) -> Llama:
     return model
 
 
-class TTSInterface:
+@dataclass
+class TTSModel:
     """
     The purpose of this class is to provide a unified interface for all the TTS models supported.
-    Specifically, different TTS model families have different peculiarities, for example, the bark model needs a
+    Specifically, different TTS model families have different peculiarities, for example, the bark models need a
     BarkProcessor, the parler models need their own tokenizer, etc. This wrapper takes care of this complexity so that
     the user doesn't have to deal with it.
-    """
-
-    def __init__(
-        self,
-        model: Union[InterfaceGGUF, BarkModel, PreTrainedModel],
-        model_id: str,
-        sample_rate: int,
-        **kwargs,
-    ):
-        self.model = model
-        self.model_id = model_id
-        self.sample_rate = sample_rate
-        self.kwargs = kwargs
-
-    def text_to_speech(
-        self, input_text: str, voice_profile: str, **kwargs
-    ) -> np.ndarray:
-        """
-        Generates a speech waveform from a text input using a pre-trained text-to-speech (TTS) model.
-
-        Args:
-            input_text (str): The text to convert to speech.
-            voice_profile (str): Depending on the selected TTS model it should either be
-                - a pre-defined ID for the Oute models (e.g. "female_1")
-                more info here https://github.com/edwko/OuteTTS/tree/main/outetts/version/v1/default_speakers
-                - a pre-defined ID for the Bark model (e.g. "v2/en_speaker_0")
-                more info here https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c
-                - a natural description of the voice profile using a pre-defined name for the Parler model (e.g. Laura's voice is calm)
-                more info here https://github.com/huggingface/parler-tts?tab=readme-ov-file#-using-a-specific-speaker
-                for the multilingual model: https://huggingface.co/parler-tts/parler-tts-mini-multilingual-v1.1
-                for the indic model: https://huggingface.co/ai4bharat/indic-parler-tts
-        Returns:
-            numpy array: The waveform of the speech as a 2D numpy array
-        """
-        return SUPPORTED_TTS_MODELS[self.model_id][1](
-            input_text, self.model, voice_profile, **self.kwargs | kwargs
-        )
-
-
-def load_tts_model(model_id: str, **kwargs) -> TTSInterface:
-    """
 
     Args:
-        model_id:
-        outetts_language:
-
-    Returns:
-
+        model (Union[InterfaceGGUF, BarkModel, PreTrainedModel]): A TTS model that has a .generate() method or similar
+            that takes text as input, and returns an audio in the form of a numpy array.
+        model_id (str): The model's identifier string.
+        sample_rate (int): The sample rate of the audio, required for properly saving the audio to a file.
+        custom_args (dict): Any model-specific arguments that a TTS model might require, e.g. tokenizer.
     """
-    return SUPPORTED_TTS_MODELS[model_id][0](model_id, **kwargs)
 
+    model: Union[InterfaceGGUF, BarkModel, PreTrainedModel]
+    model_id: str
+    sample_rate: int
+    custom_args: field(default_factory=dict)
 
-def _load_oute_tts(model_id: str, **kwargs) -> TTSInterface:
+
+def _load_oute_tts(model_id: str, **kwargs) -> TTSModel:
     """
     Loads the given model_id using the OuteTTS interface. For more info: https://github.com/edwko/OuteTTS
 
@@ -111,7 +67,7 @@ def _load_oute_tts(model_id: str, **kwargs) -> TTSInterface:
             Format is expected to be `{org}/{repo}/{filename}`.
         language (str): Supported languages in 0.2-500M: en, zh, ja, ko.
     Returns:
-        TTSInterface: The loaded model using the TTSModelWrapper.
+        TTSModel: The loaded model using the TTSModel wrapper.
     """
     model_version = model_id.split("-")[1]
 
@@ -122,32 +78,32 @@ def _load_oute_tts(model_id: str, **kwargs) -> TTSInterface:
     )
     model = InterfaceGGUF(model_version=model_version, cfg=model_config)
 
-    return TTSInterface(
-        model=model, model_id=model_id, sample_rate=model.audio_codec.sr
-    )
+    return TTSModel(model=model, model_id=model_id, sample_rate=model.audio_codec.sr)
 
 
-def _load_bark_tts(model_id: str, **kwargs) -> TTSInterface:
+def _load_bark_tts(model_id: str, **kwargs) -> TTSModel:
     """
     Loads the given model_id and its required processor. For more info: https://github.com/suno-ai/bark
 
     Args:
         model_id (str): The model id to load.
             Format is expected to be `{repo}/{filename}`.
-
     Returns:
-        TTSInterface: The loaded model with its required processor using the TTSModelWrapper.
+        TTSModel: The loaded model with its required processor using the TTSModel.
     """
 
     processor = AutoProcessor.from_pretrained(model_id)
     model = BarkModel.from_pretrained(model_id)
 
-    return TTSInterface(
-        model=model, model_id=model_id, sample_rate=24_000, bark_processor=processor
+    return TTSModel(
+        model=model,
+        model_id=model_id,
+        sample_rate=24_000,
+        custom_args={"processor": processor},
     )
 
 
-def _load_parler_tts(model_id: str, **kwargs) -> TTSInterface:
+def _load_parler_tts(model_id: str, **kwargs) -> TTSModel:
     """
     Loads the given model_id using parler_tts.from_pretrained. For more info: https://github.com/huggingface/parler-tts
 
@@ -156,47 +112,68 @@ def _load_parler_tts(model_id: str, **kwargs) -> TTSInterface:
             Format is expected to be `{repo}/{filename}`.
 
     Returns:
-        TTSInterface: The loaded model with its required tokenizer for the input. For the multilingual models we also
-        load another tokenizer for the description
+        TTSModel: The loaded model with its required tokenizer for the input.
     """
     from parler_tts import ParlerTTSForConditionalGeneration
 
     model = ParlerTTSForConditionalGeneration.from_pretrained(model_id)
     tokenizer = AutoTokenizer.from_pretrained(model_id)
 
-    description_tokenizer = (
-        AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path)
-        if model_id == "parler-tts/parler-tts-mini-multilingual-v1.1"
-        or model_id == "ai4bharat/indic-parler-tts"
-        else None
+    return TTSModel(
+        model=model,
+        model_id=model_id,
+        sample_rate=model.config.sampling_rate,
+        custom_args={
+            "tokenizer": tokenizer,
+        },
+    )
+
+
+def _load_parler_tts_multi(model_id: str, **kwargs) -> TTSModel:
+    """
+    Loads the given model_id using parler_tts.from_pretrained. For more info: https://github.com/huggingface/parler-tts
+
+    Args:
+        model_id (str): The model id to load.
+            Format is expected to be `{repo}/{filename}`.
+
+    Returns:
+        TTSModel: The loaded model with its required tokenizer for the input text and
+            another tokenizer for the description.
+    """
+
+    from parler_tts import ParlerTTSForConditionalGeneration
+
+    model = ParlerTTSForConditionalGeneration.from_pretrained(model_id)
+    tokenizer = AutoTokenizer.from_pretrained(model_id)
+    description_tokenizer = AutoTokenizer.from_pretrained(
+        model.config.text_encoder._name_or_path
     )
 
-    return TTSInterface(
+    return TTSModel(
         model=model,
         model_id=model_id,
         sample_rate=model.config.sampling_rate,
-        parler_tokenizer=tokenizer,
-        parler_description_tokenizer=description_tokenizer,
+        custom_args={
+            "tokenizer": tokenizer,
+            "description_tokenizer": description_tokenizer,
+        },
     )
 
 
-SUPPORTED_TTS_MODELS = {
-    # To add support for your model, add it here in the format {model_id} : [_load_function, _text_to_speech_function]
-    "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf": [
-        _load_oute_tts,
-        _text_to_speech_oute,
-    ],
-    "OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf": [
-        _load_oute_tts,
-        _text_to_speech_oute,
-    ],
-    "suno/bark": [_load_bark_tts, _text_to_speech_bark],
-    "parler-tts/parler-tts-large-v1": [_load_parler_tts, _text_to_speech_parler],
-    "parler-tts/parler-tts-mini-v1": [_load_parler_tts, _text_to_speech_parler],
-    "parler-tts/parler-tts-mini-v1.1": [_load_parler_tts, _text_to_speech_parler],
-    "parler-tts/parler-tts-mini-multilingual-v1.1": [
-        _load_parler_tts,
-        _text_to_speech_parler_multi,
-    ],
-    "ai4bharat/indic-parler-tts": [_load_parler_tts, _text_to_speech_parler_indic],
+TTS_LOADERS = {
+    # To add support for your model, add it here in the format {model_id} : _load_function
+    "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf": _load_oute_tts,
+    "OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf": _load_oute_tts,
+    "suno/bark": _load_bark_tts,
+    "suno/bark-small": _load_bark_tts,
+    "parler-tts/parler-tts-large-v1": _load_parler_tts,
+    "parler-tts/parler-tts-mini-v1": _load_parler_tts,
+    "parler-tts/parler-tts-mini-v1.1": _load_parler_tts,
+    "parler-tts/parler-tts-mini-multilingual-v1.1": _load_parler_tts_multi,
+    "ai4bharat/indic-parler-tts": _load_parler_tts_multi,
 }
+
+
+def load_tts_model(model_id: str, **kwargs) -> TTSModel:
+    return TTS_LOADERS[model_id](model_id, **kwargs)

From b489e0db1fb00e6d2424c5a628a6c7483fc59257 Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Thu, 19 Dec 2024 12:24:38 +0000
Subject: [PATCH 18/26] Decouple loading & inference of TTS model

---
 .../inference/text_to_speech.py               | 123 +++++++++---------
 1 file changed, 59 insertions(+), 64 deletions(-)

diff --git a/src/document_to_podcast/inference/text_to_speech.py b/src/document_to_podcast/inference/text_to_speech.py
index f4477a5..81689e7 100644
--- a/src/document_to_podcast/inference/text_to_speech.py
+++ b/src/document_to_podcast/inference/text_to_speech.py
@@ -2,6 +2,8 @@
 from outetts.version.v1.interface import InterfaceGGUF
 from transformers import PreTrainedModel, BarkModel
 
+from document_to_podcast.inference.model_loaders import TTSModel
+
 
 def _text_to_speech_oute(
     input_text: str,
@@ -10,18 +12,20 @@ def _text_to_speech_oute(
     **kwargs,
 ) -> np.ndarray:
     """
-
+    TTS generation function for the Oute TTS model family.
     Args:
-        input_text:
-        model:
-        voice_profile:
-
-        temperature: float = 0.3:
-        repetition_penalty: float = 1.1
-        max_length: int = 4096
+        input_text (str): The text to convert to speech.
+        model: A model from the Oute TTS family.
+        voice_profile: a pre-defined ID for the Oute models (e.g. "female_1")
+            more info here https://github.com/edwko/OuteTTS/tree/main/outetts/version/v1/default_speakers
+        temperature (float, default = 0.3): Controls the randomness of predictions by scaling the logits.
+            Lower values make the output more focused and deterministic, higher values produce more diverse results.
+        repetition_penalty (float, default = 1.1): Applies a penalty to tokens that have already been generated,
+            reducing the likelihood of repetition and enhancing text variety.
+        max_length (int, default = 4096): Defines the maximum number of tokens for the generated text sequence.
 
     Returns:
-
+        numpy array: The waveform of the speech as a 2D numpy array
     """
     speaker = model.load_default_speaker(name=voice_profile)
 
@@ -41,17 +45,19 @@ def _text_to_speech_bark(
     input_test: str, model: BarkModel, voice_profile: str, **kwargs
 ) -> np.ndarray:
     """
-
+    TTS generation function for the Bark model family.
     Args:
-        input_test:
-        model:
-        voice_profile:
-        processor: BarkProcessor
+        input_text (str): The text to convert to speech.
+        model: A model from the Bark family.
+        voice_profile: a pre-defined ID for the Bark model (e.g. "v2/en_speaker_0")
+            more info here https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c
+        processor: Required BarkProcessor to prepare the input text for the Bark model
 
     Returns:
-
+        numpy array: The waveform of the speech as a 2D numpy array
     """
-    if not (processor := kwargs.pop("processor", None)):
+    processor = kwargs.get("processor")
+    if processor is None:
         raise ValueError("Bark model requires a processor")
 
     inputs = processor(input_test, voice_preset=voice_profile)
@@ -66,19 +72,21 @@ def _text_to_speech_parler(
     input_text: str, model: PreTrainedModel, voice_profile: str, **kwargs
 ) -> np.ndarray:
     """
-
+    TTS generation function for the Parler TTS model family.
     Args:
-        input_text:
-        model:
-        voice_profile:
-        tokenizer: PreTrainedTokenizer
+        input_text (str): The text to convert to speech.
+        model: A model from the Parler TTS family.
+        voice_profile: a natural description of the voice profile using a pre-defined name for the Parler model (e.g. Laura's voice is calm)
+            more info here https://github.com/huggingface/parler-tts?tab=readme-ov-file#-using-a-specific-speaker
+        tokenizer (PreTrainedTokenizer): Required PreTrainedTokenizer to tokenize the input text.
 
     Returns:
-
+        numpy array: The waveform of the speech as a 2D numpy array
     """
-
-    if not (tokenizer := kwargs.pop("parler_tokenizer", None)):
+    tokenizer = kwargs.get("tokenizer")
+    if tokenizer is None:
         raise ValueError("Parler model requires a tokenizer")
+
     input_ids = tokenizer(voice_profile, return_tensors="pt").input_ids
     prompt_input_ids = tokenizer(input_text, return_tensors="pt").input_ids
 
@@ -92,22 +100,25 @@ def _text_to_speech_parler_multi(
     input_text: str, model: PreTrainedModel, voice_profile: str, **kwargs
 ) -> np.ndarray:
     """
-
+    TTS generation function for the Parler TTS multilingual model family.
     Args:
-        input_text:
-        model:
-        voice_profile:
-        tokenizer:
-        description_tokenizer:
+        input_text (str): The text to convert to speech.
+        model: A model from the Parler TTS multilingual family.
+        voice_profile: a natural description of the voice profile using a pre-defined name for the Parler model (e.g. Laura's voice is calm)
+            more info here https://huggingface.co/parler-tts/parler-tts-mini-multilingual-v1.1
+            for the indic version https://huggingface.co/ai4bharat/indic-parler-tts
+        tokenizer (PreTrainedTokenizer): Required PreTrainedTokenizer to tokenize the input text.
+        description_tokenizer (PreTrainedTokenizer): Required PreTrainedTokenizer to tokenize the description text.
 
     Returns:
-
+        numpy array: The waveform of the speech as a 2D numpy array
     """
-
-    if not (tokenizer := kwargs.pop("parler_tokenizer", None)):
+    tokenizer = kwargs.get("tokenizer")
+    if tokenizer is None:
         raise ValueError("Parler model requires a tokenizer")
 
-    if not (description_tokenizer := kwargs.pop("parler_description_tokenizer", None)):
+    description_tokenizer = kwargs.get("description_tokenizer")
+    if description_tokenizer is None:
         raise ValueError("Parler multilingual model requires a description tokenizer")
 
     input_ids = description_tokenizer(voice_profile, return_tensors="pt").input_ids
@@ -115,39 +126,23 @@ def _text_to_speech_parler_multi(
 
     generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
     waveform = generation.cpu().numpy().squeeze()
-
     return waveform
 
 
-def _text_to_speech_parler_indic(
-    input_text: str, model: PreTrainedModel, voice_profile: str, **kwargs
-) -> np.ndarray:
-    """
-
-    Args:
-        input_text:
-        model:
-        voice_profile:
-        tokenizer:
-        description_tokenizer:
+TTS_INFERENCE = {
+    "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-Q2_K.gguf": _text_to_speech_oute,
+    "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf": _text_to_speech_oute,
+    "OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf": _text_to_speech_oute,
+    "suno/bark": _text_to_speech_bark,
+    "parler-tts/parler-tts-large-v1": _text_to_speech_parler,
+    "parler-tts/parler-tts-mini-v1": _text_to_speech_parler,
+    "parler-tts/parler-tts-mini-v1.1": _text_to_speech_parler,
+    "parler-tts/parler-tts-mini-multilingual-v1.1": _text_to_speech_parler_multi,
+    "ai4bharat/indic-parler-tts": _text_to_speech_parler_multi,
+}
 
-    Returns:
 
-    """
-    if not (tokenizer := kwargs.pop("parler_tokenizer", None)):
-        raise ValueError("Parler model requires a tokenizer")
-    if not (description_tokenizer := kwargs.pop("parler_description_tokenizer", None)):
-        raise ValueError("Parler multilingual model requires a description tokenizer")
-
-    input_ids = description_tokenizer(voice_profile, return_tensors="pt").input_ids
-    prompt_input_ids = tokenizer(input_text, return_tensors="pt").input_ids
-
-    generation = model.generate(
-        input_ids=input_ids.input_ids,
-        attention_mask=input_ids.attention_mask,
-        prompt_input_ids=prompt_input_ids.input_ids,
-        prompt_attention_mask=prompt_input_ids.attention_mask,
+def text_to_speech(input_text: str, model: TTSModel, voice_profile: str) -> np.ndarray:
+    return TTS_INFERENCE[model.model_id](
+        input_text, model.model, voice_profile, **model.custom_args
     )
-    waveform = generation.cpu().numpy().squeeze()
-
-    return waveform

From dc89668afb3bb4e80ebfd6815b3dc1ea05cc2187 Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Thu, 19 Dec 2024 12:25:37 +0000
Subject: [PATCH 19/26]  Enable user to exit podcast generation gracefully

---
 src/document_to_podcast/cli.py | 44 ++++++++++++++++++----------------
 1 file changed, 23 insertions(+), 21 deletions(-)

diff --git a/src/document_to_podcast/cli.py b/src/document_to_podcast/cli.py
index 9532265..f75f6e7 100644
--- a/src/document_to_podcast/cli.py
+++ b/src/document_to_podcast/cli.py
@@ -130,27 +130,29 @@ def document_to_podcast(
     system_prompt = system_prompt.replace(
         "{SPEAKERS}", "\n".join(str(speaker) for speaker in config.speakers)
     )
-    for chunk in text_to_text_stream(
-        clean_text, text_model, system_prompt=system_prompt
-    ):
-        text += chunk
-        podcast_script += chunk
-        if text.endswith("\n") and "Speaker" in text:
-            logger.debug(text)
-            speaker_id = re.search(r"Speaker (\d+)", text).group(1)
-            voice_profile = next(
-                speaker.voice_profile
-                for speaker in config.speakers
-                if speaker.id == int(speaker_id)
-            )
-            speech = text_to_speech(
-                text.split(f'"Speaker {speaker_id}":')[-1],
-                speech_model,
-                voice_profile,
-            )
-            podcast_audio.append(speech)
-            text = ""
-
+    try:
+        for chunk in text_to_text_stream(
+            clean_text, text_model, system_prompt=system_prompt
+        ):
+            text += chunk
+            podcast_script += chunk
+            if text.endswith("\n") and "Speaker" in text:
+                logger.debug(text)
+                speaker_id = re.search(r"Speaker (\d+)", text).group(1)
+                voice_profile = next(
+                    speaker.voice_profile
+                    for speaker in config.speakers
+                    if speaker.id == int(speaker_id)
+                )
+                speech = text_to_speech(
+                    text.split(f'"Speaker {speaker_id}":')[-1],
+                    speech_model,
+                    voice_profile,
+                )
+                podcast_audio.append(speech)
+                text = ""
+    except KeyboardInterrupt:
+        logger.warning("Podcast generation stopped by user.")
     logger.info("Saving Podcast...")
     sf.write(
         str(output_folder / "podcast.wav"),

From 0d143eb89d884421198afee1c26fd8adb0f1e2f1 Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Thu, 19 Dec 2024 13:11:52 +0000
Subject: [PATCH 20/26] Add Q2 Oute version to TTS_LOADERS

---
 src/document_to_podcast/inference/model_loaders.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/src/document_to_podcast/inference/model_loaders.py b/src/document_to_podcast/inference/model_loaders.py
index 5f47fb1..33b0b73 100644
--- a/src/document_to_podcast/inference/model_loaders.py
+++ b/src/document_to_podcast/inference/model_loaders.py
@@ -78,7 +78,9 @@ def _load_oute_tts(model_id: str, **kwargs) -> TTSModel:
     )
     model = InterfaceGGUF(model_version=model_version, cfg=model_config)
 
-    return TTSModel(model=model, model_id=model_id, sample_rate=model.audio_codec.sr)
+    return TTSModel(
+        model=model, model_id=model_id, sample_rate=model.audio_codec.sr, custom_args={}
+    )
 
 
 def _load_bark_tts(model_id: str, **kwargs) -> TTSModel:
@@ -163,6 +165,7 @@ def _load_parler_tts_multi(model_id: str, **kwargs) -> TTSModel:
 
 TTS_LOADERS = {
     # To add support for your model, add it here in the format {model_id} : _load_function
+    "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-Q2_K.gguf": _load_oute_tts,
     "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf": _load_oute_tts,
     "OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf": _load_oute_tts,
     "suno/bark": _load_bark_tts,

From e9ca49857dc881730b2225bb88eb490111e47000 Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Thu, 19 Dec 2024 13:12:10 +0000
Subject: [PATCH 21/26] Add comment for support in TTS_INFERENCE

---
 src/document_to_podcast/inference/text_to_speech.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/src/document_to_podcast/inference/text_to_speech.py b/src/document_to_podcast/inference/text_to_speech.py
index 81689e7..de5e006 100644
--- a/src/document_to_podcast/inference/text_to_speech.py
+++ b/src/document_to_podcast/inference/text_to_speech.py
@@ -130,6 +130,7 @@ def _text_to_speech_parler_multi(
 
 
 TTS_INFERENCE = {
+    # To add support for your model, add it here in the format {model_id} : _inference_function
     "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-Q2_K.gguf": _text_to_speech_oute,
     "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-FP16.gguf": _text_to_speech_oute,
     "OuteAI/OuteTTS-0.2-500M-GGUF/OuteTTS-0.2-500M-FP16.gguf": _text_to_speech_oute,

From 47112a0df323bc998b1e3b78723644e8184ad86e Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Thu, 19 Dec 2024 13:12:27 +0000
Subject: [PATCH 22/26] Update test_model_loaders.py

---
 tests/unit/inference/test_model_loaders.py | 57 ++++++++++++++--------
 1 file changed, 38 insertions(+), 19 deletions(-)

diff --git a/tests/unit/inference/test_model_loaders.py b/tests/unit/inference/test_model_loaders.py
index 2c11629..e337ff5 100644
--- a/tests/unit/inference/test_model_loaders.py
+++ b/tests/unit/inference/test_model_loaders.py
@@ -1,15 +1,19 @@
+from typing import Dict, Any, Union
+
+import pytest
 from llama_cpp import Llama
 
 from document_to_podcast.inference.model_loaders import (
     load_llama_cpp_model,
-    load_outetts_model,
+    load_tts_model,
 )
-from transformers import PreTrainedModel, PreTrainedTokenizerBase
-from outetts.version.v1.interface import InterfaceGGUF
-
-from document_to_podcast.inference.model_loaders import (
-    load_parler_tts_model_and_tokenizer,
+from transformers import (
+    PreTrainedModel,
+    PreTrainedTokenizerBase,
+    BarkModel,
+    BarkProcessor,
 )
+from outetts.version.v1.interface import InterfaceGGUF
 
 
 def test_load_llama_cpp_model():
@@ -21,16 +25,31 @@ def test_load_llama_cpp_model():
     assert model.n_ctx() == 2048
 
 
-def test_load_outetts_model():
-    model = load_outetts_model(
-        "OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-Q2_K.gguf"
-    )
-    assert isinstance(model, InterfaceGGUF)
-
-
-def test_load_parler_tts_model_and_tokenizer():
-    model, tokenizer = load_parler_tts_model_and_tokenizer(
-        "parler-tts/parler-tts-mini-v1"
-    )
-    assert isinstance(model, PreTrainedModel)
-    assert isinstance(tokenizer, PreTrainedTokenizerBase)
+@pytest.mark.parametrize(
+    "model_id, expected_model_type, expected_custom_args",
+    [
+        ["OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-Q2_K.gguf", InterfaceGGUF, {}],
+        ["suno/bark-small", BarkModel, {"processor": BarkProcessor}],
+        [
+            "parler-tts/parler-tts-mini-v1.1",
+            PreTrainedModel,
+            {
+                "tokenizer": PreTrainedTokenizerBase,
+                "description_tokenizer": PreTrainedTokenizerBase,
+            },
+        ],
+    ],
+)
+def test_load_tts_model(
+    model_id: str,
+    expected_model_type: Union[InterfaceGGUF, BarkModel, PreTrainedModel],
+    expected_custom_args: Dict[str, Any],
+) -> None:
+    model = load_tts_model(model_id)
+    assert isinstance(model.model, expected_model_type)
+    assert model.model_id == model_id
+    for (k, v), (e_k, e_v) in zip(
+        model.custom_args.items(), expected_custom_args.items()
+    ):
+        assert k == e_k
+        assert isinstance(v, e_v)

From ec0fe5abbe01243594d5adc31f7c714571e53548 Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Thu, 19 Dec 2024 13:12:41 +0000
Subject: [PATCH 23/26] Update test_text_to_speech.py

---
 tests/unit/inference/test_text_to_speech.py | 81 +++++++++++++++++++--
 1 file changed, 75 insertions(+), 6 deletions(-)

diff --git a/tests/unit/inference/test_text_to_speech.py b/tests/unit/inference/test_text_to_speech.py
index 7ae239c..7d3a1b4 100644
--- a/tests/unit/inference/test_text_to_speech.py
+++ b/tests/unit/inference/test_text_to_speech.py
@@ -1,16 +1,24 @@
 from outetts.version.v1.interface import InterfaceGGUF
 from transformers import PreTrainedModel
 
+from document_to_podcast.inference.model_loaders import TTSModel
 from document_to_podcast.inference.text_to_speech import text_to_speech
 
 
 def test_text_to_speech_oute(mocker):
     model = mocker.MagicMock(spec_set=InterfaceGGUF)
-    text_to_speech(
-        "Hello?",
+    tts_model = TTSModel(
         model=model,
+        model_id="OuteAI/OuteTTS-0.1-350M-GGUF/OuteTTS-0.1-350M-Q2_K.gguf",
+        sample_rate=0,
+        custom_args={},
+    )
+    text_to_speech(
+        input_text="Hello?",
+        model=tts_model,
         voice_profile="female_1",
     )
+
     model.load_default_speaker.assert_called_with(name=mocker.ANY)
     model.generate.assert_called_with(
         text=mocker.ANY,
@@ -21,18 +29,79 @@ def test_text_to_speech_oute(mocker):
     )
 
 
+def test_text_to_speech_bark(mocker):
+    model = mocker.MagicMock(spec_set=PreTrainedModel)
+    processor = mocker.MagicMock()
+
+    tts_model = TTSModel(
+        model=model,
+        model_id="suno/bark",
+        sample_rate=0,
+        custom_args={"processor": processor},
+    )
+    text_to_speech(
+        input_text="Hello?",
+        model=tts_model,
+        voice_profile="v2/en_speaker_0",
+    )
+    processor.assert_has_calls(
+        [
+            mocker.call("Hello?", voice_preset="v2/en_speaker_0"),
+        ]
+    )
+    model.generate.assert_called_with()
+
+
 def test_text_to_speech_parler(mocker):
     model = mocker.MagicMock(spec_set=PreTrainedModel)
     tokenizer = mocker.MagicMock()
+
+    tts_model = TTSModel(
+        model=model,
+        model_id="parler-tts/parler-tts-mini-v1.1",
+        sample_rate=0,
+        custom_args={"tokenizer": tokenizer},
+    )
     text_to_speech(
-        "Hello?",
+        input_text="Hello?",
+        model=tts_model,
+        voice_profile="Laura's voice is calm",
+    )
+    tokenizer.assert_has_calls(
+        [
+            mocker.call("Laura's voice is calm", return_tensors="pt"),
+            mocker.call("Hello?", return_tensors="pt"),
+        ]
+    )
+    model.generate.assert_called_with(input_ids=mocker.ANY, prompt_input_ids=mocker.ANY)
+
+
+def test_text_to_speech_parler_multi(mocker):
+    model = mocker.MagicMock(spec_set=PreTrainedModel)
+    tokenizer = mocker.MagicMock()
+    description_tokenizer = mocker.MagicMock()
+
+    tts_model = TTSModel(
         model=model,
-        tokenizer=tokenizer,
-        voice_profile="default",
+        model_id="ai4bharat/indic-parler-tts",
+        sample_rate=0,
+        custom_args={
+            "tokenizer": tokenizer,
+            "description_tokenizer": description_tokenizer,
+        },
+    )
+    text_to_speech(
+        input_text="Hello?",
+        model=tts_model,
+        voice_profile="Laura's voice is calm",
+    )
+    description_tokenizer.assert_has_calls(
+        [
+            mocker.call("Laura's voice is calm", return_tensors="pt"),
+        ]
     )
     tokenizer.assert_has_calls(
         [
-            mocker.call("default", return_tensors="pt"),
             mocker.call("Hello?", return_tensors="pt"),
         ]
     )

From 827dd1d8793a39933fee2e9f31e594a2d5ba3822 Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Mon, 30 Dec 2024 13:56:40 +0200
Subject: [PATCH 24/26] Remove extra "use case" examples

---
 example_data/config_bark.yaml         | 31 ---------------------------
 example_data/config_parler.yaml       | 31 ---------------------------
 example_data/config_parler_multi.yaml | 31 ---------------------------
 3 files changed, 93 deletions(-)
 delete mode 100644 example_data/config_bark.yaml
 delete mode 100644 example_data/config_parler.yaml
 delete mode 100644 example_data/config_parler_multi.yaml

diff --git a/example_data/config_bark.yaml b/example_data/config_bark.yaml
deleted file mode 100644
index 77712c8..0000000
--- a/example_data/config_bark.yaml
+++ /dev/null
@@ -1,31 +0,0 @@
-input_file: "example_data/a.md"
-output_folder: "example_data/bark/"
-text_to_text_model: "allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf"
-text_to_speech_model: "suno/bark"
-text_to_text_prompt: |
-  You are a podcast scriptwriter generating engaging and natural-sounding conversations in JSON format. 
-  The script features the following speakers:
-  {SPEAKERS}
-  Instructions:
-  - Write dynamic, easy-to-follow dialogue.
-  - Include natural interruptions and interjections.
-  - Avoid repetitive phrasing between speakers.
-  - Format output as a JSON conversation.
-  Example:
-  {
-    "Speaker 1": "Welcome to our podcast! Today, we're exploring...",
-    "Speaker 2": "Hi! I'm excited to hear about this. Can you explain...",
-    "Speaker 1": "Sure! Imagine it like this...",
-    "Speaker 2": "Oh, that's cool! But how does..."
-  }
-speakers:
-  - id: 1
-    name: Laura
-    description: The main host. She explains topics clearly using anecdotes and analogies, teaching in an engaging and captivating way.
-    voice_profile: "v2/en_speaker_0"
-
-  - id: 2
-    name: Daniel
-    description: The co-host. He keeps the conversation on track, asks curious follow-up questions, and reacts with excitement or confusion, often using interjections like hmm or umm.
-    voice_profile: "v2/en_speaker_1"
-outetts_language: "en"  # Supported languages in version 0.2-500M: en, zh, ja, ko.
\ No newline at end of file
diff --git a/example_data/config_parler.yaml b/example_data/config_parler.yaml
deleted file mode 100644
index e8240e1..0000000
--- a/example_data/config_parler.yaml
+++ /dev/null
@@ -1,31 +0,0 @@
-input_file: "example_data/a.md"
-output_folder: "example_data/parler/"
-text_to_text_model: "allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf"
-text_to_speech_model: "parler-tts/parler-tts-mini-v1.1"
-text_to_text_prompt: |
-  You are a podcast scriptwriter generating engaging and natural-sounding conversations in JSON format. 
-  The script features the following speakers:
-  {SPEAKERS}
-  Instructions:
-  - Write dynamic, easy-to-follow dialogue.
-  - Include natural interruptions and interjections.
-  - Avoid repetitive phrasing between speakers.
-  - Format output as a JSON conversation.
-  Example:
-  {
-    "Speaker 1": "Welcome to our podcast! Today, we're exploring...",
-    "Speaker 2": "Hi! I'm excited to hear about this. Can you explain...",
-    "Speaker 1": "Sure! Imagine it like this...",
-    "Speaker 2": "Oh, that's cool! But how does..."
-  }
-speakers:
-  - id: 1
-    name: Laura
-    description: The main host. She explains topics clearly using anecdotes and analogies, teaching in an engaging and captivating way.
-    voice_profile: Laura's voice is calm and slow in delivery, with no background noise.
-
-  - id: 2
-    name: Daniel
-    description: The co-host. He keeps the conversation on track, asks curious follow-up questions, and reacts with excitement or confusion, often using interjections like hmm or umm.
-    voice_profile: Daniel's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise.
-outetts_language: "en"  # Supported languages in version 0.2-500M: en, zh, ja, ko.
\ No newline at end of file
diff --git a/example_data/config_parler_multi.yaml b/example_data/config_parler_multi.yaml
deleted file mode 100644
index 656ca90..0000000
--- a/example_data/config_parler_multi.yaml
+++ /dev/null
@@ -1,31 +0,0 @@
-input_file: "example_data/a.md"
-output_folder: "example_data/parler_multi/"
-text_to_text_model: "allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf"
-text_to_speech_model: "parler-tts/parler-tts-mini-multilingual-v1.1"
-text_to_text_prompt: |
-  You are a podcast scriptwriter generating engaging and natural-sounding conversations in JSON format. 
-  The script features the following speakers:
-  {SPEAKERS}
-  Instructions:
-  - Write dynamic, easy-to-follow dialogue.
-  - Include natural interruptions and interjections.
-  - Avoid repetitive phrasing between speakers.
-  - Format output as a JSON conversation.
-  Example:
-  {
-    "Speaker 1": "Welcome to our podcast! Today, we're exploring...",
-    "Speaker 2": "Hi! I'm excited to hear about this. Can you explain...",
-    "Speaker 1": "Sure! Imagine it like this...",
-    "Speaker 2": "Oh, that's cool! But how does..."
-  }
-speakers:
-  - id: 1
-    name: Laura
-    description: The main host. She explains topics clearly using anecdotes and analogies, teaching in an engaging and captivating way.
-    voice_profile: Laura's voice is calm and slow in delivery, with no background noise.
-
-  - id: 2
-    name: Daniel
-    description: The co-host. He keeps the conversation on track, asks curious follow-up questions, and reacts with excitement or confusion, often using interjections like hmm or umm.
-    voice_profile: Daniel's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise.
-outetts_language: "en"  # Supported languages in version 0.2-500M: en, zh, ja, ko.
\ No newline at end of file

From 0abbbdb03ebe895c5f854489bcc8e85f825e4686 Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Mon, 30 Dec 2024 14:23:45 +0200
Subject: [PATCH 25/26] Add bark to readme & note about multilingual support

---
 README.md | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/README.md b/README.md
index 0ea7348..8501428 100644
--- a/README.md
+++ b/README.md
@@ -8,14 +8,14 @@
 # Document-to-podcast: a Blueprint by Mozilla.ai for generating podcasts from documents using local AI
 
 This blueprint demonstrate how you can use open-source models & tools to convert input documents into a podcast featuring two speakers.
-It is designed to work on most local setups or with [GitHub Codespaces](https://github.com/codespaces/new?hide_repo_select=true&ref=main&repo=888426876&skip_quickstart=true&machine=standardLinux32gb), meaning no external API calls or GPU access is required. This makes it more accessible and privacy-friendly by keeping everything local.
+It is designed to work on most local setups or with [GitHub Codespaces](https://github.com/codespaces/new?hide_repo_select=true&ref=main&repo=888426876&skip_quickstart=true&machine=standardLinux32gb), meaning no external API calls or GPU access is required. This makes it more accessible and privacy-friendly by keeping everything local. Moreover, we have added support for models that enable generating podcasts (both text and audio), in multiple languages!
 
 ### 👉 📖 For more detailed guidance on using this project, please visit our [Docs here](https://mozilla-ai.github.io/document-to-podcast/).
 
 ### Built with
 - Python 3.10+ (use Python 3.12 for Apple M1/2/3 chips)
 - [Llama-cpp](https://github.com/abetlen/llama-cpp-python) (text-to-text, i.e script generation)
-- [OuteAI](https://github.com/edwko/OuteTTS) / [Parler_tts](https://github.com/huggingface/parler-tts) (text-to-speech, i.e audio generation)
+- [OuteAI](https://github.com/edwko/OuteTTS) / [Bark](https://github.com/suno-ai/bark) / [Parler_tts](https://github.com/huggingface/parler-tts) (text-to-speech, i.e audio generation)
 - [Streamlit](https://streamlit.io/) (UI demo)
 
 

From d81d933a26128fea91999433b7c1d546d0e6eacf Mon Sep 17 00:00:00 2001
From: Kostis-S-Z <kostissz@pm.me>
Date: Mon, 30 Dec 2024 14:24:04 +0200
Subject: [PATCH 26/26] Reference a repo that showcases multilingual use cases

---
 docs/customization.md | 1 +
 1 file changed, 1 insertion(+)

diff --git a/docs/customization.md b/docs/customization.md
index 0461977..b798abc 100644
--- a/docs/customization.md
+++ b/docs/customization.md
@@ -74,6 +74,7 @@ Looking for inspiration? Check out these examples of how others have customized
 
 - **[Radio Drama Generator](https://github.com/stefanfrench/radio-drama-generator)**: A creative adaptation that generates radio dramas by customizing ng the Blueprint parameters.
 - **[Readme-to-Podcast](https://github.com/alexmeckes/readme-to-podcast)**: This project transforms GitHub README files into podcast-style audio, showcasing the Blueprint’s ability to handle diverse text inputs.
+- **[Multilingual Podcast](https://github.com/Kostis-S-Z/document-to-podcast/)**: A repo that showcases how to use this package in other languages, like Hindi, Polish, Korean and many more.
 
 ## 🤝 **Contributing to the Blueprint**