diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 5762f2560..850e24a99 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -11,7 +11,13 @@ from langchain.chat_models.base import BaseChatModel from langchain.llms.sagemaker_endpoint import LLMContentHandler from langchain.llms.utils import enforce_stop_tokens -from langchain.prompts import PromptTemplate +from langchain.prompts import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + MessagesPlaceholder, + PromptTemplate, + SystemMessagePromptTemplate, +) from langchain.pydantic_v1 import BaseModel, Extra, root_validator from langchain.schema import LLMResult from langchain.utils import get_from_dict_or_env @@ -42,6 +48,49 @@ from pydantic.main import ModelMetaclass +CHAT_SYSTEM_PROMPT = """ +You are Jupyternaut, a conversational assistant living in JupyterLab to help users. +You are not a language model, but rather an application built on a foundation model from {provider_name} called {local_model_id}. +You are talkative and you provide lots of specific details from the foundation model's context. +You may use Markdown to format your response. +Code blocks must be formatted in Markdown. +Math should be rendered with inline TeX markup, surrounded by $. +If you do not know the answer to a question, answer truthfully by responding that you do not know. +The following is a friendly conversation between you and a human. +""".strip() + +CHAT_DEFAULT_TEMPLATE = """Current conversation: +{history} +Human: {input} +AI:""" + + +COMPLETION_SYSTEM_PROMPT = """ +You are an application built to provide helpful code completion suggestions. +You should only produce code. Keep comments to minimum, use the +programming language comment syntax. Produce clean code. +The code is written in JupyterLab, a data analysis and code development +environment which can execute code extended with additional syntax for +interactive features, such as magics. +""".strip() + +# only add the suffix bit if present to save input tokens/computation time +COMPLETION_DEFAULT_TEMPLATE = """ +The document is called `{{filename}}` and written in {{language}}. +{% if suffix %} +The code after the completion request is: + +``` +{{suffix}} +``` +{% endif %} + +Complete the following code: + +``` +{{prefix}}""" + + class EnvAuthStrategy(BaseModel): """Require one auth token via an environment variable.""" @@ -265,6 +314,55 @@ def get_prompt_template(self, format) -> PromptTemplate: else: return self.prompt_templates["text"] # Default to plain format + def get_chat_prompt_template(self) -> PromptTemplate: + """ + Produce a prompt template optimised for chat conversation. + The template should take two variables: history and input. + """ + name = self.__class__.name + if self.is_chat_provider: + return ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template( + CHAT_SYSTEM_PROMPT + ).format(provider_name=name, local_model_id=self.model_id), + MessagesPlaceholder(variable_name="history"), + HumanMessagePromptTemplate.from_template("{input}"), + ] + ) + else: + return PromptTemplate( + input_variables=["history", "input"], + template=CHAT_SYSTEM_PROMPT.format( + provider_name=name, local_model_id=self.model_id + ) + + "\n\n" + + CHAT_DEFAULT_TEMPLATE, + ) + + def get_completion_prompt_template(self) -> PromptTemplate: + """ + Produce a prompt template optimised for inline code or text completion. + The template should take variables: prefix, suffix, language, filename. + """ + if self.is_chat_provider: + return ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(COMPLETION_SYSTEM_PROMPT), + HumanMessagePromptTemplate.from_template( + COMPLETION_DEFAULT_TEMPLATE, template_format="jinja2" + ), + ] + ) + else: + return PromptTemplate( + input_variables=["prefix", "suffix", "language", "filename"], + template=COMPLETION_SYSTEM_PROMPT + + "\n\n" + + COMPLETION_DEFAULT_TEMPLATE, + template_format="jinja2", + ) + @property def is_chat_provider(self): return isinstance(self, BaseChatModel) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 0db83afdd..584f0b33f 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -4,32 +4,9 @@ from jupyter_ai_magics.providers import BaseProvider from langchain.chains import ConversationChain from langchain.memory import ConversationBufferWindowMemory -from langchain.prompts import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - MessagesPlaceholder, - PromptTemplate, - SystemMessagePromptTemplate, -) from .base import BaseChatHandler, SlashCommandRoutingType -SYSTEM_PROMPT = """ -You are Jupyternaut, a conversational assistant living in JupyterLab to help users. -You are not a language model, but rather an application built on a foundation model from {provider_name} called {local_model_id}. -You are talkative and you provide lots of specific details from the foundation model's context. -You may use Markdown to format your response. -Code blocks must be formatted in Markdown. -Math should be rendered with inline TeX markup, surrounded by $. -If you do not know the answer to a question, answer truthfully by responding that you do not know. -The following is a friendly conversation between you and a human. -""".strip() - -DEFAULT_TEMPLATE = """Current conversation: -{history} -Human: {input} -AI:""" - class DefaultChatHandler(BaseChatHandler): id = "default" @@ -49,27 +26,10 @@ def create_llm_chain( model_parameters = self.get_model_parameters(provider, provider_params) llm = provider(**provider_params, **model_parameters) - if llm.is_chat_provider: - prompt_template = ChatPromptTemplate.from_messages( - [ - SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format( - provider_name=provider.name, local_model_id=llm.model_id - ), - MessagesPlaceholder(variable_name="history"), - HumanMessagePromptTemplate.from_template("{input}"), - ] - ) - self.memory = ConversationBufferWindowMemory(return_messages=True, k=2) - else: - prompt_template = PromptTemplate( - input_variables=["history", "input"], - template=SYSTEM_PROMPT.format( - provider_name=provider.name, local_model_id=llm.model_id - ) - + "\n\n" - + DEFAULT_TEMPLATE, - ) - self.memory = ConversationBufferWindowMemory(k=2) + prompt_template = llm.get_chat_prompt_template() + self.memory = ConversationBufferWindowMemory( + return_messages=llm.is_chat_provider, k=2 + ) self.llm = llm self.llm_chain = ConversationChain( diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py index 687e41fed..552d23791 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py @@ -18,32 +18,6 @@ ) from .base import BaseInlineCompletionHandler -SYSTEM_PROMPT = """ -You are an application built to provide helpful code completion suggestions. -You should only produce code. Keep comments to minimum, use the -programming language comment syntax. Produce clean code. -The code is written in JupyterLab, a data analysis and code development -environment which can execute code extended with additional syntax for -interactive features, such as magics. -""".strip() - -AFTER_TEMPLATE = """ -The code after the completion request is: - -``` -{suffix} -``` -""".strip() - -DEFAULT_TEMPLATE = """ -The document is called `{filename}` and written in {language}. -{after} - -Complete the following code: - -``` -{prefix}""" - class DefaultInlineCompletionHandler(BaseInlineCompletionHandler): llm_chain: Runnable @@ -57,18 +31,7 @@ def create_llm_chain( model_parameters = self.get_model_parameters(provider, provider_params) llm = provider(**provider_params, **model_parameters) - if llm.is_chat_provider: - prompt_template = ChatPromptTemplate.from_messages( - [ - SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT), - HumanMessagePromptTemplate.from_template(DEFAULT_TEMPLATE), - ] - ) - else: - prompt_template = PromptTemplate( - input_variables=["prefix", "suffix", "language", "filename"], - template=SYSTEM_PROMPT + "\n\n" + DEFAULT_TEMPLATE, - ) + prompt_template = llm.get_completion_prompt_template() self.llm = llm self.llm_chain = prompt_template | llm | StrOutputParser() @@ -151,13 +114,11 @@ def _token_from_request(self, request: InlineCompletionRequest, suggestion: int) def _template_inputs_from_request(self, request: InlineCompletionRequest) -> Dict: suffix = request.suffix.strip() - # only add the suffix template if the suffix is there to save input tokens/computation time - after = AFTER_TEMPLATE.format(suffix=suffix) if suffix else "" filename = request.path.split("/")[-1] if request.path else "untitled" return { "prefix": request.prefix, - "after": after, + "suffix": suffix, "language": request.language, "filename": filename, "stop": ["\n```"],