Skip to content

Commit

Permalink
Merge branch 'main' into separate-completer-and-chat-settings
Browse files Browse the repository at this point in the history
  • Loading branch information
krassowski authored Apr 25, 2024
2 parents b06481a + 9c8046c commit 0edb11a
Show file tree
Hide file tree
Showing 14 changed files with 469 additions and 303 deletions.
78 changes: 78 additions & 0 deletions docs/source/developers/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,84 @@ my-provider = "my_provider:MyEmbeddingsProvider"

[Embeddings]: https://api.python.langchain.com/en/stable/embeddings/langchain_core.embeddings.Embeddings.html


### Custom completion providers

Any model provider derived from `BaseProvider` can be used as a completion provider.
However, some providers may benefit from customizing handling of completion requests.

There are two asynchronous methods which can be overridden in subclasses of `BaseProvider`:
- `generate_inline_completions`: takes a request (`InlineCompletionRequest`) and returns `InlineCompletionReply`
- `stream_inline_completions`: takes a request and yields an initiating reply (`InlineCompletionReply`) with `isIncomplete` set to `True` followed by subsequent chunks (`InlineCompletionStreamChunk`)

When streaming all replies and chunks for given invocation of the `stream_inline_completions()` method should include a constant and unique string token identifying the stream. All chunks except for the last chunk for a given item should have the `done` value set to `False`.

