Skip to content

Commit

Permalink
Document how to add custom model providers
Browse files Browse the repository at this point in the history
  • Loading branch information
krassowski committed Oct 29, 2023
1 parent b06e259 commit 642d34a
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 6 deletions.
90 changes: 89 additions & 1 deletion docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,94 @@ responsible for all charges they incur when they make API requests. Review your
provider's pricing information before submitting requests via Jupyter AI.
:::

### Custom model providers

You can define a new provider building upon LangChain framework API. The provider
inherit from both `jupyter-ai`'s ``BaseProvider`` and `langchain`'s [``LLM``][LLM].
You can either import a pre-defined model from [LangChain LLM list][langchain_llms],
or define a [custom LLM][custom_llm].
In the example below, we demonstrate defining a provider with two models using
a dummy ``FakeListLLM`` model, which returns responses from the ``responses``
keyword argument.

```python
# my_package/my_provider.py
from jupyter_ai_magics import BaseProvider
from langchain.llms import FakeListLLM


class MyProvider(BaseProvider, FakeListLLM):
id = "my_provider"
name = "My Provider"
model_id_key = "model"
models = [
"model_a",
"model_b"
]
def __init__(self, **kwargs):
model = kwargs.get("model_id")
kwargs["responses"] = (
["This is a response from model 'a'"]
if model == "model_a" else
["This is a response from model 'b'"]
)
super().__init__(**kwargs)
```


The provider will be available for both chat and magic usage if it inherits from
[``BaseChatModel``][BaseChatModel] or otherwise only in the magic.

To plug the new provider you will need declare it via an [entry point](https://setuptools.pypa.io/en/latest/userguide/entry_point.html):

```toml
# my_package/pyproject.toml
[project]
name = "my_package"
version = "0.0.1"

[project.entry-points."jupyter_ai.model_providers"]
my-provider = "my_provider:MyProvider"
```

To test that the above minimal provider package works, install it with:

```sh
# from `my_package` directory
pip install -e .
```

and restart JupyterLab which now should include a log with:

```
[I 2023-10-29 13:56:16.915 AiExtension] Registered model provider `ai21`.
```

[langchain_llms]: https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.llms
[custom_llm]: https://python.langchain.com/docs/modules/model_io/models/llms/custom_llm
[LLM]: https://api.python.langchain.com/en/latest/llms/langchain.llms.base.LLM.html#langchain.llms.base.LLM
[BaseChatModel]: https://api.python.langchain.com/en/latest/chat_models/langchain.chat_models.base.BaseChatModel.html


### Customising prompt templates

To modify the prompt template for a given format, override the implementation of ``get_prompt_template`` method:

```python
from langchain.prompts import PromptTemplate


class MyProvider(BaseProvider, FakeListLLM):
# (... properties as above ...)
def get_prompt_template(self, format) -> PromptTemplate:
if format === "code":
return PromptTemplate.from_template(
"{prompt}\n\nProduce output as source code only, "
"with no text or explanation before or after it."
)
return super().get_prompt_template(format)
```

## The chat interface

The easiest way to get started with Jupyter AI is to use the chat interface.
Expand Down Expand Up @@ -689,7 +777,7 @@ Write a poem about C++.

You can also define a custom LangChain chain:

```
```python
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
Expand Down
10 changes: 6 additions & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ def get_lm_providers(
for model_provider_ep in model_provider_eps:
try:
provider = model_provider_ep.load()
except:
except Exception as e:
log.error(
f"Unable to load model provider class from entry point `{model_provider_ep.name}`."
f"Unable to load model provider class from entry point `{model_provider_ep.name}`: %s.",
e,
)
continue
if not is_provider_allowed(provider.id, restrictions):
Expand All @@ -58,9 +59,10 @@ def get_em_providers(
for model_provider_ep in model_provider_eps:
try:
provider = model_provider_ep.load()
except:
except Exception as e:
log.error(
f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`."
f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`: %s.",
e,
)
continue
if not is_provider_allowed(provider.id, restrictions):
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def broadcast_message(self, message: Message):
self.chat_history.append(message)

async def on_message(self, message):
self.log.debug("Message recieved: %s", message)
self.log.debug("Message received: %s", message)

try:
message = json.loads(message)
Expand Down

0 comments on commit 642d34a

Please sign in to comment.