diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py index ec007a8ce..fbc7d5934 100644 --- a/packages/jupyter-ai/jupyter_ai/history.py +++ b/packages/jupyter-ai/jupyter_ai/history.py @@ -1,14 +1,13 @@ import time from typing import List, Optional, Sequence, Set, Union +from jupyterlab_chat.ychat import YChat from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.pydantic_v1 import BaseModel, PrivateAttr -from jupyterlab_chat.ychat import YChat - -from .models import HumanChatMessage from .constants import BOT +from .models import HumanChatMessage HUMAN_MSG_ID_KEY = "_jupyter_ai_human_msg_id" @@ -18,10 +17,11 @@ class YChatHistory(BaseChatMessageHistory): An implementation of `BaseChatMessageHistory` that yields the last `k` exchanges (`k * 2` messages) from the given YChat model. """ + def __init__(self, ychat: YChat, k: Optional[int]): self.ychat = ychat self.k = k - + @property def messages(self) -> List[BaseMessage]: """Returns the last `k` messages.""" @@ -40,16 +40,16 @@ def messages(self) -> List[BaseMessage]: messages.append(HumanMessage(content=message["body"])) return messages - + def add_message(self, message: BaseMessage) -> None: # do nothing when other LangChain objects call this method, since # message history is maintained by the `YChat` shared document. return - + def clear(self): raise NotImplementedError() - - + + class BoundedChatHistory(BaseChatMessageHistory, BaseModel): """ An in-memory implementation of `BaseChatMessageHistory` that stores up to