Skip to content

Commit

Permalink
Endpoint args for SM endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
3coins committed Nov 3, 2023
1 parent 2bc0c2b commit 5a9ec2d
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, List, Type
from typing import ClassVar, List

from jupyter_ai_magics.providers import (
AuthStrategy,
Expand All @@ -12,7 +12,6 @@
HuggingFaceHubEmbeddings,
OpenAIEmbeddings,
)
from langchain.embeddings.base import Embeddings
from pydantic import BaseModel, Extra


Expand Down
22 changes: 22 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@
+ "does nothing with other providers."
)

ENDPOINT_ARGS_SHORT_OPTION = "-e"
ENDPOINT_ARGS_LONG_OPTION = "--endpoint-args"
ENDPOINT_ARGS_HELP = (
"A JSON value that specifies extra values that will be passed "
"to the SageMaker Endpoint invoke function."
)

MODEL_ARGS_SHORT_OPTION = "-m"
MODEL_ARGS_LONG_OPTION = "--model-args"
MODEL_ARGS_HELP = (
"A JSON value that specifies extra values that will be passed to"
"the payload body of the invoke function. This can be useful to"
"pass model tuning parameters such as token count, temperature "
"etc., that affects the response generated by of a model."
)


class CellArgs(BaseModel):
type: Literal["root"] = "root"
Expand Down Expand Up @@ -127,6 +143,12 @@ def get_help(self, ctx):
required=False,
help=RESPONSE_PATH_HELP,
)
@click.option(
ENDPOINT_ARGS_SHORT_OPTION,
ENDPOINT_ARGS_LONG_OPTION,
required=False,
help=ENDPOINT_ARGS_HELP,
)
def cell_magic_parser(**kwargs):
"""
Invokes a language model identified by MODEL_ID, with the prompt being
Expand Down
35 changes: 33 additions & 2 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
import io
import json
from concurrent.futures import ThreadPoolExecutor
from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union
from typing import (
Any,
ClassVar,
Coroutine,
Dict,
List,
Literal,
Mapping,
Optional,
Union,
)

from jsonpath_ng import parse
from langchain.chat_models import (
Expand Down Expand Up @@ -232,6 +242,15 @@ def allows_concurrency(self):
return True


def pop_with_default(model: Mapping[str, Any], name: str, default: Any) -> Any:
try:
value = model.pop(name)
except KeyError as e:
return default

return value


class AI21Provider(BaseProvider, AI21):
id = "ai21"
name = "AI21"
Expand Down Expand Up @@ -613,6 +632,9 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
TextField(
key="response_path", label="Response path (required)", format="jsonpath"
),
MultilineTextField(
key="endpoint_kwargs", label="Endpoint arguments", format="json"
),
]

def __init__(self, *args, **kwargs):
Expand All @@ -621,7 +643,16 @@ def __init__(self, *args, **kwargs):
content_handler = JsonContentHandler(
request_schema=request_schema, response_path=response_path
)
super().__init__(*args, **kwargs, content_handler=content_handler)

endpoint_kwargs = pop_with_default(kwargs, "endpoint_kwargs", "{}")
endpoint_kwargs = json.loads(endpoint_kwargs)

super().__init__(
*args,
**kwargs,
content_handler=content_handler,
endpoint_kwargs=endpoint_kwargs,
)

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)
Expand Down

0 comments on commit 5a9ec2d

Please sign in to comment.