Skip to content

Commit

Permalink
revert changes to chat handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Dec 4, 2024
1 parent 0c70a37 commit 956164b
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 61 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from jupyter_ai.chat_handlers.base import BaseChatHandler, SlashCommandRoutingType
from jupyter_ai.models import HumanChatMessage
from jupyterlab_chat.ychat import YChat


class TestSlashCommand(BaseChatHandler):
Expand All @@ -26,5 +25,5 @@ class TestSlashCommand(BaseChatHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def process_message(self, message: HumanChatMessage, chat: YChat):
self.reply("This is the `/test` slash command.", chat)
async def process_message(self, message: HumanChatMessage):
self.reply("This is the `/test` slash command.")
15 changes: 7 additions & 8 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import argparse
from typing import Dict, Optional, Type
from typing import Dict, Type

from jupyter_ai.models import HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from jupyterlab_chat.ychat import YChat
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferWindowMemory
from langchain_core.prompts import PromptTemplate
Expand Down Expand Up @@ -60,32 +59,32 @@ def create_llm_chain(
verbose=False,
)

async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]):
args = self.parse_args(message, chat)
async def process_message(self, message: HumanChatMessage):
args = self.parse_args(message)
if args is None:
return
query = " ".join(args.query)
if not query:
self.reply(f"{self.parser.format_usage()}", chat, message)
self.reply(f"{self.parser.format_usage()}", message)
return

self.get_llm_chain()

try:
with self.pending("Searching learned documents", message, chat=chat):
with self.pending("Searching learned documents", message):
assert self.llm_chain
# TODO: migrate this class to use a LCEL `Runnable` instead of
# `Chain`, then remove the below ignore comment.
result = await self.llm_chain.acall( # type:ignore[attr-defined]
{"question": query}
)
response = result["answer"]
self.reply(response, chat, message)
self.reply(response, message)
except AssertionError as e:
self.log.error(e)
response = """Sorry, an error occurred while reading the from the learned documents.
If you have changed the embedding provider, try deleting the existing index by running
`/learn -d` command and then re-submitting the `learn <directory>` to learn the documents,
and then asking the question again.
"""
self.reply(response, chat, message)
self.reply(response, message)
7 changes: 2 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from typing import Optional

from jupyter_ai.models import ClearRequest
from jupyterlab_chat.ychat import YChat

from .base import BaseChatHandler, SlashCommandRoutingType

Expand All @@ -19,11 +16,11 @@ class ClearChatHandler(BaseChatHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def process_message(self, _, chat: Optional[YChat]):
async def process_message(self, _):
# Clear chat by triggering `RootChatHandler.on_clear_request()`.
for handler in self._root_chat_handlers.values():
if not handler:
continue

handler.on_clear_request(ClearRequest(target=None))
handler.on_clear_request(ClearRequest())
break
9 changes: 4 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import asyncio
from typing import Dict, Optional, Type
from typing import Dict, Type

from jupyter_ai.models import HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from jupyterlab_chat.ychat import YChat
from langchain_core.runnables import ConfigurableFieldSpec
from langchain_core.runnables.history import RunnableWithMessageHistory

Expand Down Expand Up @@ -54,7 +53,7 @@ def create_llm_chain(
)
self.llm_chain = runnable

async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]):
async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()
assert self.llm_chain

Expand All @@ -64,12 +63,12 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]
try:
context_prompt = await self.make_context_prompt(message)
except ContextProviderException as e:
self.reply(str(e), chat, message)
self.reply(str(e), message)
return
inputs["context"] = context_prompt
inputs["input"] = self.replace_prompt(inputs["input"])

await self.stream_reply(inputs, message, chat=chat)
await self.stream_reply(inputs, message)

