Skip to content

Commit

Permalink
Consistent arguments to chat handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonWeill committed Oct 25, 2023
1 parent 0ac2581 commit 613bcc3
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 31 deletions.
13 changes: 11 additions & 2 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import argparse
import os
import time
import traceback

# necessary to prevent circular import
from typing import TYPE_CHECKING, ClassVar, Dict, Optional, Type
from typing import TYPE_CHECKING, Awaitable, ClassVar, Dict, List, Optional, Type
from uuid import uuid4

from dask.distributed import Client as DaskClient
from jupyter_ai.config_manager import ConfigManager, Logger
from jupyter_ai.models import AgentChatMessage, HumanChatMessage
from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from traitlets.config import Configurable

Expand Down Expand Up @@ -50,15 +52,22 @@ def __init__(
log: Logger,
config_manager: ConfigManager,
root_chat_handlers: Dict[str, "RootChatHandler"],
chat_history: List[ChatMessage],
root_dir: str,
dask_client_future: Awaitable[DaskClient],
):
self.log = log
self.config_manager = config_manager
self._root_chat_handlers = root_chat_handlers
self._chat_history = chat_history
self.parser = argparse.ArgumentParser()
self.root_dir = os.path.abspath(os.path.expanduser(root_dir))
self.dask_client_future = dask_client_future
self.llm = None
self.llm_params = None
self.llm_chain = None


async def process_message(self, message: HumanChatMessage):
"""Processes the message passed by the root chat handler."""
try:
Expand Down
3 changes: 1 addition & 2 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ class ClearChatHandler(BaseChatHandler):
routing_method = "slash_command"
slash_id = "clear"

def __init__(self, chat_history: List[ChatMessage], *args, **kwargs):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._chat_history = chat_history

async def _process_message(self, _):
self._chat_history.clear()
Expand Down
7 changes: 3 additions & 4 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@


class DefaultChatHandler(BaseChatHandler):
def __init__(self, chat_history: List[ChatMessage], *args, **kwargs):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.memory = ConversationBufferWindowMemory(return_messages=True, k=2)
self.chat_history = chat_history

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
Expand Down Expand Up @@ -79,8 +78,8 @@ def clear_memory(self):
self.reply(reply_message)

# clear transcript for new chat clients
if self.chat_history:
self.chat_history.clear()
if self._chat_history:
self._chat_history.clear()

async def _process_message(self, message: HumanChatMessage):
self.get_llm_chain()
Expand Down
3 changes: 1 addition & 2 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,8 @@ class GenerateChatHandler(BaseChatHandler):
routing_method = "slash_command"
slash_id = "generate"

def __init__(self, root_dir: str, *args, **kwargs):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.root_dir = os.path.abspath(os.path.expanduser(root_dir))
self.llm = None

def create_llm_chain(
Expand Down
6 changes: 1 addition & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,8 @@ class LearnChatHandler(BaseChatHandler):
routing_method = "slash_command"
slash_id = "learn"

def __init__(
self, root_dir: str, dask_client_future: Awaitable[DaskClient], *args, **kwargs
):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.root_dir = root_dir
self.dask_client_future = dask_client_future
self.parser.prog = "/learn"
self.parser.add_argument("-v", "--verbose", action="store_true")
self.parser.add_argument("-d", "--delete", action="store_true")
Expand Down
22 changes: 7 additions & 15 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,23 +113,15 @@ def initialize_settings(self):
"log": self.log,
"config_manager": self.settings["jai_config_manager"],
"root_chat_handlers": self.settings["jai_root_chat_handlers"],
"chat_history": self.settings["chat_history"],
"root_dir": self.serverapp.root_dir,
"dask_client_future": dask_client_future,
}

default_chat_handler = DefaultChatHandler(
**chat_handler_kwargs, chat_history=self.settings["chat_history"]
)
clear_chat_handler = ClearChatHandler(
**chat_handler_kwargs, chat_history=self.settings["chat_history"]
)
generate_chat_handler = GenerateChatHandler(
**chat_handler_kwargs,
root_dir=self.serverapp.root_dir,
)
learn_chat_handler = LearnChatHandler(
**chat_handler_kwargs,
root_dir=self.serverapp.root_dir,
dask_client_future=dask_client_future,
)
default_chat_handler = DefaultChatHandler(**chat_handler_kwargs)
clear_chat_handler = ClearChatHandler(**chat_handler_kwargs)
generate_chat_handler = GenerateChatHandler(**chat_handler_kwargs)
learn_chat_handler = LearnChatHandler(**chat_handler_kwargs)
help_chat_handler = HelpChatHandler(**chat_handler_kwargs)
retriever = Retriever(learn_chat_handler=learn_chat_handler)
ask_chat_handler = AskChatHandler(**chat_handler_kwargs, retriever=retriever)
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def open(self):

def broadcast_message(self, message: Message):
"""Broadcasts message to all connected clients.
Appends message to `self.chat_history`.
Appends message to chat history.
"""

self.log.debug("Broadcasting message: %s to all clients...", message)
Expand Down

0 comments on commit 613bcc3

Please sign in to comment.