Skip to content

Commit

Permalink
Updated to use memory instead of chat history, fix for Bedrock Anthr…
Browse files Browse the repository at this point in the history
…opic
  • Loading branch information
3coins committed Oct 21, 2023
1 parent 0d247fd commit 7c09863
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,19 @@
from jupyter_ai.models import HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import PromptTemplate

from .base import BaseChatHandler

PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_PROMPT = PromptTemplate.from_template(PROMPT_TEMPLATE)


class AskChatHandler(BaseChatHandler):
"""Processes messages prefixed with /ask. This actor will
Expand All @@ -27,9 +37,15 @@ def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
self.llm = provider(**provider_params)
self.chat_history = []
memory = ConversationBufferWindowMemory(
memory_key="chat_history", return_messages=True, k=2
)
self.llm_chain = ConversationalRetrievalChain.from_llm(
self.llm, self._retriever, verbose=True
self.llm,
self._retriever,
memory=memory,
condense_question_prompt=CONDENSE_PROMPT,
verbose=False,
)

async def _process_message(self, message: HumanChatMessage):
Expand All @@ -44,14 +60,8 @@ async def _process_message(self, message: HumanChatMessage):
self.get_llm_chain()

try:
# limit chat history to last 2 exchanges
self.chat_history = self.chat_history[-2:]

result = await self.llm_chain.acall(
{"question": query, "chat_history": self.chat_history}
)
result = await self.llm_chain.acall({"question": query})
response = result["answer"]
self.chat_history.append((query, response))
self.reply(response, message)
except AssertionError as e:
self.log.error(e)
Expand Down

0 comments on commit 7c09863

Please sign in to comment.