diff --git a/demo/app.py b/demo/app.py index bf61ba9..8f763ee 100644 --- a/demo/app.py +++ b/demo/app.py @@ -1,8 +1,10 @@ import re from pathlib import Path +import numpy as np import streamlit as st +from document_to_podcast.podcast_maker.script_to_audio import save_waveform_as_file from document_to_podcast.preprocessing import DATA_LOADERS, DATA_CLEANERS from document_to_podcast.inference.model_loaders import ( load_llama_cpp_model, @@ -48,6 +50,21 @@ def load_text_to_speech_model_and_tokenizer(): return load_parler_tts_model_and_tokenizer("parler-tts/parler-tts-mini-v1", "cpu") +script = "script" +audio = "audio" +gen_button = "generate podcast button" +if script not in st.session_state: + st.session_state[script] = "" +if audio not in st.session_state: + st.session_state.audio = [] +if gen_button not in st.session_state: + st.session_state[gen_button] = False + + +def gen_button_clicked(): + st.session_state[gen_button] = True + + st.title("Document To Podcast") st.header("Uploading Data") @@ -107,7 +124,7 @@ def load_text_to_speech_model_and_tokenizer(): system_prompt = st.text_area("Podcast generation prompt", value=PODCAST_PROMPT) - if st.button("Generate Podcast"): + if st.button("Generate Podcast", on_click=gen_button_clicked): with st.spinner("Generating Podcast..."): text = "" for chunk in text_to_text_stream( @@ -115,7 +132,9 @@ def load_text_to_speech_model_and_tokenizer(): ): text += chunk if text.endswith("\n") and "Speaker" in text: - st.write(text) + st.session_state.script += text + st.write(st.session_state.script) + speaker_id = re.search(r"Speaker (\d+)", text).group(1) with st.spinner("Generating Audio..."): speech = text_to_speech( @@ -125,4 +144,22 @@ def load_text_to_speech_model_and_tokenizer(): SPEAKER_DESCRIPTIONS[speaker_id], ) st.audio(speech, sample_rate=speech_model.config.sampling_rate) + st.session_state.audio.append(speech) text = "" + + if st.session_state[gen_button]: + if st.button("Save Podcast to audio file"): + st.session_state.audio = np.concatenate(st.session_state.audio) + save_waveform_as_file( + waveform=st.session_state.audio, + sampling_rate=speech_model.config.sampling_rate, + filename="podcast.wav", + ) + st.markdown("Podcast saved to disk!") + + if st.button("Save Podcast script to text file"): + with open("script.txt", "w") as f: + st.session_state.script += "}" + f.write(st.session_state.script) + + st.markdown("Script saved to disk!")