generated from mozilla-ai/Blueprint-template
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
podcast script generation component (#15)
* Add devcontainer and requirements * Add pyproject.toml * Add data_loaders and tests * Add data_cleaners and tests * Update demo * Add `LOADERS` and `CLEANERS` * Add markdown and docx * Add API Reference * Update tests * Update install * Add initial scripts * More tests * fix merge * Add podcast writing to demo/app * Add missing deps * Add text_to_podcast module * Expose model options and prompt tuning in the app * pre-commit * Strip system_prompt * Rename to inference module. Add docstrings * pre-commit * Add CURATED_REPOS * JSON prompt * Update API docs * Fix format * Make text cutoff based on `model.n_ctx()`. Consider ~4 characters per token as a resonable default. * Add inference tests * Drop __init__ imports * Fix outdated arg * Drop redundant JSON output in prompt * Update default stop
- Loading branch information
Showing
12 changed files
with
263 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
# API Reference | ||
|
||
::: opennotebookllm.preprocessing.data_cleaners | ||
|
||
::: opennotebookllm.inference.model_loaders | ||
|
||
::: opennotebookllm.inference.text_to_text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from llama_cpp import Llama | ||
|
||
|
||
def load_llama_cpp_model( | ||
model_id: str, | ||
) -> Llama: | ||
""" | ||
Loads the given model_id using Llama.from_pretrained. | ||
Examples: | ||
>>> model = load_model( | ||
"allenai/OLMoE-1B-7B-0924-Instruct-GGUF/olmoe-1b-7b-0924-instruct-q8_0.gguf") | ||
Args: | ||
model_id (str): The model id to load. | ||
Format is expected to be `{org}/{repo}/{filename}`. | ||
Returns: | ||
Llama: The loaded model. | ||
""" | ||
org, repo, filename = model_id.split("/") | ||
model = Llama.from_pretrained( | ||
repo_id=f"{org}/{repo}", | ||
filename=filename, | ||
# 0 means that the model limit will be used, instead of the default (512) or other hardcoded value | ||
n_ctx=0, | ||
) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from typing import Iterator | ||
|
||
from llama_cpp import Llama | ||
|
||
|
||
def chat_completion( | ||
input_text: str, | ||
model: Llama, | ||
system_prompt: str, | ||
return_json: bool, | ||
stream: bool, | ||
stop: str | list[str] | None = None, | ||
) -> str | Iterator[str]: | ||
# create_chat_completion uses an empty list as default | ||
stop = stop or [] | ||
return model.create_chat_completion( | ||
messages=[ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": input_text}, | ||
], | ||
response_format={ | ||
"type": "json_object", | ||
} | ||
if return_json | ||
else None, | ||
stream=stream, | ||
stop=stop, | ||
) | ||
|
||
|
||
def text_to_text( | ||
input_text: str, | ||
model: Llama, | ||
system_prompt: str, | ||
return_json: bool = True, | ||
stop: str | list[str] | None = None, | ||
) -> str: | ||
""" | ||
Transforms input_text using the given model and system prompt. | ||
Args: | ||
input_text (str): The text to be transformed. | ||
model (Llama): The model to use for conversion. | ||
system_prompt (str): The system prompt to use for conversion. | ||
return_json (bool, optional): Whether to return the response as JSON. | ||
Defaults to True. | ||
stop (str | list[str] | None, optional): The stop token(s). | ||
Returns: | ||
str: The full transformed text. | ||
""" | ||
response = chat_completion( | ||
input_text, model, system_prompt, return_json, stop=stop, stream=False | ||
) | ||
return response["choices"][0]["message"]["content"] | ||
|
||
|
||
def text_to_text_stream( | ||
input_text: str, | ||
model: Llama, | ||
system_prompt: str, | ||
return_json: bool = True, | ||
stop: str | list[str] | None = None, | ||
) -> Iterator[str]: | ||
""" | ||
Transforms input_text using the given model and system prompt. | ||
Args: | ||
input_text (str): The text to be transformed. | ||
model (Llama): The model to use for conversion. | ||
system_prompt (str): The system prompt to use for conversion. | ||
return_json (bool, optional): Whether to return the response as JSON. | ||
Defaults to True. | ||
stop (str | list[str] | None, optional): The stop token(s). | ||
Yields: | ||
str: Chunks of the transformed text as they are available. | ||
""" | ||
response = chat_completion( | ||
input_text, model, system_prompt, return_json, stop=stop, stream=True | ||
) | ||
for item in response: | ||
if item["choices"][0].get("delta", {}).get("content", None): | ||
yield item["choices"][0].get("delta", {}).get("content", None) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import json | ||
from typing import Iterator | ||
|
||
import pytest | ||
|
||
from opennotebookllm.inference.model_loaders import load_llama_cpp_model | ||
from opennotebookllm.inference.text_to_text import text_to_text, text_to_text_stream | ||
|
||
|
||
def test_model_load_and_inference_text_to_text(): | ||
model = load_llama_cpp_model( | ||
"HuggingFaceTB/smollm-135M-instruct-v0.2-Q8_0-GGUF/smollm-135m-instruct-add-basics-q8_0.gguf" | ||
) | ||
result = text_to_text( | ||
"What is the capital of France?", | ||
model=model, | ||
system_prompt="", | ||
) | ||
assert isinstance(result, str) | ||
assert json.loads(result)["Capital"] == "Paris" | ||
|
||
|
||
def test_model_load_and_inference_text_to_text_no_json(): | ||
model = load_llama_cpp_model( | ||
"HuggingFaceTB/smollm-135M-instruct-v0.2-Q8_0-GGUF/smollm-135m-instruct-add-basics-q8_0.gguf" | ||
) | ||
result = text_to_text( | ||
"What is the capital of France?", | ||
model=model, | ||
system_prompt="", | ||
return_json=False, | ||
stop=".", | ||
) | ||
assert isinstance(result, str) | ||
with pytest.raises(json.JSONDecodeError): | ||
json.loads(result) | ||
assert result.startswith("The capital of France is Paris") | ||
|
||
|
||
def test_model_load_and_inference_text_to_text_stream(): | ||
model = load_llama_cpp_model( | ||
"HuggingFaceTB/smollm-135M-instruct-v0.2-Q8_0-GGUF/smollm-135m-instruct-add-basics-q8_0.gguf" | ||
) | ||
result = text_to_text_stream( | ||
"What is the capital of France?", | ||
model=model, | ||
system_prompt="", | ||
) | ||
assert isinstance(result, Iterator) | ||
assert json.loads("".join(result))["Capital"] == "Paris" | ||
|
||
|
||
def test_model_load_and_inference_text_to_text_stream_no_json(): | ||
model = load_llama_cpp_model( | ||
"HuggingFaceTB/smollm-135M-instruct-v0.2-Q8_0-GGUF/smollm-135m-instruct-add-basics-q8_0.gguf" | ||
) | ||
result = text_to_text_stream( | ||
"What is the capital of France?", | ||
model=model, | ||
system_prompt="", | ||
return_json=False, | ||
stop=".", | ||
) | ||
assert isinstance(result, Iterator) | ||
result = "".join(result) | ||
with pytest.raises(json.JSONDecodeError): | ||
json.loads(result) | ||
assert result.startswith("The capital of France is Paris") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from llama_cpp import Llama | ||
|
||
from opennotebookllm.inference.model_loaders import load_llama_cpp_model | ||
|
||
|
||
def test_load_llama_cpp_model(): | ||
model = load_llama_cpp_model( | ||
"HuggingFaceTB/smollm-135M-instruct-v0.2-Q8_0-GGUF/smollm-135m-instruct-add-basics-q8_0.gguf" | ||
) | ||
assert isinstance(model, Llama) | ||
# we set n_ctx=0 to indicate that we want to use the model's default context | ||
assert model.n_ctx() == 2048 |
File renamed without changes.
File renamed without changes.