async def make_context_prompt(self, human_msg: HumanChatMessage) -> str:
return "\n\n".join(
Expand Down
9 changes: 4 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/export.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import argparse
import os
from datetime import datetime
from typing import List, Optional
from typing import List

from jupyter_ai.models import AgentChatMessage, AgentStreamMessage, HumanChatMessage
from jupyterlab_chat.ychat import YChat

from .base import BaseChatHandler, SlashCommandRoutingType

Expand Down Expand Up @@ -32,11 +31,11 @@ def chat_message_to_markdown(self, message):
return ""

# Write the chat history to a markdown file with a timestamp
async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]):
async def process_message(self, message: HumanChatMessage):
markdown_content = "\n\n".join(
self.chat_message_to_markdown(msg) for msg in self._chat_history
)
args = self.parse_args(message, chat)
args = self.parse_args(message)
chat_filename = ( # if no filename, use "chat_history" + timestamp
args.path[0]
if (args.path and args.path[0] != "")
Expand All @@ -47,4 +46,4 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]
) # Do not use timestamp if filename is entered as argument
with open(chat_file, "w") as chat_history:
chat_history.write(markdown_content)
self.reply(f"File saved to `{chat_file}`", chat)
self.reply(f"File saved to `{chat_file}`")
8 changes: 3 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Dict, Optional, Type
from typing import Dict, Type

from jupyter_ai.models import CellWithErrorSelection, HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from jupyterlab_chat.ychat import YChat
from langchain.prompts import PromptTemplate

from .base import BaseChatHandler, SlashCommandRoutingType
Expand Down Expand Up @@ -80,11 +79,10 @@ def create_llm_chain(
runnable = prompt_template | llm # type:ignore
self.llm_chain = runnable

async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]):
async def process_message(self, message: HumanChatMessage):
if not (message.selection and message.selection.type == "cell-with-error"):
self.reply(
"`/fix` requires an active code cell with error output. Please click on a cell with error output and retry.",
chat,
message,
)
return
Expand All @@ -106,5 +104,5 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]
"error_value": selection.error.value,
}
await self.stream_reply(
inputs, message, pending_msg="Analyzing error", chat=chat
inputs, message, pending_msg="Analyzing error"
)
11 changes: 5 additions & 6 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from jupyter_ai.chat_handlers import BaseChatHandler, SlashCommandRoutingType
from jupyter_ai.models import HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from jupyterlab_chat.ychat import YChat
from langchain.chains import LLMChain
from langchain.llms import BaseLLM
from langchain.output_parsers import PydanticOutputParser
Expand Down Expand Up @@ -263,19 +262,19 @@ async def _generate_notebook(self, prompt: str):
nbformat.write(notebook, final_path)
return final_path

async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]):
async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()

# first send a verification message to user
response = "👍 Great, I will get started on your notebook. It may take a few minutes, but I will reply here when the notebook is ready. In the meantime, you can continue to ask me other questions."
self.reply(response, chat, message)
self.reply(response, message)

final_path = await self._generate_notebook(prompt=message.body)
response = f"""🎉 I have created your notebook and saved it to the location {final_path}. I am still learning how to create notebooks, so please review all code before running it."""
self.reply(response, chat, message)
self.reply(response, message)

async def handle_exc(
self, e: Exception, message: HumanChatMessage, chat: Optional[YChat]
self, e: Exception, message: HumanChatMessage
):
timestamp = time.strftime("%Y-%m-%d-%H.%M.%S")
default_log_dir = Path(self.output_dir) / "jupyter-ai-logs"
Expand All @@ -286,4 +285,4 @@ async def handle_exc(
traceback.print_exc(file=log)

response = f"An error occurred while generating the notebook. The error details have been saved to `./{log_path}`.\n\nTry running `/generate` again, as some language models require multiple attempts before a notebook is generated."
self.reply(response, chat, message)
self.reply(response, message)
7 changes: 2 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/help.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from typing import Optional

from jupyter_ai.models import HumanChatMessage
from jupyterlab_chat.ychat import YChat

from .base import BaseChatHandler, SlashCommandRoutingType

Expand All @@ -18,5 +15,5 @@ class HelpChatHandler(BaseChatHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]):
self.send_help_message(chat, message)
async def process_message(self, message: HumanChatMessage):
self.send_help_message(message)
33 changes: 14 additions & 19 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
)
from jupyter_core.paths import jupyter_data_dir
from jupyter_core.utils import ensure_dir_exists
from jupyterlab_chat.ychat import YChat
from langchain.schema import BaseRetriever, Document
from langchain.text_splitter import (
LatexTextSplitter,
Expand Down Expand Up @@ -129,29 +128,28 @@ def _load(self):
)
self.log.error(e)

