diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 1ff5d13dd..2fc0523ba 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -487,6 +487,9 @@ def get_llm_chat_memory( last_human_msg: HumanChatMessage, **kwargs, ) -> "BaseChatMessageHistory": + if self.ychat: + return self.llm_chat_memory + return WrappedBoundedChatHistory( history=self.llm_chat_memory, last_human_msg=last_human_msg, diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 83cc00c13..9c1aaab24 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -42,7 +42,7 @@ RootChatHandler, SlashCommandsInfoHandler, ) -from .history import BoundedChatHistory +from .history import BoundedChatHistory, YChatHistory from jupyter_collaboration import ( # type:ignore[import-untyped] # isort:skip __version__ as jupyter_collaboration_version, @@ -418,9 +418,13 @@ def initialize_settings(self): # list of chat messages to broadcast to new clients # this is only used to render the UI, and is not the conversational # memory object used by the LM chain. + # + # TODO: remove this in v3. this list is only used by the REST API to get + # history in v2 chat. self.settings["chat_history"] = [] - # conversational memory object used by LM chain + # TODO: remove this in v3. this is the history implementation that + # provides memory to the chat model in v2. self.settings["llm_chat_memory"] = BoundedChatHistory( k=self.default_max_chat_history ) @@ -515,13 +519,19 @@ def _init_chat_handlers( eps = entry_points() chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers") chat_handlers: Dict[str, BaseChatHandler] = {} + + if ychat: + llm_chat_memory = YChatHistory(ychat, k=self.default_max_chat_history) + else: + llm_chat_memory = self.settings["llm_chat_memory"] + chat_handler_kwargs = { "log": self.log, "config_manager": self.settings["jai_config_manager"], "model_parameters": self.settings["model_parameters"], "root_chat_handlers": self.settings["jai_root_chat_handlers"], "chat_history": self.settings["chat_history"], - "llm_chat_memory": self.settings["llm_chat_memory"], + "llm_chat_memory": llm_chat_memory, "root_dir": self.serverapp.root_dir, "dask_client_future": self.settings["dask_client_future"], "preferred_dir": self.serverapp.contents_manager.preferred_dir, diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py index 0f1ba7dc0..b3a7b8165 100644 --- a/packages/jupyter-ai/jupyter_ai/history.py +++ b/packages/jupyter-ai/jupyter_ai/history.py @@ -1,15 +1,61 @@ import time from typing import List, Optional, Sequence, Set, Union +from jupyterlab_chat.ychat import YChat from langchain_core.chat_history import BaseChatMessageHistory -from langchain_core.messages import BaseMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.pydantic_v1 import BaseModel, PrivateAttr +from .constants import BOT from .models import HumanChatMessage HUMAN_MSG_ID_KEY = "_jupyter_ai_human_msg_id" +class YChatHistory(BaseChatMessageHistory): + """ + An implementation of `BaseChatMessageHistory` that returns the preceding `k` + exchanges (`k * 2` messages) from the given YChat model. + + If `k` is set to `None`, then this class returns all preceding messages. + """ + + def __init__(self, ychat: YChat, k: Optional[int]): + self.ychat = ychat + self.k = k + + @property + def messages(self) -> List[BaseMessage]: # type:ignore[override] + """ + Returns the last `2 * k` messages preceding the latest message. If + `k` is set to `None`, return all preceding messages. + """ + # TODO: consider bounding history based on message size (e.g. total + # char/token count) instead of message count. + all_messages = self.ychat.get_messages() + + # gather last k * 2 messages and return + # we exclude the last message since that is the HumanChatMessage just + # submitted by a user. + messages: List[BaseMessage] = [] + start_idx = 0 if self.k is None else -2 * self.k - 1 + for message in all_messages[start_idx:-1]: + if message["sender"] == BOT["username"]: + messages.append(AIMessage(content=message["body"])) + else: + messages.append(HumanMessage(content=message["body"])) + + return messages + + def add_message(self, message: BaseMessage) -> None: + # do nothing when other LangChain objects call this method, since + # message history is maintained by the `YChat` shared document. + return + + def clear(self): + raise NotImplementedError() + + class BoundedChatHistory(BaseChatMessageHistory, BaseModel): """ An in-memory implementation of `BaseChatMessageHistory` that stores up to