Skip to content

Commit

Permalink
Updated magics, added model_parameters, removed model_kwargs and endp…
Browse files Browse the repository at this point in the history
…oint_kwargs.
  • Loading branch information
3coins committed Nov 8, 2023
1 parent 628641f commit 400e3cf
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 43 deletions.
4 changes: 2 additions & 2 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -888,8 +888,8 @@ This configuration allows specifying arbitrary parameters that are unpacked and
This is useful for passing parameters such as model tuning that affect the response generation by the model.
This is also an appropriate place to pass in custom attributes required by certain providers/models.

The accepted value should be a dictionary, with top level keys as the model id (provider:model_id), and value
should be any arbitrary dictionary which is unpacked and passed as is to the provider class.
The accepted value is a dictionary, with top level keys as the model id (provider:model_id), and value
should be any arbitrary dictionary which is unpacked and passed as-is to the provider class.

#### Configuring as a startup option
In this sample, the `bedrock` provider will be created with the value for `model_kwargs` when `ai21.j2-mid-v1` model is selected.
Expand Down
8 changes: 2 additions & 6 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,13 +518,9 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
provider_params["request_schema"] = args.request_schema
provider_params["response_path"] = args.response_path

if args.model_kwargs:
provider_params["model_kwargs"] = args.model_kwargs
model_parameters = json.loads(args.model_parameters)

if args.endpoint_kwargs:
provider_params["endpoint_kwargs"] = args.endpoint_kwargs

provider = Provider(**provider_params)
provider = Provider(**provider_params, **model_parameters)

# Apply a prompt template.
prompt = provider.get_prompt_template(args.format).format(prompt=prompt)
Expand Down
49 changes: 14 additions & 35 deletions packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,12 @@
+ "does nothing with other providers."
)

ENDPOINT_ARGS_SHORT_OPTION = "-e"
ENDPOINT_ARGS_LONG_OPTION = "--endpoint-kwargs"
ENDPOINT_ARGS_HELP = (
MODEL_PARAMETERS_SHORT_OPTION = "-m"
MODEL_PARAMETERS_LONG_OPTION = "--model-parameters"
MODEL_PARAMETERS_HELP = (
"A JSON value that specifies extra values that will be passed "
"to the SageMaker Endpoint invoke function."
)

MODEL_ARGS_SHORT_OPTION = "-m"
MODEL_ARGS_LONG_OPTION = "--model-kwargs"
MODEL_ARGS_HELP = (
"A JSON value that specifies extra values that will be passed to"
"the payload body of the invoke function. This can be useful to"
"pass model tuning parameters such as token count, temperature "
"etc., that affects the response generated by of a model."
"to the model. The accepted value parsed to a dict, unpacked "
"and passed as-is to the provider class."
)


Expand All @@ -59,8 +51,7 @@ class CellArgs(BaseModel):
region_name: Optional[str]
request_schema: Optional[str]
response_path: Optional[str]
model_kwargs: Optional[str]
endpoint_kwargs: Optional[str]
model_parameters: Optional[str]


# Should match CellArgs, but without "reset"
Expand Down Expand Up @@ -161,18 +152,12 @@ def verify_json_value(ctx, param, value):
help=RESPONSE_PATH_HELP,
)
@click.option(
ENDPOINT_ARGS_SHORT_OPTION,
ENDPOINT_ARGS_LONG_OPTION,
required=False,
help=ENDPOINT_ARGS_HELP,
callback=verify_json_value,
)
@click.option(
MODEL_ARGS_SHORT_OPTION,
MODEL_ARGS_LONG_OPTION,
MODEL_PARAMETERS_SHORT_OPTION,
MODEL_PARAMETERS_LONG_OPTION,
required=False,
help=MODEL_ARGS_HELP,
help=MODEL_PARAMETERS_HELP,
callback=verify_json_value,
default="{}",
)
def cell_magic_parser(**kwargs):
"""
Expand Down Expand Up @@ -222,18 +207,12 @@ def line_magic_parser():
help=RESPONSE_PATH_HELP,
)
@click.option(
ENDPOINT_ARGS_SHORT_OPTION,
ENDPOINT_ARGS_LONG_OPTION,
required=False,
help=ENDPOINT_ARGS_HELP,
callback=verify_json_value,
)
@click.option(
MODEL_ARGS_SHORT_OPTION,
MODEL_ARGS_LONG_OPTION,
MODEL_PARAMETERS_SHORT_OPTION,
MODEL_PARAMETERS_LONG_OPTION,
required=False,
help=MODEL_ARGS_HELP,
help=MODEL_PARAMETERS_HELP,
callback=verify_json_value,
default="{}",
)
def error_subparser(**kwargs):
"""
Expand Down

0 comments on commit 400e3cf

Please sign in to comment.