diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 233099151..9409d4521 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -481,6 +481,9 @@ def get_llm_chat_memory( last_human_msg: HumanChatMessage, **kwargs, ) -> "BaseChatMessageHistory": + if self.ychat: + return self.llm_chat_memory + return WrappedBoundedChatHistory( history=self.llm_chat_memory, last_human_msg=last_human_msg, diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 7c02b9e40..e65b6391b 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -42,7 +42,7 @@ RootChatHandler, SlashCommandsInfoHandler, ) -from .history import BoundedChatHistory +from .history import BoundedChatHistory, YChatHistory from jupyter_collaboration import ( # type:ignore[import-untyped] # isort:skip __version__ as jupyter_collaboration_version, @@ -417,9 +417,13 @@ def initialize_settings(self): # list of chat messages to broadcast to new clients # this is only used to render the UI, and is not the conversational # memory object used by the LM chain. + # + # TODO: remove this in v3. this list is only used by the REST API to get + # history in v2 chat. self.settings["chat_history"] = [] - # conversational memory object used by LM chain + # TODO: remove this in v3. this is the history implementation that + # provides memory to the chat model in v2. self.settings["llm_chat_memory"] = BoundedChatHistory( k=self.default_max_chat_history ) @@ -512,13 +516,18 @@ def _init_chat_handlers( eps = entry_points() chat_handler_eps = eps.select(group="jupyter_ai.chat_handlers") chat_handlers = {} + if ychat: + llm_chat_memory = YChatHistory(ychat, k=self.default_max_chat_history) + else: + llm_chat_memory = self.settings["llm_chat_memory"] + chat_handler_kwargs = { "log": self.log, "config_manager": self.settings["jai_config_manager"], "model_parameters": self.settings["model_parameters"], "root_chat_handlers": self.settings["jai_root_chat_handlers"], "chat_history": self.settings["chat_history"], - "llm_chat_memory": self.settings["llm_chat_memory"], + "llm_chat_memory": llm_chat_memory, "root_dir": self.serverapp.root_dir, "dask_client_future": self.settings["dask_client_future"], "preferred_dir": self.serverapp.contents_manager.preferred_dir, diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py index 0f1ba7dc0..ec007a8ce 100644 --- a/packages/jupyter-ai/jupyter_ai/history.py +++ b/packages/jupyter-ai/jupyter_ai/history.py @@ -2,14 +2,54 @@ from typing import List, Optional, Sequence, Set, Union from langchain_core.chat_history import BaseChatMessageHistory -from langchain_core.messages import BaseMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.pydantic_v1 import BaseModel, PrivateAttr +from jupyterlab_chat.ychat import YChat + from .models import HumanChatMessage +from .constants import BOT HUMAN_MSG_ID_KEY = "_jupyter_ai_human_msg_id" +class YChatHistory(BaseChatMessageHistory): + """ + An implementation of `BaseChatMessageHistory` that yields the last `k` + exchanges (`k * 2` messages) from the given YChat model. + """ + def __init__(self, ychat: YChat, k: Optional[int]): + self.ychat = ychat + self.k = k + + @property + def messages(self) -> List[BaseMessage]: + """Returns the last `k` messages.""" + # TODO: consider bounding history based on message size (e.g. total + # char/token count) instead of message count. + all_messages = self.ychat.get_messages() + + # gather last k * 2 messages and return + # we exclude the last message since that is the HumanMessage just + # submitted by a user. + messages = [] + for message in all_messages[-self.k * 2 - 1 : -1]: + if message["sender"] == BOT["username"]: + messages.append(AIMessage(content=message["body"])) + else: + messages.append(HumanMessage(content=message["body"])) + + return messages + + def add_message(self, message: BaseMessage) -> None: + # do nothing when other LangChain objects call this method, since + # message history is maintained by the `YChat` shared document. + return + + def clear(self): + raise NotImplementedError() + + class BoundedChatHistory(BaseChatMessageHistory, BaseModel): """ An in-memory implementation of `BaseChatMessageHistory` that stores up to