-
-
Notifications
You must be signed in to change notification settings - Fork 340
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Inline code completions (#465) * Draft inline completions implementation (server side) * Implement inline completion provider (front) * Add default debounce delay and error handling (front) * Add `gpt-3.5-turbo-instruct` because text- models are deprecated. OpenAI specifically recommends using `gpt-3.5-turbo-instruct` in favour of text-davinci, text-ada, etc. See: https://platform.openai.com/docs/deprecations/ * Improve/fix prompt template and add simple post-processing * Handle missing `registerInlineProvider`, handle no model in name * Remove IPython mention to avoid confusing languages * Disable suggestions in markdown, move language logic * Remove unused background and clip path from jupyternaut * Implement toggling the AI completer via statusbar item also adds the icon for provider re-using jupyternaut icon * Implement streaming support * Translate ipython to python for models, remove log * Move `BaseLLMHandler` to `/completions` rename to `LLMHandlerMixin` * Move frontend completions code to `/completions` * Make `IStatusBar` required for now, lint * Simplify inline completion backend (#553) * do not import from pydantic directly * refactor inline completion backend * Autocomplete frontend fixes (#583) * remove duplicate definition of inline completion provider * rename completion variables, plugins, token to be more accurate * abbreviate JupyterAIInlineProvider => JaiInlineProvider * bump @jupyterlab/completer and typescript * WIP: fix Jupyter AI completion settings * Fix issues with settings population * read from settings directly instead of using a cache * disable Jupyter AI completion by default * improve completion plugin menu items * revert unnecessary edits to package manifest * Update packages/jupyter-ai/src/components/statusbar-item.tsx Co-authored-by: Michał Krassowski <[email protected]> * tweak wording --------- Co-authored-by: krassowski <[email protected]> --------- Co-authored-by: David L. Qiu <[email protected]>
- Loading branch information
1 parent
f56496c
commit d1a9f40
Showing
23 changed files
with
2,460 additions
and
1,373 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
4 changes: 4 additions & 0 deletions
4
packages/jupyter-ai/jupyter_ai/completions/handlers/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .base import BaseInlineCompletionHandler | ||
from .default import DefaultInlineCompletionHandler | ||
|
||
__all__ = ["BaseInlineCompletionHandler", "DefaultInlineCompletionHandler"] |
169 changes: 169 additions & 0 deletions
169
packages/jupyter-ai/jupyter_ai/completions/handlers/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import json | ||
import time | ||
import traceback | ||
from asyncio import AbstractEventLoop | ||
from typing import Any, AsyncIterator, Dict, Union | ||
|
||
import tornado | ||
from jupyter_ai.completions.handlers.llm_mixin import LLMHandlerMixin | ||
from jupyter_ai.completions.models import ( | ||
CompletionError, | ||
InlineCompletionList, | ||
InlineCompletionReply, | ||
InlineCompletionRequest, | ||
InlineCompletionStreamChunk, | ||
) | ||
from jupyter_server.base.handlers import JupyterHandler | ||
from langchain.pydantic_v1 import BaseModel, ValidationError | ||
|
||
|
||
class BaseInlineCompletionHandler( | ||
LLMHandlerMixin, JupyterHandler, tornado.websocket.WebSocketHandler | ||
): | ||
"""A Tornado WebSocket handler that receives inline completion requests and | ||
fulfills them accordingly. This class is instantiated once per WebSocket | ||
connection.""" | ||
|
||
## | ||
# Interface for subclasses | ||
## | ||
async def handle_request( | ||
self, message: InlineCompletionRequest | ||
) -> InlineCompletionReply: | ||
""" | ||
Handles an inline completion request, without streaming. Subclasses | ||
must define this method and write a reply via `self.write_message()`. | ||
The method definition does not need to be wrapped in a try/except block. | ||
""" | ||
raise NotImplementedError( | ||
"The required method `self.handle_request()` is not defined by this subclass." | ||
) | ||
|
||
async def handle_stream_request( | ||
self, message: InlineCompletionRequest | ||
) -> AsyncIterator[InlineCompletionStreamChunk]: | ||
""" | ||
Handles an inline completion request, **with streaming**. | ||
Implementations may optionally define this method. Implementations that | ||
do so should stream replies via successive calls to | ||
`self.write_message()`. | ||
The method definition does not need to be wrapped in a try/except block. | ||
""" | ||
raise NotImplementedError( | ||
"The optional method `self.handle_stream_request()` is not defined by this subclass." | ||
) | ||
|
||
## | ||
# Definition of base class | ||
## | ||
handler_kind = "completion" | ||
|
||
@property | ||
def loop(self) -> AbstractEventLoop: | ||
return self.settings["jai_event_loop"] | ||
|
||
def write_message(self, message: Union[bytes, str, Dict[str, Any], BaseModel]): | ||
""" | ||
Write a bytes, string, dict, or Pydantic model object to the WebSocket | ||
connection. The base definition of this method is provided by Tornado. | ||
""" | ||
if isinstance(message, BaseModel): | ||
message = message.dict() | ||
|
||
super().write_message(message) | ||
|
||
def initialize(self): | ||
self.log.debug("Initializing websocket connection %s", self.request.path) | ||
|
||
def pre_get(self): | ||
"""Handles authentication/authorization.""" | ||
# authenticate the request before opening the websocket | ||
user = self.current_user | ||
if user is None: | ||
self.log.warning("Couldn't authenticate WebSocket connection") | ||
raise tornado.web.HTTPError(403) | ||
|
||
# authorize the user. | ||
if not self.authorizer.is_authorized(self, user, "execute", "events"): | ||
raise tornado.web.HTTPError(403) | ||
|
||
async def get(self, *args, **kwargs): | ||
"""Get an event socket.""" | ||
self.pre_get() | ||
res = super().get(*args, **kwargs) | ||
await res | ||
|
||
async def on_message(self, message): | ||
"""Public Tornado method that is called when the client sends a message | ||
over this connection. This should **not** be overriden by subclasses.""" | ||
|
||
# first, verify that the message is an `InlineCompletionRequest`. | ||
self.log.debug("Message received: %s", message) | ||
try: | ||
message = json.loads(message) | ||
request = InlineCompletionRequest(**message) | ||
except ValidationError as e: | ||
self.log.error(e) | ||
return | ||
|
||
# next, dispatch the request to the correct handler and create the | ||
# `handle_request` coroutine object | ||
handle_request = None | ||
if request.stream: | ||
try: | ||
handle_request = self._handle_stream_request(request) | ||
except NotImplementedError: | ||
self.log.error( | ||
"Unable to handle stream request. The current `InlineCompletionHandler` does not implement the `handle_stream_request()` method." | ||
) | ||
return | ||
|
||
else: | ||
handle_request = self._handle_request(request) | ||
|
||
# finally, wrap `handle_request` in an exception handler, and start the | ||
# task on the event loop. | ||
async def handle_request_and_catch(): | ||
try: | ||
await handle_request | ||
except Exception as e: | ||
await self.handle_exc(e, request) | ||
|
||
self.loop.create_task(handle_request_and_catch()) | ||
|
||
async def handle_exc(self, e: Exception, request: InlineCompletionRequest): | ||
""" | ||
Handles an exception raised in either `handle_request()` or | ||
`handle_stream_request()`. This base class provides a default | ||
implementation, which may be overriden by subclasses. | ||
""" | ||
error = CompletionError( | ||
type=e.__class__.__name__, | ||
title=e.args[0] if e.args else "Exception", | ||
traceback=traceback.format_exc(), | ||
) | ||
self.write_message( | ||
InlineCompletionReply( | ||
list=InlineCompletionList(items=[]), | ||
error=error, | ||
reply_to=request.number, | ||
) | ||
) | ||
|
||
async def _handle_request(self, request: InlineCompletionRequest): | ||
"""Private wrapper around `self.handle_request()`.""" | ||
start = time.time() | ||
await self.handle_request(request) | ||
latency_ms = round((time.time() - start) * 1000) | ||
self.log.info(f"Inline completion handler resolved in {latency_ms} ms.") | ||
|
||
async def _handle_stream_request(self, request: InlineCompletionRequest): | ||
"""Private wrapper around `self.handle_stream_request()`.""" | ||
start = time.time() | ||
await self._handle_stream_request(request) | ||
async for chunk in self.stream(request): | ||
self.write_message(chunk.dict()) | ||
latency_ms = round((time.time() - start) * 1000) | ||
self.log.info(f"Inline completion streaming completed in {latency_ms} ms.") |
192 changes: 192 additions & 0 deletions
192
packages/jupyter-ai/jupyter_ai/completions/handlers/default.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
from typing import Dict, Type | ||
|
||
from jupyter_ai_magics.providers import BaseProvider | ||
from langchain.prompts import ( | ||
ChatPromptTemplate, | ||
HumanMessagePromptTemplate, | ||
PromptTemplate, | ||
SystemMessagePromptTemplate, | ||
) | ||
from langchain.schema.output_parser import StrOutputParser | ||
from langchain.schema.runnable import Runnable | ||
|
||
from ..models import ( | ||
InlineCompletionList, | ||
InlineCompletionReply, | ||
InlineCompletionRequest, | ||
InlineCompletionStreamChunk, | ||
) | ||
from .base import BaseInlineCompletionHandler | ||
|
||
SYSTEM_PROMPT = """ | ||
You are an application built to provide helpful code completion suggestions. | ||
You should only produce code. Keep comments to minimum, use the | ||
programming language comment syntax. Produce clean code. | ||
The code is written in JupyterLab, a data analysis and code development | ||
environment which can execute code extended with additional syntax for | ||
interactive features, such as magics. | ||
""".strip() | ||
|
||
AFTER_TEMPLATE = """ | ||
The code after the completion request is: | ||
``` | ||
{suffix} | ||
``` | ||
""".strip() | ||
|
||
DEFAULT_TEMPLATE = """ | ||
The document is called `{filename}` and written in {language}. | ||
{after} | ||
Complete the following code: | ||
``` | ||
{prefix}""" | ||
|
||
|
||
class DefaultInlineCompletionHandler(BaseInlineCompletionHandler): | ||
llm_chain: Runnable | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
def create_llm_chain( | ||
self, provider: Type[BaseProvider], provider_params: Dict[str, str] | ||
): | ||
model_parameters = self.get_model_parameters(provider, provider_params) | ||
llm = provider(**provider_params, **model_parameters) | ||
|
||
if llm.is_chat_provider: | ||
prompt_template = ChatPromptTemplate.from_messages( | ||
[ | ||
SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT), | ||
HumanMessagePromptTemplate.from_template(DEFAULT_TEMPLATE), | ||
] | ||
) | ||
else: | ||
prompt_template = PromptTemplate( | ||
input_variables=["prefix", "suffix", "language", "filename"], | ||
template=SYSTEM_PROMPT + "\n\n" + DEFAULT_TEMPLATE, | ||
) | ||
|
||
self.llm = llm | ||
self.llm_chain = prompt_template | llm | StrOutputParser() | ||
|
||
async def handle_request( | ||
self, request: InlineCompletionRequest | ||
) -> InlineCompletionReply: | ||
"""Handles an inline completion request without streaming.""" | ||
self.get_llm_chain() | ||
model_arguments = self._template_inputs_from_request(request) | ||
suggestion = await self.llm_chain.ainvoke(input=model_arguments) | ||
suggestion = self._post_process_suggestion(suggestion, request) | ||
self.write_message( | ||
InlineCompletionReply( | ||
list=InlineCompletionList(items=[{"insertText": suggestion}]), | ||
reply_to=request.number, | ||
) | ||
) | ||
|
||
def _write_incomplete_reply(self, request: InlineCompletionRequest): | ||
"""Writes an incomplete `InlineCompletionReply`, indicating to the | ||
client that LLM output is about to streamed across this connection. | ||
Should be called first in `self.handle_stream_request()`.""" | ||
|
||
token = self._token_from_request(request, 0) | ||
reply = InlineCompletionReply( | ||
list=InlineCompletionList( | ||
items=[ | ||
{ | ||
# insert text starts empty as we do not pre-generate any part | ||
"insertText": "", | ||
"isIncomplete": True, | ||
"token": token, | ||
} | ||
] | ||
), | ||
reply_to=request.number, | ||
) | ||
self.write_message(reply) | ||
|
||
async def handle_stream_request(self, request: InlineCompletionRequest): | ||
# first, send empty initial reply. | ||
self._write_incomplete_reply() | ||
|
||
# then, generate and stream LLM output over this connection. | ||
self.get_llm_chain() | ||
token = self._token_from_request(request, 0) | ||
model_arguments = self._template_inputs_from_request(request) | ||
suggestion = "" | ||
|
||
async for fragment in self.llm_chain.astream(input=model_arguments): | ||
suggestion += fragment | ||
if suggestion.startswith("```"): | ||
if "\n" not in suggestion: | ||
# we are not ready to apply post-processing | ||
continue | ||
else: | ||
suggestion = self._post_process_suggestion(suggestion, request) | ||
self.write_message( | ||
InlineCompletionStreamChunk( | ||
type="stream", | ||
response={"insertText": suggestion, "token": token}, | ||
reply_to=request.number, | ||
done=False, | ||
) | ||
) | ||
|
||
# finally, send a message confirming that we are done | ||
self.write_message( | ||
InlineCompletionStreamChunk( | ||
type="stream", | ||
response={"insertText": suggestion, "token": token}, | ||
reply_to=request.number, | ||
done=True, | ||
) | ||
) | ||
|
||
def _token_from_request(self, request: InlineCompletionRequest, suggestion: int): | ||
"""Generate a deterministic token (for matching streamed messages) | ||
using request number and suggestion number""" | ||
return f"t{request.number}s{suggestion}" | ||
|
||
def _template_inputs_from_request(self, request: InlineCompletionRequest) -> Dict: | ||
suffix = request.suffix.strip() | ||
# only add the suffix template if the suffix is there to save input tokens/computation time | ||
after = AFTER_TEMPLATE.format(suffix=suffix) if suffix else "" | ||
filename = request.path.split("/")[-1] if request.path else "untitled" | ||
|
||
return { | ||
"prefix": request.prefix, | ||
"after": after, | ||
"language": request.language, | ||
"filename": filename, | ||
"stop": ["\n```"], | ||
} | ||
|
||
def _post_process_suggestion( | ||
self, suggestion: str, request: InlineCompletionRequest | ||
) -> str: | ||
"""Remove spurious fragments from the suggestion. | ||
While most models (especially instruct and infill models do not require | ||
any pre-processing, some models such as gpt-4 which only have chat APIs | ||
may require removing spurious fragments. This function uses heuristics | ||
and request data to remove such fragments. | ||
""" | ||
# gpt-4 tends to add "```python" or similar | ||
language = request.language or "python" | ||
markdown_identifiers = {"ipython": ["ipython", "python", "py"]} | ||
bad_openings = [ | ||
f"```{identifier}" | ||
for identifier in markdown_identifiers.get(language, [language]) | ||
] + ["```"] | ||
for opening in bad_openings: | ||
if suggestion.startswith(opening): | ||
suggestion = suggestion[len(opening) :].lstrip() | ||
# check for the prefix inclusion (only if there was a bad opening) | ||
if suggestion.startswith(request.prefix): | ||
suggestion = suggestion[len(request.prefix) :] | ||
break | ||
return suggestion |
Oops, something went wrong.