diff --git a/docs/source/developers/index.md b/docs/source/developers/index.md index 123315f7a..ba3c969d7 100644 --- a/docs/source/developers/index.md +++ b/docs/source/developers/index.md @@ -16,3 +16,149 @@ Jupyter AI classes. For more details about using `langchain.pydantic_v1` in an environment with Pydantic v2 installed, see the [LangChain documentation on Pydantic compatibility](https://python.langchain.com/docs/guides/pydantic_compatibility). + +## Custom model providers + +You can define new providers using the LangChain framework API. Custom providers +inherit from both `jupyter-ai`'s `BaseProvider` and `langchain`'s [`LLM`][LLM]. +You can either import a pre-defined model from [LangChain LLM list][langchain_llms], +or define a [custom LLM][custom_llm]. +In the example below, we define a provider with two models using +a dummy `FakeListLLM` model, which returns responses from the `responses` +keyword argument. + +```python +# my_package/my_provider.py +from jupyter_ai_magics import BaseProvider +from langchain.llms import FakeListLLM + + +class MyProvider(BaseProvider, FakeListLLM): + id = "my_provider" + name = "My Provider" + model_id_key = "model" + models = [ + "model_a", + "model_b" + ] + def __init__(self, **kwargs): + model = kwargs.get("model_id") + kwargs["responses"] = ( + ["This is a response from model 'a'"] + if model == "model_a" else + ["This is a response from model 'b'"] + ) + super().__init__(**kwargs) +``` + + +If the new provider inherits from [`BaseChatModel`][BaseChatModel], it will be available +both in the chat UI and with magic commands. Otherwise, users can only use the new provider +with magic commands. + +To make the new provider available, you need to declare it as an [entry point](https://setuptools.pypa.io/en/latest/userguide/entry_point.html): + +```toml +# my_package/pyproject.toml +[project] +name = "my_package" +version = "0.0.1" + +[project.entry-points."jupyter_ai.model_providers"] +my-provider = "my_provider:MyProvider" +``` + +To test that the above minimal provider package works, install it with: + +```sh +# from `my_package` directory +pip install -e . +``` + +Then, restart JupyterLab. You should now see an info message in the log that mentions +your new provider's `id`: + +``` +[I 2023-10-29 13:56:16.915 AiExtension] Registered model provider `my_provider`. +``` + +[langchain_llms]: https://api.python.langchain.com/en/v0.0.339/api_reference.html#module-langchain.llms +[custom_llm]: https://python.langchain.com/docs/modules/model_io/models/llms/custom_llm +[LLM]: https://api.python.langchain.com/en/v0.0.339/llms/langchain.llms.base.LLM.html#langchain.llms.base.LLM +[BaseChatModel]: https://api.python.langchain.com/en/v0.0.339/chat_models/langchain.chat_models.base.BaseChatModel.html + +## Prompt templates + +Each provider can define **prompt templates** for each supported format. A prompt +template guides the language model to produce output in a particular +format. The default prompt templates are a +[Python dictionary mapping formats to templates](https://github.com/jupyterlab/jupyter-ai/blob/57a758fa5cdd5a87da5519987895aa688b3766a8/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py#L138-L166). +Developers who write subclasses of `BaseProvider` can override templates per +output format, per model, and based on the prompt being submitted, by +implementing their own +[`get_prompt_template` function](https://github.com/jupyterlab/jupyter-ai/blob/57a758fa5cdd5a87da5519987895aa688b3766a8/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py#L186-L195). +Each prompt template includes the string `{prompt}`, which is replaced with +the user-provided prompt when the user runs a magic command. + +### Customizing prompt templates + +To modify the prompt template for a given format, override the `get_prompt_template` method: + +```python +from langchain.prompts import PromptTemplate + + +class MyProvider(BaseProvider, FakeListLLM): + # (... properties as above ...) + def get_prompt_template(self, format) -> PromptTemplate: + if format === "code": + return PromptTemplate.from_template( + "{prompt}\n\nProduce output as source code only, " + "with no text or explanation before or after it." + ) + return super().get_prompt_template(format) +``` + +Please note that this will only work with Jupyter AI magics (the `%ai` and `%%ai` magic commands). Custom prompt templates are not used in the chat interface yet. + +## Custom slash commands in the chat UI + +You can add a custom slash command to the chat interface by +creating a new class that inherits from `BaseChatHandler`. Set +its `id`, `name`, `help` message for display in the user interface, +and `routing_type`. Each custom slash command must have a unique +slash command. Slash commands can only contain ASCII letters, numerals, +and underscores. Each slash command must be unique; custom slash +commands cannot replace built-in slash commands. + +Add your custom handler in Python code: + +```python +from jupyter_ai.chat_handlers.base import BaseChatHandler, SlashCommandRoutingType +from jupyter_ai.models import HumanChatMessage + +class CustomChatHandler(BaseChatHandler): + id = "custom" + name = "Custom" + help = "A chat handler that does something custom" + routing_type = SlashCommandRoutingType(slash_id="custom") + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def process_message(self, message: HumanChatMessage): + # Put your custom logic here + self.reply("", message) +``` + +Jupyter AI uses entry points to support custom slash commands. +In the `pyproject.toml` file, add your custom handler to the +`[project.entry-points."jupyter_ai.chat_handlers"]` section: + +``` +[project.entry-points."jupyter_ai.chat_handlers"] +custom = "custom_package:CustomChatHandler" +``` + +Then, install your package so that Jupyter AI adds custom chat handlers +to the existing chat handlers. diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py index c3c64b789..e4c69f012 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py @@ -1,5 +1,5 @@ from .ask import AskChatHandler -from .base import BaseChatHandler +from .base import BaseChatHandler, SlashCommandRoutingType from .clear import ClearChatHandler from .default import DefaultChatHandler from .generate import GenerateChatHandler diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index e5c852051..bfb55ce21 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -7,7 +7,7 @@ from langchain.memory import ConversationBufferWindowMemory from langchain.prompts import PromptTemplate -from .base import BaseChatHandler +from .base import BaseChatHandler, SlashCommandRoutingType PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. @@ -26,6 +26,11 @@ class AskChatHandler(BaseChatHandler): to the LLM to generate the final reply. """ + id = "ask" + name = "Ask with Local Data" + help = "Asks a question with retrieval augmented generation (RAG)" + routing_type = SlashCommandRoutingType(slash_id="ask") + def __init__(self, retriever, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index c7ca70f97..15aef6788 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -1,35 +1,81 @@ import argparse +import os import time import traceback - -# necessary to prevent circular import -from typing import TYPE_CHECKING, Any, Dict, Optional, Type +from typing import ( + TYPE_CHECKING, + Awaitable, + ClassVar, + Dict, + List, + Literal, + Optional, + Type, +) from uuid import uuid4 +from dask.distributed import Client as DaskClient from jupyter_ai.config_manager import ConfigManager, Logger -from jupyter_ai.models import AgentChatMessage, HumanChatMessage +from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage from jupyter_ai_magics.providers import BaseProvider +# necessary to prevent circular import +from pydantic import BaseModel + if TYPE_CHECKING: from jupyter_ai.handlers import RootChatHandler +# Chat handler type, with specific attributes for each +class HandlerRoutingType(BaseModel): + routing_method: ClassVar[str] = Literal["slash_command"] + """The routing method that sends commands to this handler.""" + + +class SlashCommandRoutingType(HandlerRoutingType): + routing_method = "slash_command" + + slash_id: Optional[str] + """Slash ID for routing a chat command to this handler. Only one handler + may declare a particular slash ID. Must contain only alphanumerics and + underscores.""" + + class BaseChatHandler: """Base ChatHandler class containing shared methods and attributes used by multiple chat handler classes.""" + # Class attributes + id: ClassVar[str] = ... + """ID for this chat handler; should be unique""" + + name: ClassVar[str] = ... + """User-facing name of this handler""" + + help: ClassVar[str] = ... + """What this chat handler does, which third-party models it contacts, + the data it returns to the user, and so on, for display in the UI.""" + + routing_type: HandlerRoutingType = ... + def __init__( self, log: Logger, config_manager: ConfigManager, root_chat_handlers: Dict[str, "RootChatHandler"], model_parameters: Dict[str, Dict], + chat_history: List[ChatMessage], + root_dir: str, + dask_client_future: Awaitable[DaskClient], ): self.log = log self.config_manager = config_manager self._root_chat_handlers = root_chat_handlers self.model_parameters = model_parameters + self._chat_history = chat_history self.parser = argparse.ArgumentParser() + self.root_dir = os.path.abspath(os.path.expanduser(root_dir)) + self.dask_client_future = dask_client_future self.llm = None self.llm_params = None self.llm_chain = None diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py index a2a39bb00..7042c4632 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py @@ -2,13 +2,17 @@ from jupyter_ai.models import ChatMessage, ClearMessage -from .base import BaseChatHandler +from .base import BaseChatHandler, SlashCommandRoutingType class ClearChatHandler(BaseChatHandler): - def __init__(self, chat_history: List[ChatMessage], *args, **kwargs): + id = "clear" + name = "Clear chat messages" + help = "Clears the displayed chat message history only; does not clear the context sent to chat providers" + routing_type = SlashCommandRoutingType(slash_id="clear") + + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._chat_history = chat_history async def process_message(self, _): self._chat_history.clear() diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index d329e05e2..5bd839ca5 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -12,7 +12,7 @@ SystemMessagePromptTemplate, ) -from .base import BaseChatHandler +from .base import BaseChatHandler, SlashCommandRoutingType SYSTEM_PROMPT = """ You are Jupyternaut, a conversational assistant living in JupyterLab to help users. @@ -32,10 +32,14 @@ class DefaultChatHandler(BaseChatHandler): - def __init__(self, chat_history: List[ChatMessage], *args, **kwargs): + id = "default" + name = "Default" + help = "Responds to prompts that are not otherwise handled by a chat handler" + routing_type = SlashCommandRoutingType(slash_id=None) + + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.memory = ConversationBufferWindowMemory(return_messages=True, k=2) - self.chat_history = chat_history def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] @@ -80,8 +84,8 @@ def clear_memory(self): self.reply(reply_message) # clear transcript for new chat clients - if self.chat_history: - self.chat_history.clear() + if self._chat_history: + self._chat_history.clear() async def process_message(self, message: HumanChatMessage): self.get_llm_chain() diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index ca18becc2..b3d5212ae 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Type import nbformat -from jupyter_ai.chat_handlers import BaseChatHandler +from jupyter_ai.chat_handlers import BaseChatHandler, SlashCommandRoutingType from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics.providers import BaseProvider from langchain.chains import LLMChain @@ -216,11 +216,13 @@ def create_notebook(outline): class GenerateChatHandler(BaseChatHandler): - """Generates a Jupyter notebook given a description.""" + id = "generate" + name = "Generate Notebook" + help = "Generates a Jupyter notebook, including name, outline, and section contents" + routing_type = SlashCommandRoutingType(slash_id="generate") - def __init__(self, root_dir: str, *args, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.root_dir = os.path.abspath(os.path.expanduser(root_dir)) self.llm = None def create_llm_chain( diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py index be89d1165..cbf4c19c9 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py @@ -4,7 +4,7 @@ from jupyter_ai.models import AgentChatMessage, HumanChatMessage -from .base import BaseChatHandler +from .base import BaseChatHandler, SlashCommandRoutingType HELP_MESSAGE = """Hi there! I'm Jupyternaut, your programming assistant. You can ask me a question using the text box below. You can also use these commands: @@ -29,6 +29,11 @@ def HelpMessage(): class HelpChatHandler(BaseChatHandler): + id = "help" + name = "Help" + help = "Displays a help message in the chat message area" + routing_type = SlashCommandRoutingType(slash_id="help") + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index 40cae643b..825acf453 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -24,19 +24,20 @@ ) from langchain.vectorstores import FAISS -from .base import BaseChatHandler +from .base import BaseChatHandler, SlashCommandRoutingType INDEX_SAVE_DIR = os.path.join(jupyter_data_dir(), "jupyter_ai", "indices") METADATA_SAVE_PATH = os.path.join(INDEX_SAVE_DIR, "metadata.json") class LearnChatHandler(BaseChatHandler): - def __init__( - self, root_dir: str, dask_client_future: Awaitable[DaskClient], *args, **kwargs - ): + id = "learn" + name = "Learn Local Data" + help = "Pass a list of files and directories. Once converted to vector format, you can ask about them with /ask." + routing_type = SlashCommandRoutingType(slash_id="learn") + + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.root_dir = root_dir - self.dask_client_future = dask_client_future self.parser.prog = "/learn" self.parser.add_argument("-a", "--all-files", action="store_true") self.parser.add_argument("-v", "--verbose", action="store_true") diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 8ab8c0cc6..ec08d9962 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -1,6 +1,9 @@ +import logging +import re import time from dask.distributed import Client as DaskClient +from importlib_metadata import entry_points from jupyter_ai.chat_handlers.learn import Retriever from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp @@ -40,7 +43,7 @@ class AiExtension(ExtensionApp): allowed_providers = List( Unicode(), default_value=None, - help="Identifiers of allow-listed providers. If `None`, all are allowed.", + help="Identifiers of allowlisted providers. If `None`, all are allowed.", allow_none=True, config=True, ) @@ -48,7 +51,7 @@ class AiExtension(ExtensionApp): blocked_providers = List( Unicode(), default_value=None, - help="Identifiers of block-listed providers. If `None`, none are blocked.", + help="Identifiers of blocklisted providers. If `None`, none are blocked.", allow_none=True, config=True, ) @@ -156,32 +159,29 @@ def initialize_settings(self): # consumers a Future that resolves to the Dask client when awaited. dask_client_future = loop.create_task(self._get_dask_client()) + eps = entry_points() # initialize chat handlers + chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers") + chat_handler_kwargs = { "log": self.log, "config_manager": self.settings["jai_config_manager"], "root_chat_handlers": self.settings["jai_root_chat_handlers"], + "chat_history": self.settings["chat_history"], + "root_dir": self.serverapp.root_dir, + "dask_client_future": dask_client_future, "model_parameters": self.settings["model_parameters"], } - default_chat_handler = DefaultChatHandler( - **chat_handler_kwargs, chat_history=self.settings["chat_history"] - ) - clear_chat_handler = ClearChatHandler( - **chat_handler_kwargs, chat_history=self.settings["chat_history"] - ) - generate_chat_handler = GenerateChatHandler( - **chat_handler_kwargs, - root_dir=self.serverapp.root_dir, - ) - learn_chat_handler = LearnChatHandler( - **chat_handler_kwargs, - root_dir=self.serverapp.root_dir, - dask_client_future=dask_client_future, - ) + + default_chat_handler = DefaultChatHandler(**chat_handler_kwargs) + clear_chat_handler = ClearChatHandler(**chat_handler_kwargs) + generate_chat_handler = GenerateChatHandler(**chat_handler_kwargs) + learn_chat_handler = LearnChatHandler(**chat_handler_kwargs) help_chat_handler = HelpChatHandler(**chat_handler_kwargs) retriever = Retriever(learn_chat_handler=learn_chat_handler) ask_chat_handler = AskChatHandler(**chat_handler_kwargs, retriever=retriever) - self.settings["jai_chat_handlers"] = { + + jai_chat_handlers = { "default": default_chat_handler, "/ask": ask_chat_handler, "/clear": clear_chat_handler, @@ -190,6 +190,54 @@ def initialize_settings(self): "/help": help_chat_handler, } + slash_command_pattern = r"^[a-zA-Z0-9_]+$" + for chat_handler_ep in chat_handler_eps: + try: + chat_handler = chat_handler_ep.load() + except Exception as err: + self.log.error( + f"Unable to load chat handler class from entry point `{chat_handler_ep.name}`: " + + f"Unexpected {err=}, {type(err)=}" + ) + continue + + if chat_handler.routing_type.routing_method == "slash_command": + # Each slash ID must be used only once. + # Slash IDs may contain only alphanumerics and underscores. + slash_id = chat_handler.routing_type.slash_id + + if slash_id is None: + self.log.error( + f"Handler `{chat_handler_ep.name}` has an invalid slash command " + + f"`None`; only the default chat handler may use this" + ) + continue + + # Validate slash ID (/^[A-Za-z0-9_]+$/) + if re.match(slash_command_pattern, slash_id): + command_name = f"/{slash_id}" + else: + self.log.error( + f"Handler `{chat_handler_ep.name}` has an invalid slash command " + + f"`{slash_id}`; must contain only letters, numbers, " + + "and underscores" + ) + continue + + if command_name in jai_chat_handlers: + self.log.error( + f"Unable to register chat handler `{chat_handler.id}` because command `{command_name}` already has a handler" + ) + continue + + # The entry point is a class; we need to instantiate the class to send messages to it + jai_chat_handlers[command_name] = chat_handler(**chat_handler_kwargs) + self.log.info( + f"Registered chat handler `{chat_handler.id}` with command `{command_name}`." + ) + + self.settings["jai_chat_handlers"] = jai_chat_handlers + latency_ms = round((time.time() - start) * 1000) self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.") diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 170ae4006..ae1498946 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -165,7 +165,7 @@ def open(self): def broadcast_message(self, message: Message): """Broadcasts message to all connected clients. - Appends message to `self.chat_history`. + Appends message to chat history. """ self.log.debug("Broadcasting message: %s to all clients...", message)