Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

podcast script generation component #15

Merged
merged 32 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
536c98d
Add devcontainer and requirements
daavoo Nov 14, 2024
0ae661e
Add pyproject.toml
daavoo Nov 15, 2024
c4a1ee1
Add data_loaders and tests
daavoo Nov 15, 2024
d2b276c
Add data_cleaners and tests
daavoo Nov 15, 2024
8629481
Update demo
daavoo Nov 15, 2024
cef92b3
Add `LOADERS` and `CLEANERS`
daavoo Nov 19, 2024
acd50a9
Add markdown and docx
daavoo Nov 19, 2024
2a8f005
Add API Reference
daavoo Nov 19, 2024
95c342a
Update tests
daavoo Nov 19, 2024
e8ac586
Update install
daavoo Nov 19, 2024
ee7d299
Add initial scripts
daavoo Nov 19, 2024
fb38207
More tests
daavoo Nov 20, 2024
29df436
Merge remote-tracking branch 'origin/main' into AH-104-Initial-Podcas…
daavoo Nov 21, 2024
d4b6066
fix merge
daavoo Nov 21, 2024
abeb3c0
Add podcast writing to demo/app
daavoo Nov 21, 2024
4bcc57b
Add missing deps
daavoo Nov 21, 2024
06627fa
Add text_to_podcast module
daavoo Nov 21, 2024
4457813
Expose model options and prompt tuning in the app
daavoo Nov 21, 2024
c73d4d3
pre-commit
daavoo Nov 21, 2024
a868093
Strip system_prompt
daavoo Nov 21, 2024
8b2c57b
Rename to inference module. Add docstrings
daavoo Nov 22, 2024
d2c75c9
pre-commit
daavoo Nov 22, 2024
7a7e39c
Add CURATED_REPOS
daavoo Nov 22, 2024
e1c6ccb
JSON prompt
daavoo Nov 22, 2024
72413a1
Update API docs
daavoo Nov 22, 2024
8817ea0
Fix format
daavoo Nov 22, 2024
06a2c3d
Make text cutoff based on `model.n_ctx()`. Consider ~4 characters per…
daavoo Nov 25, 2024
39fb3b3
Add inference tests
daavoo Nov 25, 2024
1968278
Drop __init__ imports
daavoo Nov 25, 2024
f88b713
Fix outdated arg
daavoo Nov 25, 2024
73ac7bf
Drop redundant JSON output in prompt
daavoo Nov 25, 2024
1c11c98
Update default stop
daavoo Nov 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/.devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
"features": {
"ghcr.io/devcontainers/features/python": {
"version": "latest"
},
"packages": ["libgl1-mesa-dev"]
}
},
"postCreateCommand": "pip install -e '.[demo]'"
"postCreateCommand": "pip install -e . --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
}
63 changes: 61 additions & 2 deletions demo/app.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,32 @@
from pathlib import Path

import streamlit as st
from huggingface_hub import list_repo_files

from opennotebookllm.preprocessing import DATA_LOADERS, DATA_CLEANERS
from opennotebookllm.inference.model_loaders import load_llama_cpp_model
from opennotebookllm.inference.text_to_text import text_to_text_stream

PODCAST_PROMPT = """
daavoo marked this conversation as resolved.
Show resolved Hide resolved
You are a helpful podcast writer.
You will take the input text and generate a conversation between 2 speakers.
Example of response:
{
"Speaker 1": "Welcome to our podcast, where we explore the latest advancements in AI and technology. I'm your host, and today we're going to dive into the exciting world of TrustWorthy AI.",
"Speaker 2": "Hi, I'm excited to be here, so what is TrustWorthy AI?",
"Speaker 1":"Ah, great question! It is a term used by the European High Level Expert Group on AI. Mozilla defines trustworthy AI as AI that is demonstrably worthy of trust, tech that considers accountability, agency, and individual and collective well-being."
}
"""

CURATED_REPOS = [
stefanfrench marked this conversation as resolved.
Show resolved Hide resolved
"allenai/OLMoE-1B-7B-0924-Instruct-GGUF",
"MaziyarPanahi/SmolLM2-1.7B-Instruct-GGUF",
# system prompt seems to be ignored for this model.
# "microsoft/Phi-3-mini-4k-instruct-gguf",
"HuggingFaceTB/SmolLM2-360M-Instruct-GGUF",
"Qwen/Qwen2.5-1.5B-Instruct-GGUF",
"Qwen/Qwen2.5-3B-Instruct-GGUF",
]

