Skip to content

Commit

Permalink
Updated to use memory instead of chat history, fix for Bedrock Anthr…
Browse files Browse the repository at this point in the history
…opic

(cherry picked from commit 7c09863)
  • Loading branch information
3coins committed Oct 23, 2023
1 parent ab12857 commit 5f21903
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,19 @@
from jupyter_ai.models import HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import PromptTemplate

from .base import BaseChatHandler

PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_PROMPT = PromptTemplate.from_template(PROMPT_TEMPLATE)


class AskChatHandler(BaseChatHandler):
"""Processes messages prefixed with /ask. This actor will
Expand All @@ -27,9 +37,15 @@ def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
self.llm = provider(**provider_params)
self.chat_history = []
memory = ConversationBufferWindowMemory(
memory_key="chat_history", return_messages=True, k=2
)
self.llm_chain = ConversationalRetrievalChain.from_llm(
self.llm, self._retriever, verbose=True
self.llm,
self._retriever,
memory=memory,
condense_question_prompt=CONDENSE_PROMPT,
verbose=False,
)

async def _process_message(self, message: HumanChatMessage):
Expand All @@ -44,14 +60,8 @@ async def _process_message(self, message: HumanChatMessage):
self.get_llm_chain()

try:
# limit chat history to last 2 exchanges
self.chat_history = self.chat_history[-2:]

result = await self.llm_chain.acall(
{"question": query, "chat_history": self.chat_history}
)
result = await self.llm_chain.acall({"question": query})
response = result["answer"]
self.chat_history.append((query, response))
self.reply(response, message)
except AssertionError as e:
self.log.error(e)
Expand Down

0 comments on commit 5f21903

Please sign in to comment.