diff --git a/src/document_to_podcast/cli.py b/src/document_to_podcast/cli.py index 730ce5e..fc4090f 100644 --- a/src/document_to_podcast/cli.py +++ b/src/document_to_podcast/cli.py @@ -8,7 +8,13 @@ from loguru import logger -from document_to_podcast.config import Config, Speaker, DEFAULT_PROMPT, DEFAULT_SPEAKERS +from document_to_podcast.config import ( + Config, + Speaker, + DEFAULT_PROMPT, + DEFAULT_SPEAKERS, + TTS_MODELS, +) from document_to_podcast.inference.model_loaders import ( load_llama_cpp_model, load_parler_tts_model_and_tokenizer, @@ -24,7 +30,7 @@ def document_to_podcast( output_folder: str | None = None, text_to_text_model: str = "allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf", text_to_text_prompt: str = DEFAULT_PROMPT, - text_to_speech_model: str = "parler-tts/parler-tts-mini-v1", + text_to_speech_model: TTS_MODELS = "parler-tts/parler-tts-mini-v1", speakers: list[Speaker] | None = None, from_config: str | None = None, ): diff --git a/src/document_to_podcast/config.py b/src/document_to_podcast/config.py index 091b7ea..c57344e 100644 --- a/src/document_to_podcast/config.py +++ b/src/document_to_podcast/config.py @@ -42,6 +42,13 @@ ] +TTS_MODELS = Literal[ + "parler-tts/parler-tts-large-v1", + "parler-tts/parler-tts-mini-v1", + "parler-tts/parler-tts-mini-v1.1", +] + + def validate_input_file(value): if Path(value).suffix not in DATA_LOADERS: raise ValueError( @@ -79,10 +86,6 @@ class Config(BaseModel): input_file: Annotated[FilePath, AfterValidator(validate_input_file)] output_folder: str text_to_text_model: Annotated[str, AfterValidator(validate_text_to_text_model)] - text_to_text_prompt: str - text_to_speech_model: Literal[ - "parler-tts/parler-tts-large-v1", - "parler-tts/parler-tts-mini-v1", - "parler-tts/parler-tts-mini-v1.1", - ] + text_to_text_prompt: Annotated[str, AfterValidator(validate_text_to_text_prompt)] + text_to_speech_model: TTS_MODELS speakers: list[Speaker]