diff --git a/docs/source/developers/index.md b/docs/source/developers/index.md index c9062ad1d..aac923285 100644 --- a/docs/source/developers/index.md +++ b/docs/source/developers/index.md @@ -150,6 +150,84 @@ my-provider = "my_provider:MyEmbeddingsProvider" [Embeddings]: https://api.python.langchain.com/en/stable/embeddings/langchain_core.embeddings.Embeddings.html + +### Custom completion providers + +Any model provider derived from `BaseProvider` can be used as a completion provider. +However, some providers may benefit from customizing handling of completion requests. + +There are two asynchronous methods which can be overridden in subclasses of `BaseProvider`: +- `generate_inline_completions`: takes a request (`InlineCompletionRequest`) and returns `InlineCompletionReply` +- `stream_inline_completions`: takes a request and yields an initiating reply (`InlineCompletionReply`) with `isIncomplete` set to `True` followed by subsequent chunks (`InlineCompletionStreamChunk`) + +When streaming all replies and chunks for given invocation of the `stream_inline_completions()` method should include a constant and unique string token identifying the stream. All chunks except for the last chunk for a given item should have the `done` value set to `False`. + +The following example demonstrates a custom implementation of the completion provider with both a method for sending multiple completions in one go, and streaming multiple completions concurrently. +The implementation and explanation for the `merge_iterators` function used in this example can be found [here](https://stackoverflow.com/q/72445371/4877269). + +```python +class MyCompletionProvider(BaseProvider, FakeListLLM): + id = "my_provider" + name = "My Provider" + model_id_key = "model" + models = ["model_a"] + + def __init__(self, **kwargs): + kwargs["responses"] = ["This fake response will not be used for completion"] + super().__init__(**kwargs) + + async def generate_inline_completions(self, request: InlineCompletionRequest): + return InlineCompletionReply( + list=InlineCompletionList(items=[ + {"insertText": "An ant minding its own business"}, + {"insertText": "A bug searching for a snack"} + ]), + reply_to=request.number, + ) + + async def stream_inline_completions(self, request: InlineCompletionRequest): + token_1 = f"t{request.number}s0" + token_2 = f"t{request.number}s1" + + yield InlineCompletionReply( + list=InlineCompletionList( + items=[ + {"insertText": "An ", "isIncomplete": True, "token": token_1}, + {"insertText": "", "isIncomplete": True, "token": token_2} + ] + ), + reply_to=request.number, + ) + + # where merge_iterators + async for reply in merge_iterators([ + self._stream("elephant dancing in the rain", request.number, token_1, start_with="An"), + self._stream("A flock of birds flying around a mountain", request.number, token_2) + ]): + yield reply + + async def _stream(self, sentence, request_number, token, start_with = ""): + suggestion = start_with + + for fragment in sentence.split(): + await asyncio.sleep(0.75) + suggestion += " " + fragment + yield InlineCompletionStreamChunk( + type="stream", + response={"insertText": suggestion, "token": token}, + reply_to=request_number, + done=False + ) + + # finally, send a message confirming that we are done + yield InlineCompletionStreamChunk( + type="stream", + response={"insertText": suggestion, "token": token}, + reply_to=request_number, + done=True, + ) +``` + ## Prompt templates Each provider can define **prompt templates** for each supported format. A prompt diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/completion_utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/completion_utils.py new file mode 100644 index 000000000..204da5e7b --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/completion_utils.py @@ -0,0 +1,52 @@ +from typing import Dict + +from .models.completion import InlineCompletionRequest + + +def token_from_request(request: InlineCompletionRequest, suggestion: int): + """Generate a deterministic token (for matching streamed messages) + using request number and suggestion number""" + return f"t{request.number}s{suggestion}" + + +def template_inputs_from_request(request: InlineCompletionRequest) -> Dict: + suffix = request.suffix.strip() + filename = request.path.split("/")[-1] if request.path else "untitled" + + return { + "prefix": request.prefix, + "suffix": suffix, + "language": request.language, + "filename": filename, + "stop": ["\n```"], + } + + +def post_process_suggestion(suggestion: str, request: InlineCompletionRequest) -> str: + """Remove spurious fragments from the suggestion. + + While most models (especially instruct and infill models do not require + any pre-processing, some models such as gpt-4 which only have chat APIs + may require removing spurious fragments. This function uses heuristics + and request data to remove such fragments. + """ + # gpt-4 tends to add "```python" or similar + language = request.language or "python" + markdown_identifiers = {"ipython": ["ipython", "python", "py"]} + bad_openings = [ + f"```{identifier}" + for identifier in markdown_identifiers.get(language, [language]) + ] + ["```"] + for opening in bad_openings: + if suggestion.startswith(opening): + suggestion = suggestion[len(opening) :].lstrip() + # check for the prefix inclusion (only if there was a bad opening) + if suggestion.startswith(request.prefix): + suggestion = suggestion[len(request.prefix) :] + break + + # check if the suggestion ends with a closing markdown identifier and remove it + if suggestion.rstrip().endswith("```"): + suggestion = suggestion.rstrip()[:-3].rstrip() + + return suggestion diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py b/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py new file mode 100644 index 000000000..147f6ceec --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py @@ -0,0 +1,81 @@ +from typing import List, Literal, Optional + +from langchain.pydantic_v1 import BaseModel + + +class InlineCompletionRequest(BaseModel): + """Message send by client to request inline completions. + + Prefix/suffix implementation is used to avoid the need for synchronising + the notebook state at every key press (subject to change in future).""" + + # unique message ID generated by the client used to identify replies and + # to easily discard replies for older requests + number: int + # prefix should include full text of the current cell preceding the cursor + prefix: str + # suffix should include full text of the current cell preceding the cursor + suffix: str + # media type for the current language, e.g. `text/x-python` + mime: str + # whether to stream the response (if supported by the model) + stream: bool + # path to the notebook of file for which the completions are generated + path: Optional[str] + # language inferred from the document mime type (if possible) + language: Optional[str] + # identifier of the cell for which the completions are generated if in a notebook + # previous cells and following cells can be used to learn the wider context + cell_id: Optional[str] + + +class InlineCompletionItem(BaseModel): + """The inline completion suggestion to be displayed on the frontend. + + See JupyterLab `InlineCompletionItem` documentation for the details. + """ + + insertText: str + filterText: Optional[str] + isIncomplete: Optional[bool] + token: Optional[str] + + +class CompletionError(BaseModel): + type: str + traceback: str + + +class InlineCompletionList(BaseModel): + """Reflection of JupyterLab's `IInlineCompletionList`.""" + + items: List[InlineCompletionItem] + + +class InlineCompletionReply(BaseModel): + """Message sent from model to client with the infill suggestions""" + + list: InlineCompletionList + # number of request for which we are replying + reply_to: int + error: Optional[CompletionError] + + +class InlineCompletionStreamChunk(BaseModel): + """Message sent from model to client with the infill suggestions""" + + type: Literal["stream"] = "stream" + response: InlineCompletionItem + reply_to: int + done: bool + error: Optional[CompletionError] + + +__all__ = [ + "InlineCompletionRequest", + "InlineCompletionItem", + "CompletionError", + "InlineCompletionList", + "InlineCompletionReply", + "InlineCompletionStreamChunk", +] diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py index 382a480e1..a1347073f 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py @@ -75,23 +75,17 @@ class AzureChatOpenAIProvider(BaseProvider, AzureChatOpenAI): id = "azure-chat-openai" name = "Azure OpenAI" models = ["*"] - model_id_key = "deployment_name" + model_id_key = "azure_deployment" model_id_label = "Deployment name" pypi_package_deps = ["langchain_openai"] + # Confusingly, langchain uses both OPENAI_API_KEY and AZURE_OPENAI_API_KEY for azure + # https://github.com/langchain-ai/langchain/blob/f2579096993ae460516a0aae1d3e09f3eb5c1772/libs/partners/openai/langchain_openai/llms/azure.py#L85 auth_strategy = EnvAuthStrategy(name="AZURE_OPENAI_API_KEY") registry = True fields = [ - TextField( - key="openai_api_base", label="Base API URL (required)", format="text" - ), - TextField( - key="openai_api_version", label="API version (required)", format="text" - ), - TextField( - key="openai_organization", label="Organization (optional)", format="text" - ), - TextField(key="openai_proxy", label="Proxy (optional)", format="text"), + TextField(key="azure_endpoint", label="Base API URL (required)", format="text"), + TextField(key="api_version", label="API version (required)", format="text"), ] diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 091b78fea..3d27a4861 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -5,7 +5,17 @@ import io import json from concurrent.futures import ThreadPoolExecutor -from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union +from typing import ( + Any, + AsyncIterator, + ClassVar, + Coroutine, + Dict, + List, + Literal, + Optional, + Union, +) from jsonpath_ng import parse from langchain.chat_models.base import BaseChatModel @@ -20,6 +30,8 @@ ) from langchain.pydantic_v1 import BaseModel, Extra, root_validator from langchain.schema import LLMResult +from langchain.schema.output_parser import StrOutputParser +from langchain.schema.runnable import Runnable from langchain.utils import get_from_dict_or_env from langchain_community.chat_models import ( BedrockChat, @@ -46,6 +58,13 @@ except: from pydantic.main import ModelMetaclass +from . import completion_utils as completion +from .models.completion import ( + InlineCompletionList, + InlineCompletionReply, + InlineCompletionRequest, + InlineCompletionStreamChunk, +) from .models.persona import Persona CHAT_SYSTEM_PROMPT = """ @@ -405,6 +424,71 @@ def is_chat_provider(self): def allows_concurrency(self): return True + async def generate_inline_completions( + self, request: InlineCompletionRequest + ) -> InlineCompletionReply: + chain = self._create_completion_chain() + model_arguments = completion.template_inputs_from_request(request) + suggestion = await chain.ainvoke(input=model_arguments) + suggestion = completion.post_process_suggestion(suggestion, request) + return InlineCompletionReply( + list=InlineCompletionList(items=[{"insertText": suggestion}]), + reply_to=request.number, + ) + + async def stream_inline_completions( + self, request: InlineCompletionRequest + ) -> AsyncIterator[InlineCompletionStreamChunk]: + chain = self._create_completion_chain() + token = completion.token_from_request(request, 0) + model_arguments = completion.template_inputs_from_request(request) + suggestion = "" + + # send an incomplete `InlineCompletionReply`, indicating to the + # client that LLM output is about to streamed across this connection. + yield InlineCompletionReply( + list=InlineCompletionList( + items=[ + { + # insert text starts empty as we do not pre-generate any part + "insertText": "", + "isIncomplete": True, + "token": token, + } + ] + ), + reply_to=request.number, + ) + + async for fragment in chain.astream(input=model_arguments): + suggestion += fragment + if suggestion.startswith("```"): + if "\n" not in suggestion: + # we are not ready to apply post-processing + continue + else: + suggestion = completion.post_process_suggestion(suggestion, request) + elif suggestion.rstrip().endswith("```"): + suggestion = completion.post_process_suggestion(suggestion, request) + yield InlineCompletionStreamChunk( + type="stream", + response={"insertText": suggestion, "token": token}, + reply_to=request.number, + done=False, + ) + + # finally, send a message confirming that we are done + yield InlineCompletionStreamChunk( + type="stream", + response={"insertText": suggestion, "token": token}, + reply_to=request.number, + done=True, + ) + + def _create_completion_chain(self) -> Runnable: + prompt_template = self.get_completion_prompt_template() + return prompt_template | self | StrOutputParser() + class AI21Provider(BaseProvider, AI21): id = "ai21" diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py index e46038da5..383076c52 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/help.py @@ -62,4 +62,14 @@ def __init__(self, *args, chat_handlers: Dict[str, BaseChatHandler], **kwargs): self._chat_handlers = chat_handlers async def process_message(self, message: HumanChatMessage): - self.reply(_format_help_message(self._chat_handlers), message) + persona = self.config_manager.persona + lm_provider = self.config_manager.lm_provider + unsupported_slash_commands = ( + lm_provider.unsupported_slash_commands if lm_provider else set() + ) + self.reply( + _format_help_message( + self._chat_handlers, persona, unsupported_slash_commands + ), + message, + ) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index e1a22c9cc..38390a44c 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -67,17 +67,23 @@ def __init__(self, *args, **kwargs): def _load(self): """Loads the vector store.""" - embeddings = self.get_embedding_model() - if not embeddings: + if self.index is not None: return - if self.index is None: - try: - self.index = FAISS.load_local( - INDEX_SAVE_DIR, embeddings, index_name=self.index_name - ) - self.load_metadata() - except Exception as e: - self.log.error("Could not load vector index from disk.") + + try: + embeddings = self.get_embedding_model() + if not embeddings: + return + + self.index = FAISS.load_local( + INDEX_SAVE_DIR, embeddings, index_name=self.index_name + ) + self.load_metadata() + except Exception as e: + self.log.error( + "Could not load vector index from disk. Full exception details printed below." + ) + self.log.error(e) async def process_message(self, message: HumanChatMessage): # If no embedding provider has been selected @@ -118,13 +124,16 @@ async def process_message(self, message: HumanChatMessage): if args.verbose: self.reply(f"Loading and splitting files for {load_path}", message) - await self.learn_dir( - load_path, args.chunk_size, args.chunk_overlap, args.all_files - ) - self.save() - - response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them. - You can ask questions about these docs by prefixing your message with **/ask**.""" + try: + await self.learn_dir( + load_path, args.chunk_size, args.chunk_overlap, args.all_files + ) + except Exception as e: + response = f"""Learn documents in **{load_path}** failed. {str(e)}.""" + else: + self.save() + response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them. + You can ask questions about these docs by prefixing your message with **/ask**.""" self.reply(response, message) def _build_list_response(self): @@ -155,7 +164,6 @@ async def learn_dir( delayed = split(path, all_files, splitter=splitter) doc_chunks = await dask_client.compute(delayed) - em_provider_cls, em_provider_args = self.get_embedding_provider() delayed = get_embeddings(doc_chunks, em_provider_cls, em_provider_args) embedding_records = await dask_client.compute(delayed) diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py index c52c308db..9eb4f845a 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/base.py @@ -2,7 +2,7 @@ import time import traceback from asyncio import AbstractEventLoop -from typing import Any, AsyncIterator, Dict, Union +from typing import Union import tornado from jupyter_ai.completions.handlers.llm_mixin import LLMHandlerMixin @@ -14,7 +14,7 @@ InlineCompletionStreamChunk, ) from jupyter_server.base.handlers import JupyterHandler -from langchain.pydantic_v1 import BaseModel, ValidationError +from langchain.pydantic_v1 import ValidationError class BaseInlineCompletionHandler( @@ -27,12 +27,10 @@ class BaseInlineCompletionHandler( ## # Interface for subclasses ## - async def handle_request( - self, message: InlineCompletionRequest - ) -> InlineCompletionReply: + async def handle_request(self, message: InlineCompletionRequest) -> None: """ Handles an inline completion request, without streaming. Subclasses - must define this method and write a reply via `self.write_message()`. + must define this method and write a reply via `self.reply()`. The method definition does not need to be wrapped in a try/except block. """ @@ -40,14 +38,11 @@ async def handle_request( "The required method `self.handle_request()` is not defined by this subclass." ) - async def handle_stream_request( - self, message: InlineCompletionRequest - ) -> AsyncIterator[InlineCompletionStreamChunk]: + async def handle_stream_request(self, message: InlineCompletionRequest) -> None: """ Handles an inline completion request, **with streaming**. Implementations may optionally define this method. Implementations that - do so should stream replies via successive calls to - `self.write_message()`. + do so should stream replies via successive calls to `self.reply()`. The method definition does not need to be wrapped in a try/except block. """ @@ -64,14 +59,9 @@ async def handle_stream_request( def loop(self) -> AbstractEventLoop: return self.settings["jai_event_loop"] - def write_message(self, message: Union[bytes, str, Dict[str, Any], BaseModel]): - """ - Write a bytes, string, dict, or Pydantic model object to the WebSocket - connection. The base definition of this method is provided by Tornado. - """ - if isinstance(message, BaseModel): - message = message.dict() - + def reply(self, reply: Union[InlineCompletionReply, InlineCompletionStreamChunk]): + """Write a reply object to the WebSocket connection.""" + message = reply.dict() super().write_message(message) def initialize(self): @@ -144,7 +134,7 @@ async def handle_exc(self, e: Exception, request: InlineCompletionRequest): title=e.args[0] if e.args else "Exception", traceback=traceback.format_exc(), ) - self.write_message( + self.reply( InlineCompletionReply( list=InlineCompletionList(items=[]), error=error, diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py index eb03df156..38676b998 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py @@ -1,154 +1,24 @@ -from typing import Dict, Type - -from jupyter_ai_magics.providers import BaseProvider -from langchain.prompts import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - PromptTemplate, - SystemMessagePromptTemplate, -) -from langchain.schema.output_parser import StrOutputParser -from langchain.schema.runnable import Runnable - -from ..models import ( - InlineCompletionList, - InlineCompletionReply, - InlineCompletionRequest, - InlineCompletionStreamChunk, -) +from ..models import InlineCompletionRequest from .base import BaseInlineCompletionHandler class DefaultInlineCompletionHandler(BaseInlineCompletionHandler): - llm_chain: Runnable - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def create_llm_chain( - self, provider: Type[BaseProvider], provider_params: Dict[str, str] - ): - unified_parameters = { - **provider_params, - **(self.get_model_parameters(provider, provider_params)), - } - llm = provider(**unified_parameters) - - prompt_template = llm.get_completion_prompt_template() - - self.llm = llm - self.llm_chain = prompt_template | llm | StrOutputParser() - - async def handle_request(self, request: InlineCompletionRequest) -> None: + async def handle_request(self, request: InlineCompletionRequest): """Handles an inline completion request without streaming.""" - self.get_llm_chain() - model_arguments = self._template_inputs_from_request(request) - suggestion = await self.llm_chain.ainvoke(input=model_arguments) - suggestion = self._post_process_suggestion(suggestion, request) - self.write_message( - InlineCompletionReply( - list=InlineCompletionList(items=[{"insertText": suggestion}]), - reply_to=request.number, - ) - ) - - def _write_incomplete_reply(self, request: InlineCompletionRequest): - """Writes an incomplete `InlineCompletionReply`, indicating to the - client that LLM output is about to streamed across this connection. - Should be called first in `self.handle_stream_request()`.""" + llm = self.get_llm() + if not llm: + raise ValueError("Please select a model for inline completion.") - token = self._token_from_request(request, 0) - reply = InlineCompletionReply( - list=InlineCompletionList( - items=[ - { - # insert text starts empty as we do not pre-generate any part - "insertText": "", - "isIncomplete": True, - "token": token, - } - ] - ), - reply_to=request.number, - ) - self.write_message(reply) + reply = await llm.generate_inline_completions(request) + self.reply(reply) async def handle_stream_request(self, request: InlineCompletionRequest): - # first, send empty initial reply. - self._write_incomplete_reply(request) - - # then, generate and stream LLM output over this connection. - self.get_llm_chain() - token = self._token_from_request(request, 0) - model_arguments = self._template_inputs_from_request(request) - suggestion = "" - - async for fragment in self.llm_chain.astream(input=model_arguments): - suggestion += fragment - if suggestion.startswith("```"): - if "\n" not in suggestion: - # we are not ready to apply post-processing - continue - else: - suggestion = self._post_process_suggestion(suggestion, request) - self.write_message( - InlineCompletionStreamChunk( - type="stream", - response={"insertText": suggestion, "token": token}, - reply_to=request.number, - done=False, - ) - ) - - # finally, send a message confirming that we are done - self.write_message( - InlineCompletionStreamChunk( - type="stream", - response={"insertText": suggestion, "token": token}, - reply_to=request.number, - done=True, - ) - ) - - def _token_from_request(self, request: InlineCompletionRequest, suggestion: int): - """Generate a deterministic token (for matching streamed messages) - using request number and suggestion number""" - return f"t{request.number}s{suggestion}" - - def _template_inputs_from_request(self, request: InlineCompletionRequest) -> Dict: - suffix = request.suffix.strip() - filename = request.path.split("/")[-1] if request.path else "untitled" - - return { - "prefix": request.prefix, - "suffix": suffix, - "language": request.language, - "filename": filename, - "stop": ["\n```"], - } - - def _post_process_suggestion( - self, suggestion: str, request: InlineCompletionRequest - ) -> str: - """Remove spurious fragments from the suggestion. + llm = self.get_llm() + if not llm: + raise ValueError("Please select a model for inline completion.") - While most models (especially instruct and infill models do not require - any pre-processing, some models such as gpt-4 which only have chat APIs - may require removing spurious fragments. This function uses heuristics - and request data to remove such fragments. - """ - # gpt-4 tends to add "```python" or similar - language = request.language or "python" - markdown_identifiers = {"ipython": ["ipython", "python", "py"]} - bad_openings = [ - f"```{identifier}" - for identifier in markdown_identifiers.get(language, [language]) - ] + ["```"] - for opening in bad_openings: - if suggestion.startswith(opening): - suggestion = suggestion[len(opening) :].lstrip() - # check for the prefix inclusion (only if there was a bad opening) - if suggestion.startswith(request.prefix): - suggestion = suggestion[len(request.prefix) :] - break - return suggestion + async for reply in llm.stream_inline_completions(request): + self.reply(reply) diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py index fa16920d2..573caffdc 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Type +from logging import Logger +from typing import Any, Dict, Optional, Type from jupyter_ai.config_manager import ConfigManager from jupyter_ai_magics.providers import BaseProvider @@ -7,26 +8,24 @@ class LLMHandlerMixin: """Base class containing shared methods and attributes used by LLM handler classes.""" - # This could be used to derive `BaseChatHandler` too (there is a lot of duplication!), - # but it was decided against it to avoid introducing conflicts for backports against 1.x - handler_kind: str + settings: dict + log: Logger @property - def config_manager(self) -> ConfigManager: + def jai_config_manager(self) -> ConfigManager: return self.settings["jai_config_manager"] @property def model_parameters(self) -> Dict[str, Dict[str, Any]]: return self.settings["model_parameters"] - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.llm = None - self.llm_params = None - self.llm_chain = None + self._llm: Optional[BaseProvider] = None + self._llm_params = None - def get_llm_chain(self): + def get_llm_chain(self) -> Optional[BaseProvider]: lm_provider = self.config_manager.completions_lm_provider lm_provider_params = self.config_manager.completions_lm_provider_params @@ -34,7 +33,7 @@ def get_llm_chain(self): return None curr_lm_id = ( - f'{self.llm.id}:{lm_provider_params["model_id"]}' if self.llm else None + f'{self._llm.id}:{lm_provider_params["model_id"]}' if self._llm else None ) next_lm_id = ( f'{lm_provider.id}:{lm_provider_params["model_id"]}' @@ -42,19 +41,23 @@ def get_llm_chain(self): else None ) + should_recreate_llm = False if curr_lm_id != next_lm_id: self.log.info( f"Switching {self.handler_kind} language model from {curr_lm_id} to {next_lm_id}." ) - self.create_llm_chain(lm_provider, lm_provider_params) - elif self.llm_params != lm_provider_params: + should_recreate_llm = True + elif self._llm_params != lm_provider_params: self.log.info( f"{self.handler_kind} model params changed, updating the llm chain." ) - self.create_llm_chain(lm_provider, lm_provider_params) + should_recreate_llm = True + + if should_recreate_llm: + self._llm = self.create_llm(lm_provider, lm_provider_params) + self._llm_params = lm_provider_params - self.llm_params = lm_provider_params - return self.llm_chain + return self._llm def get_model_parameters( self, provider: Type[BaseProvider], provider_params: Dict[str, str] @@ -63,7 +66,13 @@ def get_model_parameters( f"{provider.id}:{provider_params['model_id']}", {} ) - def create_llm_chain( + def create_llm( self, provider: Type[BaseProvider], provider_params: Dict[str, str] - ): - raise NotImplementedError("Should be implemented by subclasses") + ) -> BaseProvider: + unified_parameters = { + **provider_params, + **(self.get_model_parameters(provider, provider_params)), + } + llm = provider(**unified_parameters) + + return llm diff --git a/packages/jupyter-ai/jupyter_ai/completions/models.py b/packages/jupyter-ai/jupyter_ai/completions/models.py index 507365408..e9679379e 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/models.py +++ b/packages/jupyter-ai/jupyter_ai/completions/models.py @@ -1,71 +1,17 @@ -from typing import List, Literal, Optional - -from langchain.pydantic_v1 import BaseModel - - -class InlineCompletionRequest(BaseModel): - """Message send by client to request inline completions. - - Prefix/suffix implementation is used to avoid the need for synchronising - the notebook state at every key press (subject to change in future).""" - - # unique message ID generated by the client used to identify replies and - # to easily discard replies for older requests - number: int - # prefix should include full text of the current cell preceding the cursor - prefix: str - # suffix should include full text of the current cell preceding the cursor - suffix: str - # media type for the current language, e.g. `text/x-python` - mime: str - # whether to stream the response (if supported by the model) - stream: bool - # path to the notebook of file for which the completions are generated - path: Optional[str] - # language inferred from the document mime type (if possible) - language: Optional[str] - # identifier of the cell for which the completions are generated if in a notebook - # previous cells and following cells can be used to learn the wider context - cell_id: Optional[str] - - -class InlineCompletionItem(BaseModel): - """The inline completion suggestion to be displayed on the frontend. - - See JuptyerLab `InlineCompletionItem` documentation for the details. - """ - - insertText: str - filterText: Optional[str] - isIncomplete: Optional[bool] - token: Optional[str] - - -class CompletionError(BaseModel): - type: str - traceback: str - - -class InlineCompletionList(BaseModel): - """Reflection of JupyterLab's `IInlineCompletionList`.""" - - items: List[InlineCompletionItem] - - -class InlineCompletionReply(BaseModel): - """Message sent from model to client with the infill suggestions""" - - list: InlineCompletionList - # number of request for which we are replying - reply_to: int - error: Optional[CompletionError] - - -class InlineCompletionStreamChunk(BaseModel): - """Message sent from model to client with the infill suggestions""" - - type: Literal["stream"] = "stream" - response: InlineCompletionItem - reply_to: int - done: bool - error: Optional[CompletionError] +from jupyter_ai_magics.models.completion import ( + CompletionError, + InlineCompletionItem, + InlineCompletionList, + InlineCompletionReply, + InlineCompletionRequest, + InlineCompletionStreamChunk, +) + +__all__ = [ + "InlineCompletionRequest", + "InlineCompletionItem", + "CompletionError", + "InlineCompletionList", + "InlineCompletionReply", + "InlineCompletionStreamChunk", +] diff --git a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py index 6607d97d6..561f00a1c 100644 --- a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py +++ b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py @@ -8,13 +8,12 @@ from langchain.document_loaders import PyPDFLoader from langchain.schema import Document from langchain.text_splitter import TextSplitter -from pypdf import PdfReader # Uses pypdf which is used by PyPDFLoader from langchain def pdf_to_text(path): - reader = PdfReader(path) - text = "\n \n".join([page.extract_text() for page in reader.pages]) + pages = PyPDFLoader(path) + text = "\n \n".join([page.page_content for page in pages.load_and_split()]) return text diff --git a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py index 0028356a3..c5b5d1eea 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py @@ -1,8 +1,14 @@ import json from types import SimpleNamespace +from typing import Union +import pytest from jupyter_ai.completions.handlers.default import DefaultInlineCompletionHandler -from jupyter_ai.completions.models import InlineCompletionRequest +from jupyter_ai.completions.models import ( + InlineCompletionReply, + InlineCompletionRequest, + InlineCompletionStreamChunk, +) from jupyter_ai_magics import BaseProvider from langchain_community.llms import FakeListLLM from pytest import fixture @@ -17,28 +23,31 @@ class MockProvider(BaseProvider, FakeListLLM): models = ["model"] def __init__(self, **kwargs): - kwargs["responses"] = ["Test response"] + if "responses" not in kwargs: + kwargs["responses"] = ["Test response"] super().__init__(**kwargs) class MockCompletionHandler(DefaultInlineCompletionHandler): - def __init__(self): + def __init__(self, lm_provider=None, lm_provider_params=None): self.request = HTTPServerRequest() self.application = Application() self.messages = [] self.tasks = [] self.settings["jai_config_manager"] = SimpleNamespace( - completions_lm_provider=MockProvider, - completions_lm_provider_params={"model_id": "model"}, + completions_lm_provider=lm_provider or MockProvider, + completions_lm_provider_params=lm_provider_params or {"model_id": "model"}, ) self.settings["jai_event_loop"] = SimpleNamespace( create_task=lambda x: self.tasks.append(x) ) self.settings["model_parameters"] = {} - self.llm_params = {} - self.create_llm_chain(MockProvider, {"model_id": "model"}) + self._llm_params = {} + self._llm = None - def write_message(self, message: str) -> None: # type: ignore + def reply( + self, message: Union[InlineCompletionReply, InlineCompletionStreamChunk] + ) -> None: self.messages.append(message) async def handle_exc(self, e: Exception, _request: InlineCompletionRequest): @@ -89,8 +98,44 @@ async def test_handle_request(inline_handler): assert suggestions[0].insertText == "Test response" -async def test_handle_stream_request(inline_handler): - inline_handler.llm_chain = FakeListLLM(responses=["test"]) +@pytest.mark.parametrize( + "response,expected_suggestion", + [ + ("```python\nTest python code\n```", "Test python code"), + ("```\ntest\n```\n \n", "test"), + ("```hello```world```", "hello```world"), + ], +) +async def test_handle_request_with_spurious_fragments(response, expected_suggestion): + inline_handler = MockCompletionHandler( + lm_provider=MockProvider, + lm_provider_params={ + "model_id": "model", + "responses": [response], + }, + ) + dummy_request = InlineCompletionRequest( + number=1, prefix="", suffix="", mime="", stream=False + ) + + await inline_handler.handle_request(dummy_request) + # should write a single reply + assert len(inline_handler.messages) == 1 + # reply should contain a single suggestion + suggestions = inline_handler.messages[0].list.items + assert len(suggestions) == 1 + # the suggestion should include insert text from LLM without spurious fragments + assert suggestions[0].insertText == expected_suggestion + + +async def test_handle_stream_request(): + inline_handler = MockCompletionHandler( + lm_provider=MockProvider, + lm_provider_params={ + "model_id": "model", + "responses": ["test"], + }, + ) dummy_request = InlineCompletionRequest( number=1, prefix="", suffix="", mime="", stream=True ) @@ -102,16 +147,16 @@ async def test_handle_stream_request(inline_handler): # first reply should be empty to start the stream first = inline_handler.messages[0].list.items[0] assert first.insertText == "" - assert first.isIncomplete == True + assert first.isIncomplete is True # second reply should be a chunk containing the token second = inline_handler.messages[1] assert second.type == "stream" - assert second.response.insertText == "Test response" - assert second.done == False + assert second.response.insertText == "test" + assert second.done is False # third reply should be a closing chunk third = inline_handler.messages[2] assert third.type == "stream" - assert third.response.insertText == "Test response" - assert third.done == True + assert third.response.insertText == "test" + assert third.done is True diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index 763963c9f..f5eb5e98f 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "jupyterlab~=4.0", "aiosqlite>=0.18", "importlib_metadata>=5.2.0", - "jupyter_ai_magics", + "jupyter_ai_magics>=2.13.0", "dask[distributed]", "faiss-cpu", # Not distributed by official repo "typing_extensions>=4.5.0", @@ -54,7 +54,7 @@ test = [ dev = ["jupyter_ai_magics[dev]"] -all = ["jupyter_ai_magics[all]"] +all = ["jupyter_ai_magics[all]", "pypdf"] [tool.hatch.version] source = "nodejs"