Skip to content

Commit

Permalink
dedicate a separate LangChain history object per chat
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Dec 10, 2024
1 parent a0b8e84 commit dcf03de
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 4 deletions.
3 changes: 3 additions & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,9 @@ def get_llm_chat_memory(
last_human_msg: HumanChatMessage,
**kwargs,
) -> "BaseChatMessageHistory":
if self.ychat:
return self.llm_chat_memory

return WrappedBoundedChatHistory(
history=self.llm_chat_memory,
last_human_msg=last_human_msg,
Expand Down
15 changes: 12 additions & 3 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
RootChatHandler,
SlashCommandsInfoHandler,
)
from .history import BoundedChatHistory
from .history import BoundedChatHistory, YChatHistory

from jupyter_collaboration import ( # type:ignore[import-untyped] # isort:skip
__version__ as jupyter_collaboration_version,
Expand Down Expand Up @@ -417,9 +417,13 @@ def initialize_settings(self):
# list of chat messages to broadcast to new clients
# this is only used to render the UI, and is not the conversational
# memory object used by the LM chain.
#
# TODO: remove this in v3. this list is only used by the REST API to get
# history in v2 chat.
self.settings["chat_history"] = []

# conversational memory object used by LM chain
# TODO: remove this in v3. this is the history implementation that
# provides memory to the chat model in v2.
self.settings["llm_chat_memory"] = BoundedChatHistory(
k=self.default_max_chat_history
)
Expand Down Expand Up @@ -512,13 +516,18 @@ def _init_chat_handlers(
eps = entry_points()
chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers")
chat_handlers = {}
if ychat:
llm_chat_memory = YChatHistory(ychat, k=self.default_max_chat_history)
else:
llm_chat_memory = self.settings["llm_chat_memory"]

chat_handler_kwargs = {
"log": self.log,
"config_manager": self.settings["jai_config_manager"],
"model_parameters": self.settings["model_parameters"],
"root_chat_handlers": self.settings["jai_root_chat_handlers"],
"chat_history": self.settings["chat_history"],
"llm_chat_memory": self.settings["llm_chat_memory"],
"llm_chat_memory": llm_chat_memory,
"root_dir": self.serverapp.root_dir,
"dask_client_future": self.settings["dask_client_future"],
"preferred_dir": self.serverapp.contents_manager.preferred_dir,
Expand Down
42 changes: 41 additions & 1 deletion packages/jupyter-ai/jupyter_ai/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,54 @@
from typing import List, Optional, Sequence, Set, Union

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr

from jupyterlab_chat.ychat import YChat

from .models import HumanChatMessage
from .constants import BOT

HUMAN_MSG_ID_KEY = "_jupyter_ai_human_msg_id"


class YChatHistory(BaseChatMessageHistory):
"""
An implementation of `BaseChatMessageHistory` that yields the last `k`
exchanges (`k * 2` messages) from the given YChat model.
"""
def __init__(self, ychat: YChat, k: Optional[int]):
self.ychat = ychat
self.k = k

@property
def messages(self) -> List[BaseMessage]:
"""Returns the last `k` messages."""
# TODO: consider bounding history based on message size (e.g. total
# char/token count) instead of message count.
all_messages = self.ychat.get_messages()

# gather last k * 2 messages and return
# we exclude the last message since that is the HumanMessage just
# submitted by a user.
messages = []
for message in all_messages[-self.k * 2 - 1 : -1]:
if message["sender"] == BOT["username"]:
messages.append(AIMessage(content=message["body"]))
else:
messages.append(HumanMessage(content=message["body"]))

return messages

def add_message(self, message: BaseMessage) -> None:
# do nothing when other LangChain objects call this method, since
# message history is maintained by the `YChat` shared document.
return

def clear(self):
raise NotImplementedError()


class BoundedChatHistory(BaseChatMessageHistory, BaseModel):
"""
An in-memory implementation of `BaseChatMessageHistory` that stores up to
Expand Down

0 comments on commit dcf03de

Please sign in to comment.