-
-
Notifications
You must be signed in to change notification settings - Fork 340
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Draft inline completions implementation (server side) * Implement inline completion provider (front) * Add default debounce delay and error handling (front) * Add `gpt-3.5-turbo-instruct` because text- models are deprecated. OpenAI specifically recommends using `gpt-3.5-turbo-instruct` in favour of text-davinci, text-ada, etc. See: https://platform.openai.com/docs/deprecations/ * Improve/fix prompt template and add simple post-processing * Handle missing `registerInlineProvider`, handle no model in name * Remove IPython mention to avoid confusing languages * Disable suggestions in markdown, move language logic * Remove unused background and clip path from jupyternaut * Implement toggling the AI completer via statusbar item also adds the icon for provider re-using jupyternaut icon * Implement streaming support * Translate ipython to python for models, remove log * Move `BaseLLMHandler` to `/completions` rename to `LLMHandlerMixin` * Move frontend completions code to `/completions` * Make `IStatusBar` required for now, lint
- Loading branch information
1 parent
f9c8033
commit 5d28fea
Showing
22 changed files
with
1,934 additions
and
1,030 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
4 changes: 4 additions & 0 deletions
4
packages/jupyter-ai/jupyter_ai/completions/handlers/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .base import BaseInlineCompletionHandler | ||
from .default import DefaultInlineCompletionHandler | ||
|
||
__all__ = ["BaseInlineCompletionHandler", "DefaultInlineCompletionHandler"] |
76 changes: 76 additions & 0 deletions
76
packages/jupyter-ai/jupyter_ai/completions/handlers/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import traceback | ||
|
||
# necessary to prevent circular import | ||
from typing import TYPE_CHECKING, AsyncIterator, Dict | ||
|
||
from jupyter_ai.completions.handlers.llm_mixin import LLMHandlerMixin | ||
from jupyter_ai.completions.models import ( | ||
CompletionError, | ||
InlineCompletionList, | ||
InlineCompletionReply, | ||
InlineCompletionRequest, | ||
InlineCompletionStreamChunk, | ||
ModelChangedNotification, | ||
) | ||
from jupyter_ai.config_manager import ConfigManager, Logger | ||
|
||
if TYPE_CHECKING: | ||
from jupyter_ai.handlers import InlineCompletionHandler | ||
|
||
|
||
class BaseInlineCompletionHandler(LLMHandlerMixin): | ||
"""Class implementing completion handling.""" | ||
|
||
handler_kind = "completion" | ||
|
||
def __init__( | ||
self, | ||
log: Logger, | ||
config_manager: ConfigManager, | ||
model_parameters: Dict[str, Dict], | ||
ws_sessions: Dict[str, "InlineCompletionHandler"], | ||
): | ||
super().__init__(log, config_manager, model_parameters) | ||
self.ws_sessions = ws_sessions | ||
|
||
async def on_message( | ||
self, message: InlineCompletionRequest | ||
) -> InlineCompletionReply: | ||
try: | ||
return await self.process_message(message) | ||
except Exception as e: | ||
return await self._handle_exc(e, message) | ||
|
||
async def process_message( | ||
self, message: InlineCompletionRequest | ||
) -> InlineCompletionReply: | ||
""" | ||
Processes an inline completion request. Completion handlers | ||
(subclasses) must implement this method. | ||
The method definition does not need to be wrapped in a try/except block. | ||
""" | ||
raise NotImplementedError("Should be implemented by subclasses.") | ||
|
||
async def stream( | ||
self, message: InlineCompletionRequest | ||
) -> AsyncIterator[InlineCompletionStreamChunk]: | ||
""" " | ||
Stream the inline completion as it is generated. Completion handlers | ||
(subclasses) can implement this method. | ||
""" | ||
raise NotImplementedError() | ||
|
||
async def _handle_exc(self, e: Exception, message: InlineCompletionRequest): | ||
error = CompletionError( | ||
type=e.__class__.__name__, | ||
title=e.args[0] if e.args else "Exception", | ||
traceback=traceback.format_exc(), | ||
) | ||
return InlineCompletionReply( | ||
list=InlineCompletionList(items=[]), error=error, reply_to=message.number | ||
) | ||
|
||
def broadcast(self, message: ModelChangedNotification): | ||
for session in self.ws_sessions.values(): | ||
session.write_message(message.dict()) |
184 changes: 184 additions & 0 deletions
184
packages/jupyter-ai/jupyter_ai/completions/handlers/default.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
from typing import Dict, Type | ||
|
||
from jupyter_ai_magics.providers import BaseProvider | ||
from langchain.prompts import ( | ||
ChatPromptTemplate, | ||
HumanMessagePromptTemplate, | ||
PromptTemplate, | ||
SystemMessagePromptTemplate, | ||
) | ||
from langchain.schema.output_parser import StrOutputParser | ||
from langchain.schema.runnable import Runnable | ||
|
||
from ..models import ( | ||
InlineCompletionList, | ||
InlineCompletionReply, | ||
InlineCompletionRequest, | ||
InlineCompletionStreamChunk, | ||
ModelChangedNotification, | ||
) | ||
from .base import BaseInlineCompletionHandler | ||
|
||
SYSTEM_PROMPT = """ | ||
You are an application built to provide helpful code completion suggestions. | ||
You should only produce code. Keep comments to minimum, use the | ||
programming language comment syntax. Produce clean code. | ||
The code is written in JupyterLab, a data analysis and code development | ||
environment which can execute code extended with additional syntax for | ||
interactive features, such as magics. | ||
""".strip() | ||
|
||
AFTER_TEMPLATE = """ | ||
The code after the completion request is: | ||
``` | ||
{suffix} | ||
``` | ||
""".strip() | ||
|
||
DEFAULT_TEMPLATE = """ | ||
The document is called `{filename}` and written in {language}. | ||
{after} | ||
Complete the following code: | ||
``` | ||
{prefix}""" | ||
|
||
|
||
class DefaultInlineCompletionHandler(BaseInlineCompletionHandler): | ||
llm_chain: Runnable | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
def create_llm_chain( | ||
self, provider: Type[BaseProvider], provider_params: Dict[str, str] | ||
): | ||
lm_provider = self.config_manager.lm_provider | ||
lm_provider_params = self.config_manager.lm_provider_params | ||
next_lm_id = ( | ||
f'{lm_provider.id}:{lm_provider_params["model_id"]}' | ||
if lm_provider | ||
else None | ||
) | ||
self.broadcast(ModelChangedNotification(model=next_lm_id)) | ||
|
||
model_parameters = self.get_model_parameters(provider, provider_params) | ||
llm = provider(**provider_params, **model_parameters) | ||
|
||
if llm.is_chat_provider: | ||
prompt_template = ChatPromptTemplate.from_messages( | ||
[ | ||
SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT), | ||
HumanMessagePromptTemplate.from_template(DEFAULT_TEMPLATE), | ||
] | ||
) | ||
else: | ||
prompt_template = PromptTemplate( | ||
input_variables=["prefix", "suffix", "language", "filename"], | ||
template=SYSTEM_PROMPT + "\n\n" + DEFAULT_TEMPLATE, | ||
) | ||
|
||
self.llm = llm | ||
self.llm_chain = prompt_template | llm | StrOutputParser() | ||
|
||
async def process_message( | ||
self, request: InlineCompletionRequest | ||
) -> InlineCompletionReply: | ||
if request.stream: | ||
token = self._token_from_request(request, 0) | ||
return InlineCompletionReply( | ||
list=InlineCompletionList( | ||
items=[ | ||
{ | ||
# insert text starts empty as we do not pre-generate any part | ||
"insertText": "", | ||
"isIncomplete": True, | ||
"token": token, | ||
} | ||
] | ||
), | ||
reply_to=request.number, | ||
) | ||
else: | ||
self.get_llm_chain() | ||
model_arguments = self._template_inputs_from_request(request) | ||
suggestion = await self.llm_chain.ainvoke(input=model_arguments) | ||
suggestion = self._post_process_suggestion(suggestion, request) | ||
return InlineCompletionReply( | ||
list=InlineCompletionList(items=[{"insertText": suggestion}]), | ||
reply_to=request.number, | ||
) | ||
|
||
async def stream(self, request: InlineCompletionRequest): | ||
self.get_llm_chain() | ||
token = self._token_from_request(request, 0) | ||
model_arguments = self._template_inputs_from_request(request) | ||
suggestion = "" | ||
async for fragment in self.llm_chain.astream(input=model_arguments): | ||
suggestion += fragment | ||
if suggestion.startswith("```"): | ||
if "\n" not in suggestion: | ||
# we are not ready to apply post-processing | ||
continue | ||
else: | ||
suggestion = self._post_process_suggestion(suggestion, request) | ||
yield InlineCompletionStreamChunk( | ||
type="stream", | ||
response={"insertText": suggestion, "token": token}, | ||
reply_to=request.number, | ||
done=False, | ||
) | ||
# at the end send a message confirming that we are done | ||
yield InlineCompletionStreamChunk( | ||
type="stream", | ||
response={"insertText": suggestion, "token": token}, | ||
reply_to=request.number, | ||
done=True, | ||
) | ||
|
||
def _token_from_request(self, request: InlineCompletionRequest, suggestion: int): | ||
"""Generate a deterministic token (for matching streamed messages) | ||
using request number and suggestion number""" | ||
return f"t{request.number}s{suggestion}" | ||
|
||
def _template_inputs_from_request(self, request: InlineCompletionRequest) -> Dict: | ||
suffix = request.suffix.strip() | ||
# only add the suffix template if the suffix is there to save input tokens/computation time | ||
after = AFTER_TEMPLATE.format(suffix=suffix) if suffix else "" | ||
filename = request.path.split("/")[-1] if request.path else "untitled" | ||
|
||
return { | ||
"prefix": request.prefix, | ||
"after": after, | ||
"language": request.language, | ||
"filename": filename, | ||
"stop": ["\n```"], | ||
} | ||
|
||
def _post_process_suggestion( | ||
self, suggestion: str, request: InlineCompletionRequest | ||
) -> str: | ||
"""Remove spurious fragments from the suggestion. | ||
While most models (especially instruct and infill models do not require | ||
any pre-processing, some models such as gpt-4 which only have chat APIs | ||
may require removing spurious fragments. This function uses heuristics | ||
and request data to remove such fragments. | ||
""" | ||
# gpt-4 tends to add "```python" or similar | ||
language = request.language or "python" | ||
markdown_identifiers = {"ipython": ["ipython", "python", "py"]} | ||
bad_openings = [ | ||
f"```{identifier}" | ||
for identifier in markdown_identifiers.get(language, [language]) | ||
] + ["```"] | ||
for opening in bad_openings: | ||
if suggestion.startswith(opening): | ||
suggestion = suggestion[len(opening) :].lstrip() | ||
# check for the prefix inclusion (only if there was a bad opening) | ||
if suggestion.startswith(request.prefix): | ||
suggestion = suggestion[len(request.prefix) :] | ||
break | ||
return suggestion |
73 changes: 73 additions & 0 deletions
73
packages/jupyter-ai/jupyter_ai/completions/handlers/llm_mixin.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from typing import Dict, Type | ||
|
||
from jupyter_ai.config_manager import ConfigManager, Logger | ||
from jupyter_ai_magics.providers import BaseProvider | ||
|
||
|
||
class LLMHandlerMixin: | ||
"""Base class containing shared methods and attributes used by LLM handler classes.""" | ||
|
||
# This could be used to derive `BaseChatHandler` too (there is a lot of duplication!), | ||
# but it was decided against it to avoid introducing conflicts for backports against 1.x | ||
|
||
handler_kind: str | ||
|
||
def __init__( | ||
self, | ||
log: Logger, | ||
config_manager: ConfigManager, | ||
model_parameters: Dict[str, Dict], | ||
): | ||
self.log = log | ||
self.config_manager = config_manager | ||
self.model_parameters = model_parameters | ||
self.llm = None | ||
self.llm_params = None | ||
self.llm_chain = None | ||
|
||
def model_changed_callback(self): | ||
"""Method which can be overridden in sub-classes to listen to model change.""" | ||
pass | ||
|
||
def get_llm_chain(self): | ||
lm_provider = self.config_manager.lm_provider | ||
lm_provider_params = self.config_manager.lm_provider_params | ||
|
||
curr_lm_id = ( | ||
f'{self.llm.id}:{lm_provider_params["model_id"]}' if self.llm else None | ||
) | ||
next_lm_id = ( | ||
f'{lm_provider.id}:{lm_provider_params["model_id"]}' | ||
if lm_provider | ||
else None | ||
) | ||
|
||
if not lm_provider or not lm_provider_params: | ||
return None | ||
|
||
if curr_lm_id != next_lm_id: | ||
self.log.info( | ||
f"Switching {self.handler_kind} language model from {curr_lm_id} to {next_lm_id}." | ||
) | ||
self.create_llm_chain(lm_provider, lm_provider_params) | ||
self.model_changed_callback() | ||
elif self.llm_params != lm_provider_params: | ||
self.log.info( | ||
f"{self.handler_kind} model params changed, updating the llm chain." | ||
) | ||
self.create_llm_chain(lm_provider, lm_provider_params) | ||
|
||
self.llm_params = lm_provider_params | ||
return self.llm_chain | ||
|
||
def get_model_parameters( | ||
self, provider: Type[BaseProvider], provider_params: Dict[str, str] | ||
): | ||
return self.model_parameters.get( | ||
f"{provider.id}:{provider_params['model_id']}", {} | ||
) | ||
|
||
def create_llm_chain( | ||
self, provider: Type[BaseProvider], provider_params: Dict[str, str] | ||
): | ||
raise NotImplementedError("Should be implemented by subclasses") |
Oops, something went wrong.