From f8b6b8c52f4537a6d1c1fd9cd1c32a482b1bfc6a Mon Sep 17 00:00:00 2001 From: daavoo Date: Fri, 13 Dec 2024 10:29:25 +0100 Subject: [PATCH] Rename tone to voice_profile --- demo/app.py | 6 +++--- src/document_to_podcast/cli.py | 10 ++++++---- src/document_to_podcast/config.py | 6 +++--- src/document_to_podcast/inference/text_to_speech.py | 6 +++--- tests/unit/inference/test_text_to_speech.py | 2 +- 5 files changed, 16 insertions(+), 14 deletions(-) diff --git a/demo/app.py b/demo/app.py index d9fac7c..aeed6c0 100644 --- a/demo/app.py +++ b/demo/app.py @@ -138,8 +138,8 @@ def gen_button_clicked(): st.write(st.session_state.script) speaker_id = re.search(r"Speaker (\d+)", text).group(1) - tone = next( - speaker["tone"] + voice_profile = next( + speaker["voice_profile"] for speaker in speakers if speaker["id"] == int(speaker_id) ) @@ -148,7 +148,7 @@ def gen_button_clicked(): text.split(f'"Speaker {speaker_id}":')[-1], speech_model, speech_tokenizer, - tone, + voice_profile, ) st.audio(speech, sample_rate=44100) st.session_state.audio.append(speech) diff --git a/src/document_to_podcast/cli.py b/src/document_to_podcast/cli.py index 18de858..730ce5e 100644 --- a/src/document_to_podcast/cli.py +++ b/src/document_to_podcast/cli.py @@ -127,14 +127,16 @@ def document_to_podcast( if text.endswith("\n") and "Speaker" in text: logger.debug(text) speaker_id = re.search(r"Speaker (\d+)", text).group(1) - tone = next( - speaker for speaker in config.speakers if speaker.id == int(speaker_id) - ).tone + voice_profile = next( + speaker.voice_profile + for speaker in config.speakers + if speaker.id == int(speaker_id) + ) speech = text_to_speech( text.split(f'"Speaker {speaker_id}":')[-1], speech_model, speech_tokenizer, - tone, + voice_profile, ) podcast_audio.append(speech) text = "" diff --git a/src/document_to_podcast/config.py b/src/document_to_podcast/config.py index ff892a4..091b7ea 100644 --- a/src/document_to_podcast/config.py +++ b/src/document_to_podcast/config.py @@ -31,13 +31,13 @@ "id": 1, "name": "Laura", "description": "The main host. She explains topics clearly using anecdotes and analogies, teaching in an engaging and captivating way.", - "tone": "Laura's voice is exciting and fast in delivery with very clear audio and no background noise.", + "voice_profile": "Laura's voice is exciting and fast in delivery with very clear audio and no background noise.", }, { "id": 2, "name": "Jon", "description": "The co-host. He keeps the conversation on track, asks curious follow-up questions, and reacts with excitement or confusion, often using interjections like hmm or umm.", - "tone": "Jon's voice is calm with very clear audio and no background noise.", + "voice_profile": "Jon's voice is calm with very clear audio and no background noise.", }, ] @@ -69,7 +69,7 @@ class Speaker(BaseModel): id: int name: str description: str - tone: str + voice_profile: str def __str__(self): return f"Speaker {self.id}. Named {self.name}. {self.description}" diff --git a/src/document_to_podcast/inference/text_to_speech.py b/src/document_to_podcast/inference/text_to_speech.py index 6435cf6..253e56e 100644 --- a/src/document_to_podcast/inference/text_to_speech.py +++ b/src/document_to_podcast/inference/text_to_speech.py @@ -6,7 +6,7 @@ def text_to_speech( input_text: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase, - tone: str, + voice_profile: str, ) -> np.ndarray: """ Generates a speech waveform from a text input using a pre-trained text-to-speech (TTS) model. @@ -15,11 +15,11 @@ def text_to_speech( input_text (str): The text to convert to speech. model (PreTrainedModel): The model used for generating the waveform. tokenizer (PreTrainedTokenizerBase): The tokenizer used for tokenizing the text in order to send to the model. - tone (str): A description used by the ParlerTTS model to configure the speaker profile. + voice_profile (str): A description used by the ParlerTTS model to configure the voice. Returns: numpy array: The waveform of the speech as a 2D numpy array """ - input_ids = tokenizer(tone, return_tensors="pt").input_ids + input_ids = tokenizer(voice_profile, return_tensors="pt").input_ids prompt_input_ids = tokenizer(input_text, return_tensors="pt").input_ids generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) diff --git a/tests/unit/inference/test_text_to_speech.py b/tests/unit/inference/test_text_to_speech.py index 1dfc07b..fbb666b 100644 --- a/tests/unit/inference/test_text_to_speech.py +++ b/tests/unit/inference/test_text_to_speech.py @@ -8,7 +8,7 @@ def test_text_to_speech_parler(mocker): "Hello?", model=model, tokenizer=tokenizer, - tone="default", + voice_profile="default", ) tokenizer.assert_has_calls( [