diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 2fc0523ba..962b37fe6 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -37,6 +37,8 @@ ) from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import BaseProvider +from jupyterlab_chat.models import Message as YMessage +from jupyterlab_chat.models import NewMessage, User from jupyterlab_chat.ychat import YChat from langchain.pydantic_v1 import BaseModel from langchain_core.messages import AIMessageChunk @@ -182,13 +184,6 @@ def __init__( self.context_providers = context_providers self.message_interrupted = message_interrupted self.ychat = ychat - self.indexes_by_id: Dict[str, int] = {} - """ - Indexes of messages in the YChat document by message ID. - - TODO: Remove this once `jupyterlab-chat` can update messages by ID - without an index. - """ self.llm: Optional[BaseProvider] = None self.llm_params: Optional[dict] = None @@ -282,13 +277,14 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage): ) self.reply(response, message) - def write_message(self, body: str, id: Optional[str] = None) -> str: + def write_message(self, body: str, stream_id: Optional[str] = None) -> str: """ - [Jupyter Chat only] Writes a message to the YChat shared document - that this chat handler is assigned to. + [Jupyter Chat only] Adds a message to the YChat shared document that + this chat handler is assigned to. If `stream_id` is passed, then this + method appends to the message referenced by `stream_id`. - Returns the new message ID. This will be identical to the `id` argument - if passed. + Returns the new message ID. This will be identical to the `stream_id` + argument if passed. """ # TODO: remove this once `ychat` becomes a required attribute. if not self.ychat: @@ -296,24 +292,23 @@ def write_message(self, body: str, id: Optional[str] = None) -> str: bot = self.ychat.get_user(BOT["username"]) if not bot: - self.ychat.set_user(BOT) - - index = self.indexes_by_id.get(id, None) if id else None - id = id if id else str(uuid4()) - new_index = self.ychat.set_message( - { - "type": "msg", - "body": body, - "id": id if id else str(uuid4()), - "time": time.time(), - "sender": BOT["username"], - "raw_time": False, - }, - index=index, - append=True, - ) + self.ychat.set_user(User(**BOT)) + + if stream_id: + self.ychat.update_message( + YMessage( + body=body, + id=stream_id, + time=time.time(), + sender=BOT["username"], + raw_time=False, + ), + append=True, + ) + id = stream_id + else: + id = self.ychat.add_message(NewMessage(body=body, sender=BOT["username"])) - self.indexes_by_id[id] = new_index return id def broadcast_message(self, message: Message): diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py index b3a7b8165..d1cc80cc1 100644 --- a/packages/jupyter-ai/jupyter_ai/history.py +++ b/packages/jupyter-ai/jupyter_ai/history.py @@ -40,10 +40,10 @@ def messages(self) -> List[BaseMessage]: # type:ignore[override] messages: List[BaseMessage] = [] start_idx = 0 if self.k is None else -2 * self.k - 1 for message in all_messages[start_idx:-1]: - if message["sender"] == BOT["username"]: - messages.append(AIMessage(content=message["body"])) + if message.sender == BOT["username"]: + messages.append(AIMessage(content=message.body)) else: - messages.append(HumanMessage(content=message["body"])) + messages.append(HumanMessage(content=message.body)) return messages diff --git a/packages/jupyter-ai/package.json b/packages/jupyter-ai/package.json index e105f1752..4f5b5bb4d 100644 --- a/packages/jupyter-ai/package.json +++ b/packages/jupyter-ai/package.json @@ -61,7 +61,7 @@ "dependencies": { "@emotion/react": "^11.10.5", "@emotion/styled": "^11.10.5", - "@jupyter/chat": "^0.6.0", + "@jupyter/chat": "^0.7.0", "@jupyterlab/application": "^4.2.0", "@jupyterlab/apputils": "^4.2.0", "@jupyterlab/codeeditor": "^4.2.0", diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index 90581c718..da1d1144b 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "typing_extensions>=4.5.0", "traitlets>=5.0", "deepmerge>=2.0,<3", - "jupyterlab-chat>=0.6.0", + "jupyterlab-chat>=0.7.0,<1.0.0", ] dynamic = ["version", "description", "authors", "urls", "keywords"] diff --git a/yarn.lock b/yarn.lock index 6208fc3c6..a7dbcc5d6 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2220,7 +2220,7 @@ __metadata: "@babel/preset-env": ^7.0.0 "@emotion/react": ^11.10.5 "@emotion/styled": ^11.10.5 - "@jupyter/chat": ^0.6.0 + "@jupyter/chat": ^0.7.0 "@jupyterlab/application": ^4.2.0 "@jupyterlab/apputils": ^4.2.0 "@jupyterlab/builder": ^4.2.0 @@ -2284,9 +2284,9 @@ __metadata: languageName: unknown linkType: soft -"@jupyter/chat@npm:^0.6.0": - version: 0.6.0 - resolution: "@jupyter/chat@npm:0.6.0" +"@jupyter/chat@npm:^0.7.0": + version: 0.7.0 + resolution: "@jupyter/chat@npm:0.7.0" dependencies: "@emotion/react": ^11.10.5 "@emotion/styled": ^11.10.5 @@ -2298,6 +2298,7 @@ __metadata: "@jupyterlab/rendermime": ^4.2.0 "@jupyterlab/ui-components": ^4.2.0 "@lumino/commands": ^2.0.0 + "@lumino/coreutils": ^2.0.0 "@lumino/disposable": ^2.0.0 "@lumino/signaling": ^2.0.0 "@mui/icons-material": ^5.11.0 @@ -2305,7 +2306,7 @@ __metadata: clsx: ^2.1.0 react: ^18.2.0 react-dom: ^18.2.0 - checksum: 1af125488113fbe9089014ae2fd6bf39f5645e6a8387ee513cbae2ed24183ebd6611347e6c6bca05832a6d90fdb001ebce508f7b08e3ae92edd9946b9253dec9 + checksum: 0317fda48c447cf82b00c55d2e51f1d2942a4b05e1308ab0654538c675592baf2e8dd04e93643ae44801fca4068e91e1b563bcb095b78c7aaa82d04a7f8dd2b5 languageName: node linkType: hard @@ -3331,7 +3332,7 @@ __metadata: languageName: node linkType: hard -"@lumino/coreutils@npm:^1.11.0 || ^2.0.0, @lumino/coreutils@npm:^1.11.0 || ^2.1.2, @lumino/coreutils@npm:^2.1.2, @lumino/coreutils@npm:^2.2.0": +"@lumino/coreutils@npm:^1.11.0 || ^2.0.0, @lumino/coreutils@npm:^1.11.0 || ^2.1.2, @lumino/coreutils@npm:^2.0.0, @lumino/coreutils@npm:^2.1.2, @lumino/coreutils@npm:^2.2.0": version: 2.2.0 resolution: "@lumino/coreutils@npm:2.2.0" dependencies: