Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Dec 24, 2024
1 parent cfb1ff2 commit 2f131eb
Show file tree
Hide file tree
Showing 14 changed files with 55 additions and 55 deletions.
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from typing import Dict, Type

from jupyter_ai_magics.providers import BaseProvider
from jupyterlab_chat.models import Message
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferWindowMemory
from langchain_core.prompts import PromptTemplate
from jupyterlab_chat.models import Message

from .base import BaseChatHandler, SlashCommandRoutingType

Expand Down
13 changes: 5 additions & 8 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jupyter_ai.config_manager import ConfigManager, Logger
from jupyter_ai.constants import BOT
from jupyter_ai_magics.providers import BaseProvider
from jupyterlab_chat.models import NewMessage, Message, User
from jupyterlab_chat.models import Message, NewMessage, User
from jupyterlab_chat.ychat import YChat
from langchain.pydantic_v1 import BaseModel
from langchain_core.messages import AIMessageChunk
Expand Down Expand Up @@ -255,7 +255,7 @@ async def _default_handle_exc(self, e: Exception, _human_message: Message):
)
self.reply(response, _human_message)

def reply(self, body: str, _human_message = None) -> str:
def reply(self, body: str, _human_message=None) -> str:
"""
Adds a message to the YChat shared document that this chat handler is
assigned to. Returns the new message ID.
Expand All @@ -269,7 +269,7 @@ def reply(self, body: str, _human_message = None) -> str:

id = self.ychat.add_message(NewMessage(body=body, sender=BOT["username"]))
return id

@property
def persona(self):
return self.config_manager.persona
Expand Down Expand Up @@ -391,7 +391,7 @@ def start_reply_stream(self):
finally:
# close the `ReplyStream` on exit.
reply_stream.close()

async def stream_reply(
self,
input: Input,
Expand Down Expand Up @@ -450,11 +450,8 @@ async def stream_reply(

# if stream was interrupted, add a tombstone
if stream_interrupted:
stream_tombstone = (
"\n\n(AI response stopped by user)"
)
stream_tombstone = "\n\n(AI response stopped by user)"
reply_stream.write(stream_tombstone)



class GenerationInterrupted(asyncio.CancelledError):
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
from typing import Dict, Type

from jupyterlab_chat.models import Message
from jupyter_ai_magics.providers import BaseProvider
from jupyterlab_chat.models import Message
from langchain_core.runnables.history import RunnableWithMessageHistory

from ..context_providers import ContextProviderException, find_commands
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
import nbformat
from jupyter_ai.chat_handlers import BaseChatHandler, SlashCommandRoutingType
from jupyter_ai_magics.providers import BaseProvider
from jupyterlab_chat.models import Message
from langchain.chains import LLMChain
from langchain.llms import BaseLLM
from langchain.output_parsers import PydanticOutputParser
from langchain.pydantic_v1 import BaseModel
from langchain.schema.output_parser import BaseOutputParser
from langchain_core.prompts import PromptTemplate
from jupyterlab_chat.models import Message


class OutlineSection(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/help.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from jupyterlab_chat.models import Message

from .base import BaseChatHandler, SlashCommandRoutingType


Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from jupyter_core.paths import jupyter_data_dir
from jupyter_core.utils import ensure_dir_exists
from jupyterlab_chat.models import Message
from langchain.schema import BaseRetriever, Document
from langchain.text_splitter import (
LatexTextSplitter,
Expand All @@ -28,7 +29,6 @@
RecursiveCharacterTextSplitter,
)
from langchain_community.vectorstores import FAISS
from jupyterlab_chat.models import Message

from .base import BaseChatHandler, SlashCommandRoutingType

Expand Down
17 changes: 10 additions & 7 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/utils/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from typing import Optional

from jupyter_ai.constants import BOT
from jupyterlab_chat.models import Message, NewMessage, User
from jupyterlab_chat.ychat import YChat
from jupyterlab_chat.models import User, NewMessage, Message


class ReplyStreamClosed(Exception):
pass


class ReplyStream:
"""
Object yielded by the `BaseChatHandler.start_reply_stream()` context
Expand All @@ -35,17 +36,17 @@ def __init__(self, ychat: YChat):
self.ychat = ychat
self._is_open = False
self._stream_id: Optional[str] = None

def _set_user(self):
bot = self.ychat.get_user(BOT["username"])
if not bot:
self.ychat.set_user(User(**BOT))

def open(self):
self._set_user()
self.ychat.awareness.set_local_state_field("isWriting", True)
self._is_open = True

def write(self, chunk: str) -> str:
"""
Writes a string chunk to the current reply stream. Returns the ID of the
Expand All @@ -55,10 +56,12 @@ def write(self, chunk: str) -> str:
assert self._is_open
except:
raise ReplyStreamClosed("Reply stream must be opened first.") from None

if not self._stream_id:
self._set_user()
self._stream_id = self.ychat.add_message(NewMessage(body="", sender=BOT["username"]))
self._stream_id = self.ychat.add_message(
NewMessage(body="", sender=BOT["username"])
)

self._set_user()
self.ychat.update_message(
Expand All @@ -73,7 +76,7 @@ def write(self, chunk: str) -> str:
)

return self._stream_id

def close(self):
self.ychat.awareness.set_local_state_field("isWriting", False)
self._is_open = False
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/context_providers/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import List

import nbformat
from jupyterlab_chat.models import Message
from jupyter_ai.document_loaders.directory import SUPPORTED_EXTS
from jupyter_ai.models import ListOptionsEntry
from jupyterlab_chat.models import Message

from .base import (
BaseCommandContextProvider,
Expand Down
4 changes: 1 addition & 3 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,7 @@ async def _stop_extension(self):
await dask_client.close()
self.log.debug("Closed Dask client.")

def _init_chat_handlers(
self, ychat: YChat
) -> Dict[str, BaseChatHandler]:
def _init_chat_handlers(self, ychat: YChat) -> Dict[str, BaseChatHandler]:
"""
Initializes a set of chat handlers. May accept a YChat instance for
collaborative chats.
Expand Down
16 changes: 12 additions & 4 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from typing import TYPE_CHECKING, Dict, List, Optional, cast, Type

from jupyter_ai.chat_handlers import SlashCommandRoutingType, AskChatHandler, DefaultChatHandler, LearnChatHandler, GenerateChatHandler, HelpChatHandler
from typing import TYPE_CHECKING, Dict, List, Optional, Type, cast

from jupyter_ai.chat_handlers import (
AskChatHandler,
DefaultChatHandler,
GenerateChatHandler,
HelpChatHandler,
LearnChatHandler,
SlashCommandRoutingType,
)
from jupyter_ai.config_manager import ConfigManager, KeyEmptyError, WriteConflictError
from jupyter_ai.context_providers import BaseCommandContextProvider, ContextCommand
from jupyter_server.base.handlers import APIHandler as BaseAPIHandler
Expand Down Expand Up @@ -32,9 +39,10 @@
"/ask": AskChatHandler,
"/learn": LearnChatHandler,
"/generate": GenerateChatHandler,
"/help": HelpChatHandler
"/help": HelpChatHandler,
}


class ProviderHandler(BaseAPIHandler):
"""
Helper base class used for HTTP handlers hosting endpoints relating to
Expand Down
10 changes: 4 additions & 6 deletions packages/jupyter-ai/jupyter_ai/history.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import List, Optional

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage

from jupyter_ai.constants import BOT
from jupyterlab_chat.ychat import YChat
from jupyterlab_chat.models import Message as JChatMessage
from jupyterlab_chat.ychat import YChat
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage


class YChatHistory(BaseChatMessageHistory):
Expand Down Expand Up @@ -51,7 +50,7 @@ def _convert_to_langchain_messages(self, jchat_messages: List[JChatMessage]):
messages.append(AIMessage(content=jchat_message.body))
else:
messages.append(HumanMessage(content=jchat_message.body))

return messages

def add_message(self, message: BaseMessage) -> None:
Expand All @@ -61,4 +60,3 @@ def add_message(self, message: BaseMessage) -> None:

def clear(self):
raise NotImplementedError()

6 changes: 3 additions & 3 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Any, Dict, List, Optional

from jupyter_ai_magics.providers import AuthStrategy, Field
from langchain.pydantic_v1 import BaseModel, validator

# unused import: exports Persona from this module
from jupyter_ai_magics.models.persona import Persona
from jupyter_ai_magics.providers import AuthStrategy, Field
from langchain.pydantic_v1 import BaseModel, validator

DEFAULT_CHUNK_SIZE = 2000
DEFAULT_CHUNK_OVERLAP = 100


class ListProvidersEntry(BaseModel):
"""Model provider with supported models
and provider's authentication strategy
Expand Down
10 changes: 2 additions & 8 deletions packages/jupyter-ai/jupyter_ai/tests/test_context_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,7 @@ def human_message() -> Message:
"@file:'test7.py test\"\n" # do not allow for mixed quotes
"```\n@file:fail2.py\n```\n" # do not look within backticks
)
return Message(
id="fake-message-uuid",
time=0,
body=prompt,
sender="fake-user-uuid"
)
return Message(id="fake-message-uuid", time=0, body=prompt, sender="fake-user-uuid")


@pytest.fixture
Expand Down Expand Up @@ -53,8 +48,7 @@ def test_find_instances(file_context_provider, human_message):
"@file:'test7.py",
]
commands = [
cmd.cmd
for cmd in find_commands(file_context_provider, human_message.body)
cmd.cmd for cmd in find_commands(file_context_provider, human_message.body)
]
assert commands == expected

Expand Down
23 changes: 12 additions & 11 deletions packages/jupyter-ai/jupyter_ai/tests/test_handlers.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import logging
import os
import stat
from typing import Optional, List
from typing import List, Optional
from unittest import mock

from jupyterlab_chat.models import NewMessage
from jupyter_ai.chat_handlers import DefaultChatHandler, learn
from jupyter_ai.config_manager import ConfigManager
from jupyter_ai.extension import DEFAULT_HELP_MESSAGE_TEMPLATE
from jupyter_ai.models import (
Persona,
)
from jupyter_ai.history import YChatHistory
from jupyter_ai.models import Persona
from jupyter_ai_magics import BaseProvider
from langchain_community.llms import FakeListLLM
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from jupyterlab_chat.models import NewMessage
from jupyterlab_chat.ychat import YChat
from langchain_community.llms import FakeListLLM
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from pycrdt import Awareness, Doc


Expand Down Expand Up @@ -50,7 +48,7 @@ def __init__(self, lm_provider=None, lm_provider_params=None):
self.ychat = YChat(ydoc=ydoc, awareness=awareness)
self.ychat_history = YChatHistory(ychat=self.ychat)

# initialize & configure mock ConfigManager
# initialize & configure mock ConfigManager
config_manager = mock.create_autospec(ConfigManager)
config_manager.lm_provider = lm_provider or MockProvider
config_manager.lm_provider_params = lm_provider_params or {"model_id": "model"}
Expand Down Expand Up @@ -78,8 +76,10 @@ def messages(self) -> List[BaseMessage]:
the last message.
"""

return self.ychat_history._convert_to_langchain_messages(self.ychat.get_messages())

return self.ychat_history._convert_to_langchain_messages(
self.ychat.get_messages()
)

async def send_human_message(self, body: str = "Hello!"):
"""
Test helper method that sends a human message to this chat handler.
Expand All @@ -101,9 +101,11 @@ def is_writing(self) -> bool:
"""
return self.ychat.awareness.get_local_state()["isWriting"]


class TestException(Exception):
pass


def test_learn_index_permissions(tmp_path):
test_dir = tmp_path / "test"
with mock.patch.object(learn, "INDEX_SAVE_DIR", new=test_dir):
Expand Down Expand Up @@ -142,4 +144,3 @@ async def test_default_stops_writing_on_error():
assert isinstance(handler.messages[0], HumanMessage)
assert isinstance(handler.messages[1], AIMessage)
assert not handler.is_writing

0 comments on commit 2f131eb

Please sign in to comment.