uploaded_file = st.file_uploader(
"Choose a file", type=["pdf", "html", "txt", "docx", "md"]
Expand All @@ -17,9 +40,45 @@
raw_text = DATA_LOADERS[extension](uploaded_file)
with col1:
st.title("Raw Text")
st.write(raw_text[:200])
st.text_area(f"Total Length: {len(raw_text)}", f"{raw_text[:500]} . . .")

clean_text = DATA_CLEANERS[extension](raw_text)
with col2:
st.title("Cleaned Text")
st.write(clean_text[:200])
st.text_area(f"Total Length: {len(clean_text)}", f"{clean_text[:500]} . . .")

repo_name = st.selectbox("Select Repo", CURATED_REPOS)
model_name = st.selectbox(
"Select Model",
[
x
for x in list_repo_files(repo_name)
if ".gguf" in x.lower() and ("q8" in x.lower() or "fp16" in x.lower())
],
index=None,
)
if model_name:
with st.spinner("Downloading and Loading Model..."):
model = load_llama_cpp_model(model_id=f"{repo_name}/{model_name}")

# ~4 characters per token is considered a reasonable default.
max_characters = model.n_ctx() * 4
if len(clean_text) > max_characters:
st.warning(
f"Input text is too big ({len(clean_text)})."
f" Using only a subset of it ({max_characters})."
)
clean_text = clean_text[:max_characters]

system_prompt = st.text_area("Podcast generation prompt", value=PODCAST_PROMPT)

if st.button("Generate Podcast Script"):
with st.spinner("Generating Podcast Script..."):
text = ""
for chunk in text_to_text_stream(
clean_text, model, system_prompt=system_prompt.strip()
):
text += chunk
if text.endswith("\n"):
st.write(text)
text = ""
4 changes: 4 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# API Reference

::: opennotebookllm.preprocessing.data_cleaners

::: opennotebookllm.inference.model_loaders

::: opennotebookllm.inference.text_to_text
9 changes: 4 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ requires-python = ">=3.10"
dynamic = ["version"]
dependencies = [
"beautifulsoup4",
"huggingface-hub",
"llama-cpp-python",
"loguru",
"PyPDF2[crypto]",
"python-docx"
"python-docx",
"streamlit",
]

[project.optional-dependencies]
Expand All @@ -27,10 +30,6 @@ tests = [
"pytest-sugar>=0.9.6",
]

demo = [
"streamlit"
]

[project.urls]
Documentation = "https://mozilla-ai.github.io/OpenNotebookLLM/"
Issues = "https://github.com/mozilla-ai/OpenNotebookLLM/issues"
Expand Down
Empty file.
28 changes: 28 additions & 0 deletions src/opennotebookllm/inference/model_loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from llama_cpp import Llama


def load_llama_cpp_model(
model_id: str,
) -> Llama:
"""
Loads the given model_id using Llama.from_pretrained.

Examples:
>>> model = load_model(
"allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf")

Args:
model_id (str): The model id to load.
Format is expected to be `{org}/{repo}/{filename}`.

Returns:
Llama: The loaded model.
"""
org, repo, filename = model_id.split("/")
model = Llama.from_pretrained(
repo_id=f"{org}/{repo}",
filename=filename,
# 0 means that the model limit will be used, instead of the default (512) or other hardcoded value
n_ctx=0,
)
return model
84 changes: 84 additions & 0 deletions src/opennotebookllm/inference/text_to_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Iterator

from llama_cpp import Llama


def chat_completion(
input_text: str,
model: Llama,
system_prompt: str,
return_json: bool,
stream: bool,
stop: str | list[str] | None = None,
) -> str | Iterator[str]:
# create_chat_completion uses an empty list as default
stop = stop or []
return model.create_chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": input_text},
],
response_format={
"type": "json_object",
}
if return_json
else None,
stream=stream,
stop=stop,
)