async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]):
async def process_message(self, message: HumanChatMessage):
# If no embedding provider has been selected
em_provider_cls, em_provider_args = self.get_embedding_provider()
if not em_provider_cls:
self.reply(
"Sorry, please select an embedding provider before using the `/learn` command.",
chat,
)
return

args = self.parse_args(message, chat)
args = self.parse_args(message)
if args is None:
return

if args.delete:
self.delete()
self.reply(
f"👍 I have deleted everything I previously learned.", chat, message
f"👍 I have deleted everything I previously learned.", message
)
return

if args.list:
self.reply(self._build_list_response(), chat)
self.reply(self._build_list_response())
return

if args.remote:
Expand All @@ -162,23 +160,20 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]
args.path = [arxiv_to_text(id, self.output_dir)]
self.reply(
f"Learning arxiv file with id **{id}**, saved in **{args.path[0]}**.",
chat,
message,
)
except ModuleNotFoundError as e:
self.log.error(e)
self.reply(
"No `arxiv` package found. "
"Install with `pip install arxiv`.",
chat,
)
return
except Exception as e:
self.log.error(e)
self.reply(
"An error occurred while processing the arXiv file. "
f"Please verify that the arxiv id {id} is correct.",
chat,
)
return

Expand All @@ -194,7 +189,7 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]
"- Learn on files in the root directory: `/learn *`\n"
"- Learn all python files under the root directory recursively: `/learn **/*.py`"
)
self.reply(f"{self.parser.format_usage()}\n\n {no_path_arg_message}", chat)
self.reply(f"{self.parser.format_usage()}\n\n {no_path_arg_message}")
return
short_path = args.path[0]
load_path = os.path.join(self.output_dir, short_path)
Expand All @@ -204,14 +199,14 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]
next(iglob(load_path))
except StopIteration:
response = f"Sorry, that path doesn't exist: {load_path}"
self.reply(response, chat, message)
self.reply(response, message)
return

# delete and relearn index if embedding model was changed
await self.delete_and_relearn(chat)
await self.delete_and_relearn()

with self.pending(
f"Loading and splitting files for {load_path}", message, chat=chat
f"Loading and splitting files for {load_path}", message
):
try:
await self.learn_dir(
Expand All @@ -228,7 +223,7 @@ async def process_message(self, message: HumanChatMessage, chat: Optional[YChat]
You can ask questions about these docs by prefixing your message with **/ask**.""" % (
load_path.replace("*", r"\*")
)
self.reply(response, chat, message)
self.reply(response, message)

def _build_list_response(self):
if not self.metadata.dirs:
Expand Down Expand Up @@ -282,7 +277,7 @@ def _add_dir_to_metadata(self, path: str, chunk_size: int, chunk_overlap: int):
)
self.metadata.dirs = dirs

async def delete_and_relearn(self, chat: Optional[YChat] = None):
async def delete_and_relearn(self):
"""Delete the vector store and relearn all indexed directories if
necessary. If the embedding model is unchanged, this method does
nothing."""
Expand All @@ -309,11 +304,11 @@ async def delete_and_relearn(self, chat: Optional[YChat] = None):
documents you had previously submitted for learning. Please wait to use
the **/ask** command until I am done with this task."""

self.reply(message, chat)
self.reply(message)

metadata = self.metadata
self.delete()
await self.relearn(metadata, chat)
await self.relearn(metadata)
self.prev_em_id = curr_em_id

def delete(self):
Expand All @@ -327,7 +322,7 @@ def delete(self):
if os.path.isfile(path):
os.remove(path)

async def relearn(self, metadata: IndexMetadata, chat: Optional[YChat]):
async def relearn(self, metadata: IndexMetadata):
# Index all dirs in the metadata
if not metadata.dirs:
return
Expand All @@ -347,7 +342,7 @@ async def relearn(self, metadata: IndexMetadata, chat: Optional[YChat]):
message = f"""🎉 I am done learning docs in these directories:
{dir_list} I am ready to answer questions about them.
You can ask me about these documents by starting your message with **/ask**."""
self.reply(message, chat)
self.reply(message)

def create(
self,
Expand Down

0 comments on commit 956164b

Please sign in to comment.