Skip to content

Commit

Permalink
Setting default model providers
Browse files Browse the repository at this point in the history
  • Loading branch information
aws-khatria committed Oct 31, 2023
1 parent 7e4a2a5 commit d81338e
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 5 deletions.
25 changes: 21 additions & 4 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
lm_providers: LmProvidersDict,
em_providers: EmProvidersDict,
restrictions: ProviderRestrictions,
provider_defaults: dict,
*args,
**kwargs,
):
Expand All @@ -111,6 +112,8 @@ def __init__(
self._em_providers = em_providers
"""Provider restrictions."""
self._restrictions = restrictions
"""Provider defaults."""
self._provider_defaults = provider_defaults

"""When the server last read the config file. If the file was not
modified after this time, then we can return the cached
Expand All @@ -137,21 +140,35 @@ def _init_validator(self) -> Validator:
self.validator = Validator(schema)

def _init_config(self):
default_dict = self._init_defaults()
if os.path.exists(self.config_path):
with open(self.config_path, encoding="utf-8") as f:
config = GlobalConfig(**json.loads(f.read()))
field_dict = json.loads(f.read())
# merge the default config with the config read from disk
Merger.merge(field_dict, default_dict)
config = GlobalConfig(**field_dict)
# re-write to the file to validate the config and apply any
# updates to the config file immediately
self._write_config(config)
return

properties = self.validator.schema.get("properties", {})
default_config = GlobalConfig(**default_dict)
self._write_config(default_config)

def _init_defaults(self):
field_list = GlobalConfig.__fields__.keys()
properties = self.validator.schema.get("properties", {})
field_dict = {
field: properties.get(field).get("default") for field in field_list
}
default_config = GlobalConfig(**field_dict)
self._write_config(default_config)
if self._provider_defaults is None:
return field_dict

for field in field_list:
default_value = self._provider_defaults.get(field)
if default_value is not None:
field_dict[field] = default_value
return field_dict

def _read_config(self) -> GlobalConfig:
"""Returns the user's current configuration as a GlobalConfig object.
Expand Down
40 changes: 39 additions & 1 deletion packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jupyter_ai.chat_handlers.learn import Retriever
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
from jupyter_server.extension.application import ExtensionApp
from traitlets import List, Unicode
from traitlets import List, Unicode, Dict

from .chat_handlers import (
AskChatHandler,
Expand Down Expand Up @@ -53,13 +53,50 @@ class AiExtension(ExtensionApp):
config=True,
)

model_provider_id = Unicode(
default_value=None,
allow_none=True,
help="Default language model provider.",
config=True,
)

embeddings_provider_id = Unicode(
default_value=None,
allow_none=True,
help="Default embeddings model provider.",
config=True,
)

api_keys = Dict(
Unicode(),
Unicode(),
default_value=None,
allow_none=True,
help="API keys for language model providers.",
config=True,
)

fields = Dict(
default_value=None,
allow_none=True,
help="Sub fields required for language model providers.",
config=True,
)

def initialize_settings(self):
start = time.time()
restrictions = {
"allowed_providers": self.allowed_providers,
"blocked_providers": self.blocked_providers,
}

provider_defaults = {
"model_provider_id": self.model_provider_id,
"embeddings_provider_id": self.embeddings_provider_id,
"api_keys": self.api_keys,
"fields": self.fields,
}

self.settings["lm_providers"] = get_lm_providers(
log=self.log, restrictions=restrictions
)
Expand All @@ -74,6 +111,7 @@ def initialize_settings(self):
lm_providers=self.settings["lm_providers"],
em_providers=self.settings["em_providers"],
restrictions=restrictions,
provider_defaults=provider_defaults,
)

self.log.info("Registered providers.")
Expand Down
73 changes: 73 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,34 @@ def common_cm_kwargs(config_path, schema_path):
"config_path": config_path,
"schema_path": schema_path,
"restrictions": {"allowed_providers": None, "blocked_providers": None},
"provider_defaults": {
"model_provider_id": None,
"embeddings_provider_id": None,
"api_keys": None,
"fields": None,
}
}


@pytest.fixture
def cm_kargs_with_defaults(config_path, schema_path):
"""Kwargs that are commonly used when initializing the CM."""
log = logging.getLogger()
lm_providers = get_lm_providers()
em_providers = get_em_providers()
return {
"log": log,
"lm_providers": lm_providers,
"em_providers": em_providers,
"config_path": config_path,
"schema_path": schema_path,
"restrictions": {"allowed_providers": None, "blocked_providers": None},
"provider_defaults": {
"model_provider_id": "bedrock-chat:anthropic.claude-v1",
"embeddings_provider_id": "bedrock:amazon.titan-embed-text-v1",
"api_keys": {"OPENAI_API_KEY": "open-ai-key-value"},
"fields": {"bedrock-chat:anthropic.claude-v1":{"credentials_profile_name": "default","region_name": "us-west-2"}},
}
}


Expand All @@ -46,6 +74,12 @@ def cm(common_cm_kwargs):
return ConfigManager(**common_cm_kwargs)


@pytest.fixture
def cm_with_defaults(cm_kargs_with_defaults):
"""The default ConfigManager instance, with an empty config and config schema."""
return ConfigManager(**cm_kargs_with_defaults)


@pytest.fixture(autouse=True)
def reset(config_path, schema_path):
"""Fixture that deletes the config and config schema after each test."""
Expand Down Expand Up @@ -114,8 +148,47 @@ def test_init_with_existing_config(
config_path=config_path,
schema_path=schema_path,
restrictions={"allowed_providers": None, "blocked_providers": None},
provider_defaults=None
)

def test_init_with_default_values(
cm_with_defaults: ConfigManager, config_path: str, schema_path: str
):
"""
Test that the ConfigManager initializes with the expected default values.
Args:
cm_with_defaults (ConfigManager): A ConfigManager instance with default values.
config_path (str): The path to the configuration file.
schema_path (str): The path to the schema file.
"""
config_response = cm_with_defaults.get_config()
#assert config response
assert config_response.model_provider_id == "bedrock-chat:anthropic.claude-v1"
assert config_response.embeddings_provider_id == "bedrock:amazon.titan-embed-text-v1"
assert config_response.api_keys == ["OPENAI_API_KEY"]
assert config_response.fields == {"bedrock-chat:anthropic.claude-v1":{"credentials_profile_name": "default","region_name": "us-west-2"}}

del cm_with_defaults

log = logging.getLogger()
lm_providers = get_lm_providers()
em_providers = get_em_providers()
cm_with_defaults_override =ConfigManager(
log=log,
lm_providers=lm_providers,
em_providers=em_providers,
config_path=config_path,
schema_path=schema_path,
restrictions={"allowed_providers": None, "blocked_providers": None},
provider_defaults={"model_provider_id": "bedrock-chat:anthropic.claude-v2"},
)

assert cm_with_defaults_override.get_config().model_provider_id == "bedrock-chat:anthropic.claude-v2"





def test_property_access_on_default_config(cm: ConfigManager):
"""Asserts that the CM behaves well with an empty, default
Expand Down

0 comments on commit d81338e

Please sign in to comment.