From d0f531d52783deaa446b0de0820c9bc649476b3f Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Fri, 20 Oct 2023 21:48:23 -0700 Subject: [PATCH] Updated to use memory instead of chat history, fix for Bedrock Anthropic --- .../jupyter_ai/chat_handlers/ask.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index cad14b0e5..dbc2bd679 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -4,9 +4,19 @@ from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics.providers import BaseProvider from langchain.chains import ConversationalRetrievalChain +from langchain.memory import ConversationBufferWindowMemory +from langchain.prompts import PromptTemplate from .base import BaseChatHandler +PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. + +Chat History: +{chat_history} +Follow Up Input: {question} +Standalone question:""" +CONDENSE_PROMPT = PromptTemplate.from_template(PROMPT_TEMPLATE) + class AskChatHandler(BaseChatHandler): """Processes messages prefixed with /ask. This actor will @@ -27,9 +37,15 @@ def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): self.llm = provider(**provider_params) - self.chat_history = [] + memory = ConversationBufferWindowMemory( + memory_key="chat_history", return_messages=True, k=2 + ) self.llm_chain = ConversationalRetrievalChain.from_llm( - self.llm, self._retriever, verbose=True + self.llm, + self._retriever, + memory=memory, + condense_question_prompt=CONDENSE_PROMPT, + verbose=False, ) async def _process_message(self, message: HumanChatMessage): @@ -44,14 +60,8 @@ async def _process_message(self, message: HumanChatMessage): self.get_llm_chain() try: - # limit chat history to last 2 exchanges - self.chat_history = self.chat_history[-2:] - - result = await self.llm_chain.acall( - {"question": query, "chat_history": self.chat_history} - ) + result = await self.llm_chain.acall({"question": query}) response = result["answer"] - self.chat_history.append((query, response)) self.reply(response, message) except AssertionError as e: self.log.error(e)