def text_to_text(
input_text: str,
model: Llama,
system_prompt: str,
return_json: bool = True,
stop: str | list[str] | None = None,
) -> str:
"""
Transforms input_text using the given model and system prompt.

Args:
input_text (str): The text to be transformed.
model (Llama): The model to use for conversion.
system_prompt (str): The system prompt to use for conversion.
return_json (bool, optional): Whether to return the response as JSON.
Defaults to True.
stop (str | list[str] | None, optional): The stop token(s).

Returns:
str: The full transformed text.
"""
response = chat_completion(
input_text, model, system_prompt, return_json, stop=stop, stream=False
)
return response["choices"][0]["message"]["content"]


def text_to_text_stream(
input_text: str,
model: Llama,
system_prompt: str,
return_json: bool = True,
stop: str | list[str] | None = None,
) -> Iterator[str]:
"""
Transforms input_text using the given model and system prompt.

Args:
input_text (str): The text to be transformed.
model (Llama): The model to use for conversion.
system_prompt (str): The system prompt to use for conversion.
return_json (bool, optional): Whether to return the response as JSON.
Defaults to True.
stop (str | list[str] | None, optional): The stop token(s).

Yields:
str: Chunks of the transformed text as they are available.
"""
response = chat_completion(
input_text, model, system_prompt, return_json, stop=stop, stream=True
)
for item in response:
if item["choices"][0].get("delta", {}).get("content", None):
yield item["choices"][0].get("delta", {}).get("content", None)
68 changes: 68 additions & 0 deletions tests/integration/test_model_load_and_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import json
from typing import Iterator

import pytest

from opennotebookllm.inference.model_loaders import load_llama_cpp_model
from opennotebookllm.inference.text_to_text import text_to_text, text_to_text_stream


def test_model_load_and_inference_text_to_text():
model = load_llama_cpp_model(
"HuggingFaceTB/smollm-135M-instruct-v0.2-Q8_0-GGUF/smollm-135m-instruct-add-basics-q8_0.gguf"
)
result = text_to_text(
"What is the capital of France?",
model=model,
system_prompt="",
)
assert isinstance(result, str)
assert json.loads(result)["Capital"] == "Paris"


def test_model_load_and_inference_text_to_text_no_json():
model = load_llama_cpp_model(
"HuggingFaceTB/smollm-135M-instruct-v0.2-Q8_0-GGUF/smollm-135m-instruct-add-basics-q8_0.gguf"
)
result = text_to_text(
"What is the capital of France?",
model=model,
system_prompt="",
return_json=False,
stop=".",
)
assert isinstance(result, str)
with pytest.raises(json.JSONDecodeError):
json.loads(result)
assert result.startswith("The capital of France is Paris")


def test_model_load_and_inference_text_to_text_stream():
model = load_llama_cpp_model(
"HuggingFaceTB/smollm-135M-instruct-v0.2-Q8_0-GGUF/smollm-135m-instruct-add-basics-q8_0.gguf"
)
result = text_to_text_stream(
"What is the capital of France?",
model=model,
system_prompt="",
)
assert isinstance(result, Iterator)
assert json.loads("".join(result))["Capital"] == "Paris"


def test_model_load_and_inference_text_to_text_stream_no_json():
model = load_llama_cpp_model(
"HuggingFaceTB/smollm-135M-instruct-v0.2-Q8_0-GGUF/smollm-135m-instruct-add-basics-q8_0.gguf"
)
result = text_to_text_stream(
"What is the capital of France?",
model=model,
system_prompt="",
return_json=False,
stop=".",
)
assert isinstance(result, Iterator)
result = "".join(result)
with pytest.raises(json.JSONDecodeError):
json.loads(result)
assert result.startswith("The capital of France is Paris")
12 changes: 12 additions & 0 deletions tests/unit/inference/test_model_loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from llama_cpp import Llama

from opennotebookllm.inference.model_loaders import load_llama_cpp_model


def test_load_llama_cpp_model():
model = load_llama_cpp_model(
"HuggingFaceTB/smollm-135M-instruct-v0.2-Q8_0-GGUF/smollm-135m-instruct-add-basics-q8_0.gguf"
)
assert isinstance(model, Llama)
# we set n_ctx=0 to indicate that we want to use the model's default context
assert model.n_ctx() == 2048