Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport PR #430 on branch 1.x (Model parameters option to pass in model tuning, arbitrary parameters) #453

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, List, Type
from typing import ClassVar, List

from jupyter_ai_magics.providers import (
AuthStrategy,
Expand All @@ -12,7 +12,6 @@
HuggingFaceHubEmbeddings,
OpenAIEmbeddings,
)
from langchain.embeddings.base import Embeddings
from pydantic import BaseModel, Extra


Expand Down
19 changes: 7 additions & 12 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,11 @@ def handle_error(self, args: ErrorArgs):

prompt = f"Explain the following error:\n\n{last_error}"
# Set CellArgs based on ErrorArgs
cell_args = CellArgs(
type="root", model_id=args.model_id, format=args.format, reset=False
)
values = args.dict()
values["type"] = "root"
values["reset"] = False
cell_args = CellArgs(**values)

return self.run_ai_cell(cell_args, prompt)

def _append_exchange_openai(self, prompt: str, output: str):
Expand Down Expand Up @@ -538,16 +540,9 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
provider_params["request_schema"] = args.request_schema
provider_params["response_path"] = args.response_path

# Validate that the request schema is well-formed JSON
try:
json.loads(args.request_schema)
except json.JSONDecodeError as e:
raise ValueError(
"request-schema must be valid JSON. "
f"Error at line {e.lineno}, column {e.colno}: {e.msg}"
) from None
model_parameters = json.loads(args.model_parameters)

provider = Provider(**provider_params)
provider = Provider(**provider_params, **model_parameters)

# Apply a prompt template.
prompt = provider.get_prompt_template(args.format).format(prompt=prompt)
Expand Down
42 changes: 42 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Literal, Optional, get_args

import click
Expand Down Expand Up @@ -32,12 +33,21 @@
+ "does nothing with other providers."
)

MODEL_PARAMETERS_SHORT_OPTION = "-m"
MODEL_PARAMETERS_LONG_OPTION = "--model-parameters"
MODEL_PARAMETERS_HELP = (
"A JSON value that specifies extra values that will be passed "
"to the model. The accepted value parsed to a dict, unpacked "
"and passed as-is to the provider class."
)


class CellArgs(BaseModel):
type: Literal["root"] = "root"
model_id: str
format: FORMAT_CHOICES_TYPE
reset: bool
model_parameters: Optional[str]
# The following parameters are required only for SageMaker models
region_name: Optional[str]
request_schema: Optional[str]
Expand All @@ -49,6 +59,7 @@ class ErrorArgs(BaseModel):
type: Literal["error"] = "error"
model_id: str
format: FORMAT_CHOICES_TYPE
model_parameters: Optional[str]
# The following parameters are required only for SageMaker models
region_name: Optional[str]
request_schema: Optional[str]
Expand Down Expand Up @@ -93,6 +104,19 @@ def get_help(self, ctx):
click.echo(super().get_help(ctx))


def verify_json_value(ctx, param, value):
if not value:
return value
try:
json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(
f"{param.get_error_hint(ctx)} must be valid JSON. "
f"Error at line {e.lineno}, column {e.colno}: {e.msg}"
)
return value


@click.command()
@click.argument("model_id")
@click.option(
Expand Down Expand Up @@ -120,13 +144,22 @@ def get_help(self, ctx):
REQUEST_SCHEMA_LONG_OPTION,
required=False,
help=REQUEST_SCHEMA_HELP,
callback=verify_json_value,
)
@click.option(
RESPONSE_PATH_SHORT_OPTION,
RESPONSE_PATH_LONG_OPTION,
required=False,
help=RESPONSE_PATH_HELP,
)
@click.option(
MODEL_PARAMETERS_SHORT_OPTION,
MODEL_PARAMETERS_LONG_OPTION,
required=False,
help=MODEL_PARAMETERS_HELP,
callback=verify_json_value,
default="{}",
)
def cell_magic_parser(**kwargs):
"""
Invokes a language model identified by MODEL_ID, with the prompt being
Expand Down Expand Up @@ -166,13 +199,22 @@ def line_magic_parser():
REQUEST_SCHEMA_LONG_OPTION,
required=False,
help=REQUEST_SCHEMA_HELP,
callback=verify_json_value,
)
@click.option(
RESPONSE_PATH_SHORT_OPTION,
RESPONSE_PATH_LONG_OPTION,
required=False,
help=RESPONSE_PATH_HELP,
)
@click.option(
MODEL_PARAMETERS_SHORT_OPTION,
MODEL_PARAMETERS_LONG_OPTION,
required=False,
help=MODEL_PARAMETERS_HELP,
callback=verify_json_value,
default="{}",
)
def error_subparser(**kwargs):
"""
Explains the most recent error. Takes the same options (except -r) as
Expand Down
13 changes: 12 additions & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
import io
import json
from concurrent.futures import ThreadPoolExecutor
from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union
from typing import (
Any,
ClassVar,
Coroutine,
Dict,
List,
Literal,
Mapping,
Optional,
Union,
)

