Skip to content

Commit

Permalink
Add TTS_MODELS Literal
Browse files Browse the repository at this point in the history
  • Loading branch information
daavoo committed Dec 13, 2024
1 parent f8b6b8c commit ca42ec6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
10 changes: 8 additions & 2 deletions src/document_to_podcast/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from loguru import logger


from document_to_podcast.config import Config, Speaker, DEFAULT_PROMPT, DEFAULT_SPEAKERS
from document_to_podcast.config import (
Config,
Speaker,
DEFAULT_PROMPT,
DEFAULT_SPEAKERS,
TTS_MODELS,
)
from document_to_podcast.inference.model_loaders import (
load_llama_cpp_model,
load_parler_tts_model_and_tokenizer,
Expand All @@ -24,7 +30,7 @@ def document_to_podcast(
output_folder: str | None = None,
text_to_text_model: str = "allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf",
text_to_text_prompt: str = DEFAULT_PROMPT,
text_to_speech_model: str = "parler-tts/parler-tts-mini-v1",
text_to_speech_model: TTS_MODELS = "parler-tts/parler-tts-mini-v1",
speakers: list[Speaker] | None = None,
from_config: str | None = None,
):
Expand Down
15 changes: 9 additions & 6 deletions src/document_to_podcast/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@
]


TTS_MODELS = Literal[
"parler-tts/parler-tts-large-v1",
"parler-tts/parler-tts-mini-v1",
"parler-tts/parler-tts-mini-v1.1",
]


def validate_input_file(value):
if Path(value).suffix not in DATA_LOADERS:
raise ValueError(
Expand Down Expand Up @@ -79,10 +86,6 @@ class Config(BaseModel):
input_file: Annotated[FilePath, AfterValidator(validate_input_file)]
output_folder: str
text_to_text_model: Annotated[str, AfterValidator(validate_text_to_text_model)]
text_to_text_prompt: str
text_to_speech_model: Literal[
"parler-tts/parler-tts-large-v1",
"parler-tts/parler-tts-mini-v1",
"parler-tts/parler-tts-mini-v1.1",
]
text_to_text_prompt: Annotated[str, AfterValidator(validate_text_to_text_prompt)]
text_to_speech_model: TTS_MODELS
speakers: list[Speaker]

0 comments on commit ca42ec6

Please sign in to comment.