Skip to content

Commit

Permalink
implement configurable prompt templates for default chat handler
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Nov 8, 2023
1 parent 9827ba7 commit ad7e07f
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 22 deletions.
28 changes: 27 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from jupyter_ai.config_manager import ConfigManager, Logger
from jupyter_ai.models import AgentChatMessage, HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from traitlets.config import Configurable

if TYPE_CHECKING:
from jupyter_ai.handlers import RootChatHandler


class BaseChatHandler:
class BaseChatHandler(Configurable):
"""Base ChatHandler class containing shared methods and attributes used by
multiple chat handler classes."""

Expand All @@ -23,7 +24,10 @@ def __init__(
log: Logger,
config_manager: ConfigManager,
root_chat_handlers: Dict[str, "RootChatHandler"],
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.log = log
self.config_manager = config_manager
self._root_chat_handlers = root_chat_handlers
Expand Down Expand Up @@ -94,6 +98,28 @@ def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None):
handler.broadcast_message(agent_msg)
break

@property
def lm_id(self):
"""Retrieves the language model ID from the config manager."""
lm_provider = self.config_manager.lm_provider
lm_provider_params = self.config_manager.lm_provider_params

if lm_provider:
return lm_provider.id + ":" + lm_provider_params["model_id"]
else:
return None

@property
def em_id(self):
"""Retrieves the embedding model ID from the config manager."""
em_provider = self.config_manager.em_provider
em_provider_params = self.config_manager.em_provider_params

if em_provider:
return em_provider.id + ":" + em_provider_params["model_id"]
else:
return None

def get_llm_chain(self):
lm_provider = self.config_manager.lm_provider
lm_provider_params = self.config_manager.lm_provider_params
Expand Down
31 changes: 10 additions & 21 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, List, Type

from jupyter_ai.models import ChatMessage, ClearMessage, HumanChatMessage
from jupyter_ai.prompt_templates import ChatPromptTemplates
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferWindowMemory
Expand All @@ -14,29 +15,17 @@

from .base import BaseChatHandler

SYSTEM_PROMPT = """
You are Jupyternaut, a conversational assistant living in JupyterLab to help users.
You are not a language model, but rather an application built on a foundation model from {provider_name} called {local_model_id}.
You are talkative and you provide lots of specific details from the foundation model's context.
You may use Markdown to format your response.
Code blocks must be formatted in Markdown.
Math should be rendered with inline TeX markup, surrounded by $.
If you do not know the answer to a question, answer truthfully by responding that you do not know.
The following is a friendly conversation between you and a human.
""".strip()

DEFAULT_TEMPLATE = """Current conversation:
{history}
Human: {input}
AI:"""


class DefaultChatHandler(BaseChatHandler):
def __init__(self, chat_history: List[ChatMessage], *args, **kwargs):
super().__init__(*args, **kwargs)
self.memory = ConversationBufferWindowMemory(return_messages=True, k=2)
self.chat_history = chat_history

@property
def templates(self):
return ChatPromptTemplates(self.lm_id, config=self.config)

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
Expand All @@ -45,9 +34,9 @@ def create_llm_chain(
if llm.is_chat_provider:
prompt_template = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format(
provider_name=llm.name, local_model_id=llm.model_id
),
SystemMessagePromptTemplate.from_template(
self.templates.system
).format(provider_name=llm.name, local_model_id=llm.model_id),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}"),
]
Expand All @@ -56,11 +45,11 @@ def create_llm_chain(
else:
prompt_template = PromptTemplate(
input_variables=["history", "input"],
template=SYSTEM_PROMPT.format(
template=self.templates.system.format(
provider_name=llm.name, local_model_id=llm.model_id
)
+ "\n\n"
+ DEFAULT_TEMPLATE,
+ self.templates.default,
)
self.memory = ConversationBufferWindowMemory(k=2)

Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def initialize_settings(self):
# initialize chat handlers
chat_handler_kwargs = {
"log": self.log,
"config": self.config, # traitlets config
"config_manager": self.settings["jai_config_manager"],
"root_chat_handlers": self.settings["jai_root_chat_handlers"],
}
Expand Down
75 changes: 75 additions & 0 deletions packages/jupyter-ai/jupyter_ai/prompt_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from traitlets import Dict, Unicode
from traitlets.config import Configurable

SYSTEM_TEMPLATE = """
You are Jupyternaut, a conversational assistant living in JupyterLab to help users.
You are not a language model, but rather an application built on a foundation model from {provider_name} called {local_model_id}.
You are talkative and you provide lots of specific details from the foundation model's context.
You may use Markdown to format your response.
Code blocks must be formatted in Markdown.
Math should be rendered with inline TeX markup, surrounded by $.
If you do not know the answer to a question, answer truthfully by responding that you do not know.
The following is a friendly conversation between you and a human.
""".strip()

HISTORY_TEMPLATE = """
Current conversation:
{history}
Human: {input}
AI:
""".strip()


class ChatPromptTemplates(Configurable):
system_template = Unicode(
default_value=SYSTEM_TEMPLATE,
help="The system prompt template.",
allow_none=False,
config=True,
)

system_overrides = Dict(
key_trait=Unicode(),
value_trait=Unicode(),
default_value={},
help="Defines model-specific overrides of the system prompt template.",
allow_none=False,
config=True,
)

history_template = Unicode(
default_value=HISTORY_TEMPLATE,
help="The history prompt template.",
allow_none=False,
config=True,
)

history_overrides = Dict(
key_trait=Unicode(),
value_trait=Unicode(),
default_value={},
help="Defines model-specific overrides of the history prompt template.",
allow_none=False,
config=True,
)

lm_id: str = None

def __init__(self, lm_id, *args, **kwargs):
super().__init__(*args, **kwargs)

@property
def system(self) -> str:
return self.system_overrides.get(self.lm_id, self.system_template)

@property
def history(self) -> str:
return self.history_overrides.get(self.lm_id, self.history_template)


class AskPromptTemplates(Configurable):
...


class GeneratePromptTemplates(Configurable):
...

0 comments on commit ad7e07f

Please sign in to comment.