from jsonpath_ng import parse
from langchain.chat_models import (
Expand Down Expand Up @@ -621,6 +631,7 @@ def __init__(self, *args, **kwargs):
content_handler = JsonContentHandler(
request_schema=request_schema, response_path=response_path
)

super().__init__(*args, **kwargs, content_handler=content_handler)

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(self, retriever, *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
self.llm = provider(**provider_params)
model_parameters = self.get_model_parameters(provider, provider_params)
self.llm = provider(**provider_params, **model_parameters)
memory = ConversationBufferWindowMemory(
memory_key="chat_history", return_messages=True, k=2
)
Expand Down
11 changes: 10 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import traceback

# necessary to prevent circular import
from typing import TYPE_CHECKING, Dict, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
from uuid import uuid4

from jupyter_ai.config_manager import ConfigManager, Logger
Expand All @@ -23,10 +23,12 @@ def __init__(
log: Logger,
config_manager: ConfigManager,
root_chat_handlers: Dict[str, "RootChatHandler"],
model_parameters: Dict[str, Dict],
):
self.log = log
self.config_manager = config_manager
self._root_chat_handlers = root_chat_handlers
self.model_parameters = model_parameters
self.parser = argparse.ArgumentParser()
self.llm = None
self.llm_params = None
Expand Down Expand Up @@ -122,6 +124,13 @@ def get_llm_chain(self):
self.llm_params = lm_provider_params
return self.llm_chain

def get_model_parameters(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
return self.model_parameters.get(
f"{provider.id}:{provider_params['model_id']}", {}
)

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def __init__(self, chat_history: List[ChatMessage], *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
llm = provider(**provider_params)
model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)

if llm.is_chat_provider:
prompt_template = ChatPromptTemplate.from_messages(
Expand Down
4 changes: 3 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def __init__(self, root_dir: str, *args, **kwargs):
def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
llm = provider(**provider_params)
model_parameters = self.get_model_parameters(provider, provider_params)
llm = provider(**provider_params, **model_parameters)

self.llm = llm
return llm

Expand Down
46 changes: 45 additions & 1 deletion packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jupyter_ai.chat_handlers.learn import Retriever
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
from jupyter_server.extension.application import ExtensionApp
from traitlets import List, Unicode
from traitlets import Dict, List, Unicode

from .chat_handlers import (
AskChatHandler,
Expand Down Expand Up @@ -53,13 +53,56 @@ class AiExtension(ExtensionApp):
config=True,
)

allowed_models = List(
Unicode(),
default_value=None,
help="""
Language models to allow, as a list of global model IDs in the format
`<provider>:<local-model-id>`. If `None`, all are allowed. Defaults to
`None`.

Note: Currently, if `allowed_providers` is also set, then this field is
ignored. This is subject to change in a future non-major release. Using
both traits is considered to be undefined behavior at this time.
""",
allow_none=True,
config=True,
)

blocked_models = List(
Unicode(),
default_value=None,
help="""
Language models to block, as a list of global model IDs in the format
`<provider>:<local-model-id>`. If `None`, none are blocked. Defaults to
`None`.
""",
allow_none=True,
config=True,
)

model_parameters = Dict(
key_trait=Unicode(),
value_trait=Dict(),
default_value={},
help="""Key-value pairs for model id and corresponding parameters that
are passed to the provider class. The values are unpacked and passed to
the provider class as-is.""",
allow_none=True,
config=True,
)

def initialize_settings(self):
start = time.time()
restrictions = {
"allowed_providers": self.allowed_providers,
"blocked_providers": self.blocked_providers,
}

self.settings["model_parameters"] = self.model_parameters
self.log.info(f"Configured model parameters: {self.model_parameters}")

# Fetch LM & EM providers
self.settings["lm_providers"] = get_lm_providers(
log=self.log, restrictions=restrictions
)
Expand Down Expand Up @@ -107,6 +150,7 @@ def initialize_settings(self):
"log": self.log,
"config_manager": self.settings["jai_config_manager"],
"root_chat_handlers": self.settings["jai_root_chat_handlers"],
"model_parameters": self.settings["model_parameters"],
}
default_chat_handler = DefaultChatHandler(
**chat_handler_kwargs, chat_history=self.settings["chat_history"]
Expand Down
Loading