diff --git a/demo/app.py b/demo/app.py index e4096c6..41ae869 100644 --- a/demo/app.py +++ b/demo/app.py @@ -1,11 +1,21 @@ from pathlib import Path import streamlit as st +from huggingface_hub import list_repo_files from opennotebookllm.preprocessing import DATA_LOADERS, DATA_CLEANERS from opennotebookllm.text_to_podcast import load_model from opennotebookllm.text_to_podcast import text_to_podcast +PODCAST_PROMPT = """ +Convert this text into a podcast script. +The conversation should be between 2 speakers. +Use [SPEAKER1] and [SPEAKER2] to limit sections. +Do not include [INTRO], [OUTRO] or any other [SECTION]. +Text: +""" + +REPO = "allenai/OLMoE-1B-7B-0924-Instruct-GGUF" uploaded_file = st.file_uploader( "Choose a file", type=["pdf", "html", "txt", "docx", "md"] @@ -33,13 +43,26 @@ ) clean_text = clean_text[: 4096 * 3] - with st.spinner("Downloading and Loading Model..."): - model = load_model() + model_name = st.selectbox("Select Model", + [ + x for x in list_repo_files(REPO) + if ".gguf" in x + # The float16 is too big for the 16GB RAM codespace + and "f16" not in x + ], + index=None + ) + if model_name: + with st.spinner("Downloading and Loading Model..."): + model = load_model(model_id=f"{REPO}/{model_name}") + + system_prompt = st.text_area("Podcast generation prompt", value=PODCAST_PROMPT) - with st.spinner("Writing Podcast Script..."): - text = "" - for chunk in text_to_podcast(clean_text, model, stream=True): - text += chunk - if text.endswith("\n"): - st.write(text) - text = "" + if st.button("Generate Podcast Script"): + with st.spinner("Generating Podcast Script..."): + text = "" + for chunk in text_to_podcast(clean_text, model, system_prompt=system_prompt, stream=True): + text += chunk + if text.endswith("\n"): + st.write(text) + text = "" diff --git a/src/opennotebookllm/text_to_podcast/inference.py b/src/opennotebookllm/text_to_podcast/inference.py index 3c3b24b..50942fc 100644 --- a/src/opennotebookllm/text_to_podcast/inference.py +++ b/src/opennotebookllm/text_to_podcast/inference.py @@ -1,17 +1,10 @@ from llama_cpp import Llama -PROMPT = """ -Convert this text into a podcast script. -The conversation should be between 2 speakers. -Use [SPEAKER1] and [SPEAKER2] to limit sections. -Do not include [INTRO], [OUTRO] or any other [SECTION]. -Text: -""" def load_model( model_id: str = "allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf", -): +) -> Llama: org, repo, filename = model_id.split("/") model = Llama.from_pretrained( repo_id=f"{org}/{repo}", @@ -23,7 +16,7 @@ def load_model( def text_to_podcast( - input_text: str, model: Llama, system_prompt: str = PROMPT, stream: bool = False + input_text: str, model: Llama, system_prompt: str, stream: bool = False ): response = model.create_chat_completion( messages=[