Skip to content

Commit

Permalink
Rename tone to voice_profile
Browse files Browse the repository at this point in the history
  • Loading branch information
daavoo committed Dec 13, 2024
1 parent 615dfbc commit f8b6b8c
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 14 deletions.
6 changes: 3 additions & 3 deletions demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def gen_button_clicked():
st.write(st.session_state.script)

speaker_id = re.search(r"Speaker (\d+)", text).group(1)
tone = next(
speaker["tone"]
voice_profile = next(
speaker["voice_profile"]
for speaker in speakers
if speaker["id"] == int(speaker_id)
)
Expand All @@ -148,7 +148,7 @@ def gen_button_clicked():
text.split(f'"Speaker {speaker_id}":')[-1],
speech_model,
speech_tokenizer,
tone,
voice_profile,
)
st.audio(speech, sample_rate=44100)
st.session_state.audio.append(speech)
Expand Down
10 changes: 6 additions & 4 deletions src/document_to_podcast/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,16 @@ def document_to_podcast(
if text.endswith("\n") and "Speaker" in text:
logger.debug(text)
speaker_id = re.search(r"Speaker (\d+)", text).group(1)
tone = next(
speaker for speaker in config.speakers if speaker.id == int(speaker_id)
).tone
voice_profile = next(
speaker.voice_profile
for speaker in config.speakers
if speaker.id == int(speaker_id)
)
speech = text_to_speech(
text.split(f'"Speaker {speaker_id}":')[-1],
speech_model,
speech_tokenizer,
tone,
voice_profile,
)
podcast_audio.append(speech)
text = ""
Expand Down
6 changes: 3 additions & 3 deletions src/document_to_podcast/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@
"id": 1,
"name": "Laura",
"description": "The main host. She explains topics clearly using anecdotes and analogies, teaching in an engaging and captivating way.",
"tone": "Laura's voice is exciting and fast in delivery with very clear audio and no background noise.",
"voice_profile": "Laura's voice is exciting and fast in delivery with very clear audio and no background noise.",
},
{
"id": 2,
"name": "Jon",
"description": "The co-host. He keeps the conversation on track, asks curious follow-up questions, and reacts with excitement or confusion, often using interjections like hmm or umm.",
"tone": "Jon's voice is calm with very clear audio and no background noise.",
"voice_profile": "Jon's voice is calm with very clear audio and no background noise.",
},
]

Expand Down Expand Up @@ -69,7 +69,7 @@ class Speaker(BaseModel):
id: int
name: str
description: str
tone: str
voice_profile: str

def __str__(self):
return f"Speaker {self.id}. Named {self.name}. {self.description}"
Expand Down
6 changes: 3 additions & 3 deletions src/document_to_podcast/inference/text_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def text_to_speech(
input_text: str,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
tone: str,
voice_profile: str,
) -> np.ndarray:
"""
Generates a speech waveform from a text input using a pre-trained text-to-speech (TTS) model.
Expand All @@ -15,11 +15,11 @@ def text_to_speech(
input_text (str): The text to convert to speech.
model (PreTrainedModel): The model used for generating the waveform.
tokenizer (PreTrainedTokenizerBase): The tokenizer used for tokenizing the text in order to send to the model.
tone (str): A description used by the ParlerTTS model to configure the speaker profile.
voice_profile (str): A description used by the ParlerTTS model to configure the voice.
Returns:
numpy array: The waveform of the speech as a 2D numpy array
"""
input_ids = tokenizer(tone, return_tensors="pt").input_ids
input_ids = tokenizer(voice_profile, return_tensors="pt").input_ids
prompt_input_ids = tokenizer(input_text, return_tensors="pt").input_ids

generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/inference/test_text_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_text_to_speech_parler(mocker):
"Hello?",
model=model,
tokenizer=tokenizer,
tone="default",
voice_profile="default",
)
tokenizer.assert_has_calls(
[
Expand Down

0 comments on commit f8b6b8c

Please sign in to comment.