The following example demonstrates a custom implementation of the completion provider with both a method for sending multiple completions in one go, and streaming multiple completions concurrently.
The implementation and explanation for the `merge_iterators` function used in this example can be found [here](https://stackoverflow.com/q/72445371/4877269).

```python
class MyCompletionProvider(BaseProvider, FakeListLLM):
id = "my_provider"
name = "My Provider"
model_id_key = "model"
models = ["model_a"]

def __init__(self, **kwargs):
kwargs["responses"] = ["This fake response will not be used for completion"]
super().__init__(**kwargs)

async def generate_inline_completions(self, request: InlineCompletionRequest):
return InlineCompletionReply(
list=InlineCompletionList(items=[
{"insertText": "An ant minding its own business"},
{"insertText": "A bug searching for a snack"}
]),
reply_to=request.number,
)

async def stream_inline_completions(self, request: InlineCompletionRequest):
token_1 = f"t{request.number}s0"
token_2 = f"t{request.number}s1"

yield InlineCompletionReply(
list=InlineCompletionList(
items=[
{"insertText": "An ", "isIncomplete": True, "token": token_1},
{"insertText": "", "isIncomplete": True, "token": token_2}
]
),
reply_to=request.number,
)

# where merge_iterators
async for reply in merge_iterators([
self._stream("elephant dancing in the rain", request.number, token_1, start_with="An"),
self._stream("A flock of birds flying around a mountain", request.number, token_2)
]):
yield reply

async def _stream(self, sentence, request_number, token, start_with = ""):
suggestion = start_with

for fragment in sentence.split():
await asyncio.sleep(0.75)
suggestion += " " + fragment
yield InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request_number,
done=False
)

# finally, send a message confirming that we are done
yield InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request_number,
done=True,
)
```

## Prompt templates

Each provider can define **prompt templates** for each supported format. A prompt
Expand Down
52 changes: 52 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/completion_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Dict

from .models.completion import InlineCompletionRequest


def token_from_request(request: InlineCompletionRequest, suggestion: int):
"""Generate a deterministic token (for matching streamed messages)
using request number and suggestion number"""
return f"t{request.number}s{suggestion}"


def template_inputs_from_request(request: InlineCompletionRequest) -> Dict:
suffix = request.suffix.strip()
filename = request.path.split("/")[-1] if request.path else "untitled"

return {
"prefix": request.prefix,
"suffix": suffix,
"language": request.language,
"filename": filename,
"stop": ["\n```"],
}


def post_process_suggestion(suggestion: str, request: InlineCompletionRequest) -> str:
"""Remove spurious fragments from the suggestion.
While most models (especially instruct and infill models do not require
any pre-processing, some models such as gpt-4 which only have chat APIs
may require removing spurious fragments. This function uses heuristics
and request data to remove such fragments.
"""
# gpt-4 tends to add "```python" or similar
language = request.language or "python"
markdown_identifiers = {"ipython": ["ipython", "python", "py"]}
bad_openings = [
f"```{identifier}"
for identifier in markdown_identifiers.get(language, [language])
] + ["```"]
for opening in bad_openings:
if suggestion.startswith(opening):
suggestion = suggestion[len(opening) :].lstrip()
# check for the prefix inclusion (only if there was a bad opening)
if suggestion.startswith(request.prefix):
suggestion = suggestion[len(request.prefix) :]
break

# check if the suggestion ends with a closing markdown identifier and remove it
if suggestion.rstrip().endswith("```"):
suggestion = suggestion.rstrip()[:-3].rstrip()

return suggestion
81 changes: 81 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import List, Literal, Optional

from langchain.pydantic_v1 import BaseModel


class InlineCompletionRequest(BaseModel):
"""Message send by client to request inline completions.
Prefix/suffix implementation is used to avoid the need for synchronising
the notebook state at every key press (subject to change in future)."""

# unique message ID generated by the client used to identify replies and
# to easily discard replies for older requests
number: int
# prefix should include full text of the current cell preceding the cursor
prefix: str
# suffix should include full text of the current cell preceding the cursor
suffix: str
# media type for the current language, e.g. `text/x-python`
mime: str
# whether to stream the response (if supported by the model)
stream: bool
# path to the notebook of file for which the completions are generated
path: Optional[str]
# language inferred from the document mime type (if possible)
language: Optional[str]
# identifier of the cell for which the completions are generated if in a notebook
# previous cells and following cells can be used to learn the wider context
cell_id: Optional[str]


class InlineCompletionItem(BaseModel):
"""The inline completion suggestion to be displayed on the frontend.
See JupyterLab `InlineCompletionItem` documentation for the details.
"""

insertText: str
filterText: Optional[str]
isIncomplete: Optional[bool]
token: Optional[str]


class CompletionError(BaseModel):
type: str
traceback: str


class InlineCompletionList(BaseModel):
"""Reflection of JupyterLab's `IInlineCompletionList`."""

items: List[InlineCompletionItem]


class InlineCompletionReply(BaseModel):
"""Message sent from model to client with the infill suggestions"""

list: InlineCompletionList
# number of request for which we are replying
reply_to: int
error: Optional[CompletionError]


class InlineCompletionStreamChunk(BaseModel):
"""Message sent from model to client with the infill suggestions"""

type: Literal["stream"] = "stream"
response: InlineCompletionItem
reply_to: int
done: bool
error: Optional[CompletionError]


__all__ = [
"InlineCompletionRequest",
"InlineCompletionItem",
"CompletionError",
"InlineCompletionList",
"InlineCompletionReply",
"InlineCompletionStreamChunk",
]
Original file line number Diff line number Diff line change
Expand Up @@ -75,23 +75,17 @@ class AzureChatOpenAIProvider(BaseProvider, AzureChatOpenAI):
id = "azure-chat-openai"
name = "Azure OpenAI"
models = ["*"]
model_id_key = "deployment_name"
model_id_key = "azure_deployment"
model_id_label = "Deployment name"
pypi_package_deps = ["langchain_openai"]
# Confusingly, langchain uses both OPENAI_API_KEY and AZURE_OPENAI_API_KEY for azure
# https://github.com/langchain-ai/langchain/blob/f2579096993ae460516a0aae1d3e09f3eb5c1772/libs/partners/openai/langchain_openai/llms/azure.py#L85
auth_strategy = EnvAuthStrategy(name="AZURE_OPENAI_API_KEY")
registry = True

fields = [
TextField(
key="openai_api_base", label="Base API URL (required)", format="text"
),
TextField(
key="openai_api_version", label="API version (required)", format="text"
),
TextField(
key="openai_organization", label="Organization (optional)", format="text"
),
TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
TextField(key="azure_endpoint", label="Base API URL (required)", format="text"),
TextField(key="api_version", label="API version (required)", format="text"),
]


Expand Down
86 changes: 85 additions & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
import io
import json
from concurrent.futures import ThreadPoolExecutor
from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union
from typing import (
Any,
AsyncIterator,
ClassVar,
Coroutine,
Dict,
List,
Literal,
Optional,
Union,
)

from jsonpath_ng import parse
from langchain.chat_models.base import BaseChatModel
Expand All @@ -20,6 +30,8 @@
)
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
from langchain.schema import LLMResult
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import Runnable
from langchain.utils import get_from_dict_or_env
from langchain_community.chat_models import (
BedrockChat,
Expand All @@ -46,6 +58,13 @@
except:
from pydantic.main import ModelMetaclass

from . import completion_utils as completion
from .models.completion import (
InlineCompletionList,
InlineCompletionReply,
InlineCompletionRequest,
InlineCompletionStreamChunk,
)
from .models.persona import Persona

CHAT_SYSTEM_PROMPT = """
Expand Down Expand Up @@ -405,6 +424,71 @@ def is_chat_provider(self):
def allows_concurrency(self):
return True

async def generate_inline_completions(
self, request: InlineCompletionRequest
) -> InlineCompletionReply:
chain = self._create_completion_chain()
model_arguments = completion.template_inputs_from_request(request)
suggestion = await chain.ainvoke(input=model_arguments)
suggestion = completion.post_process_suggestion(suggestion, request)
return InlineCompletionReply(
list=InlineCompletionList(items=[{"insertText": suggestion}]),
reply_to=request.number,
)

async def stream_inline_completions(
self, request: InlineCompletionRequest
) -> AsyncIterator[InlineCompletionStreamChunk]:
chain = self._create_completion_chain()
token = completion.token_from_request(request, 0)
model_arguments = completion.template_inputs_from_request(request)
suggestion = ""

# send an incomplete `InlineCompletionReply`, indicating to the
# client that LLM output is about to streamed across this connection.
yield InlineCompletionReply(
list=InlineCompletionList(
items=[
{
# insert text starts empty as we do not pre-generate any part
"insertText": "",
"isIncomplete": True,
"token": token,
}
]
),
reply_to=request.number,
)

async for fragment in chain.astream(input=model_arguments):
suggestion += fragment
if suggestion.startswith("```"):
if "\n" not in suggestion:
# we are not ready to apply post-processing
continue
else:
suggestion = completion.post_process_suggestion(suggestion, request)
elif suggestion.rstrip().endswith("```"):
suggestion = completion.post_process_suggestion(suggestion, request)
yield InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=False,
)

# finally, send a message confirming that we are done
yield InlineCompletionStreamChunk(
type="stream",
response={"insertText": suggestion, "token": token},
reply_to=request.number,
done=True,
)

def _create_completion_chain(self) -> Runnable:
prompt_template = self.get_completion_prompt_template()
return prompt_template | self | StrOutputParser()


class AI21Provider(BaseProvider, AI21):
id = "ai21"
Expand Down
12 changes: 11 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,14 @@ def __init__(self, *args, chat_handlers: Dict[str, BaseChatHandler], **kwargs):
self._chat_handlers = chat_handlers

async def process_message(self, message: HumanChatMessage):
self.reply(_format_help_message(self._chat_handlers), message)
persona = self.config_manager.persona
lm_provider = self.config_manager.lm_provider
unsupported_slash_commands = (
lm_provider.unsupported_slash_commands if lm_provider else set()
)
self.reply(
_format_help_message(
self._chat_handlers, persona, unsupported_slash_commands
),
message,
)
Loading

0 comments on commit 0edb11a

Please sign in to comment.