From 48097b1ffc1f8df0b9894c33cc4787b445a3ec38 Mon Sep 17 00:00:00 2001 From: aws-khatria <107514829+aws-khatria@users.noreply.github.com> Date: Thu, 25 Jan 2024 19:44:20 -0800 Subject: [PATCH] Backport PR #421: Setting default model providers --- docs/source/users/index.md | 22 +++++ .../jupyter-ai/jupyter_ai/config_manager.py | 34 ++++++-- packages/jupyter-ai/jupyter_ai/extension.py | 40 ++++++++++ .../jupyter_ai/tests/test_config_manager.py | 80 +++++++++++++++++++ 4 files changed, 168 insertions(+), 8 deletions(-) diff --git a/docs/source/users/index.md b/docs/source/users/index.md index 06a5ddde7..6814a087a 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -765,6 +765,28 @@ The `--response-path` option is a [JSONPath](https://goessner.net/articles/JsonP You can specify an allowlist, to only allow only a certain list of providers, or a blocklist, to block some providers. +### Configuring default models and API keys + +This configuration allows for setting a default language and embedding models, and their corresponding API keys. +These values are offered as a starting point for users, so they don't have to select the models and API keys, however, +the selections they make in the settings panel will take precedence over these values. + +Specify default language model +``` +jupyter lab --AiExtension.default_language_model=bedrock-chat:anthropic.claude-v2 +``` + +Specify default embedding model +``` +jupyter lab --AiExtension.default_embedding_model=bedrock:amazon.titan-embed-text-v1 +``` + +Specify default API keys +``` +jupyter lab --AiExtension.default_api_keys={'OPENAI_API_KEY': 'sk-abcd'} +``` + + ### Blocklisting providers This configuration allows for blocking specific providers in the settings panel. diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 82ef03126..01d3fe766 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -105,6 +105,7 @@ def __init__( blocked_providers: Optional[List[str]], allowed_models: Optional[List[str]], blocked_models: Optional[List[str]], + defaults: dict, *args, **kwargs, ): @@ -120,6 +121,8 @@ def __init__( self._blocked_providers = blocked_providers self._allowed_models = allowed_models self._blocked_models = blocked_models + self._defaults = defaults + """Provider defaults.""" self._last_read: Optional[int] = None """When the server last read the config file. If the file was not @@ -146,14 +149,20 @@ def _init_validator(self) -> Validator: self.validator = Validator(schema) def _init_config(self): + default_config = self._init_defaults() if os.path.exists(self.config_path): - self._process_existing_config() + self._process_existing_config(default_config) else: - self._create_default_config() + self._create_default_config(default_config) - def _process_existing_config(self): + def _process_existing_config(self, default_config): with open(self.config_path, encoding="utf-8") as f: - config = GlobalConfig(**json.loads(f.read())) + existing_config = json.loads(f.read()) + merged_config = Merger.merge( + default_config, + {k: v for k, v in existing_config.items() if v is not None}, + ) + config = GlobalConfig(**merged_config) validated_config = self._validate_lm_em_id(config) # re-write to the file to validate the config and apply any @@ -192,14 +201,23 @@ def _validate_lm_em_id(self, config): return config - def _create_default_config(self): - properties = self.validator.schema.get("properties", {}) + def _create_default_config(self, default_config): + self._write_config(GlobalConfig(**default_config)) + + def _init_defaults(self): field_list = GlobalConfig.__fields__.keys() + properties = self.validator.schema.get("properties", {}) field_dict = { field: properties.get(field).get("default") for field in field_list } - default_config = GlobalConfig(**field_dict) - self._write_config(default_config) + if self._defaults is None: + return field_dict + + for field in field_list: + default_value = self._defaults.get(field) + if default_value is not None: + field_dict[field] = default_value + return field_dict def _read_config(self) -> GlobalConfig: """Returns the user's current configuration as a GlobalConfig object. diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index a9385801a..dc689fe0e 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -104,6 +104,38 @@ class AiExtension(ExtensionApp): config=True, ) + default_language_model = Unicode( + default_value=None, + allow_none=True, + help=""" + Default language model to use, as string in the format + :, defaults to None. + """, + config=True, + ) + + default_embeddings_model = Unicode( + default_value=None, + allow_none=True, + help=""" + Default embeddings model to use, as string in the format + :, defaults to None. + """, + config=True, + ) + + default_api_keys = Dict( + key_trait=Unicode(), + value_trait=Unicode(), + default_value=None, + allow_none=True, + help=""" + Default API keys for model providers, as a dictionary, + in the format `:`. Defaults to None. + """, + config=True, + ) + def initialize_settings(self): start = time.time() @@ -122,6 +154,13 @@ def initialize_settings(self): self.settings["model_parameters"] = self.model_parameters self.log.info(f"Configured model parameters: {self.model_parameters}") + defaults = { + "model_provider_id": self.default_language_model, + "embeddings_provider_id": self.default_embeddings_model, + "api_keys": self.default_api_keys, + "fields": self.model_parameters, + } + # Fetch LM & EM providers self.settings["lm_providers"] = get_lm_providers( log=self.log, restrictions=restrictions @@ -140,6 +179,7 @@ def initialize_settings(self): blocked_providers=self.blocked_providers, allowed_models=self.allowed_models, blocked_models=self.blocked_models, + defaults=defaults, ) self.log.info("Registered providers.") diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index c238fc448..9aa16d2f8 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -41,6 +41,35 @@ def common_cm_kwargs(config_path, schema_path): "blocked_providers": None, "allowed_models": None, "blocked_models": None, + "restrictions": {"allowed_providers": None, "blocked_providers": None}, + "defaults": { + "model_provider_id": None, + "embeddings_provider_id": None, + "api_keys": None, + "fields": None, + }, + } + + +@pytest.fixture +def cm_kargs_with_defaults(config_path, schema_path, common_cm_kwargs): + """Kwargs that are commonly used when initializing the CM.""" + log = logging.getLogger() + lm_providers = get_lm_providers() + em_providers = get_em_providers() + return { + **common_cm_kwargs, + "defaults": { + "model_provider_id": "bedrock-chat:anthropic.claude-v1", + "embeddings_provider_id": "bedrock:amazon.titan-embed-text-v1", + "api_keys": {"OPENAI_API_KEY": "open-ai-key-value"}, + "fields": { + "bedrock-chat:anthropic.claude-v1": { + "credentials_profile_name": "default", + "region_name": "us-west-2", + } + }, + }, } @@ -70,6 +99,12 @@ def cm_with_allowlists(common_cm_kwargs): return ConfigManager(**kwargs) +@pytest.fixture +def cm_with_defaults(cm_kargs_with_defaults): + """The default ConfigManager instance, with an empty config and config schema.""" + return ConfigManager(**cm_kargs_with_defaults) + + @pytest.fixture(autouse=True) def reset(config_path, schema_path): """Fixture that deletes the config and config schema after each test.""" @@ -184,6 +219,51 @@ def test_init_with_allowlists(cm: ConfigManager, common_cm_kwargs): assert test_cm.em_gid == None +def test_init_with_default_values( + cm_with_defaults: ConfigManager, + config_path: str, + schema_path: str, + common_cm_kwargs, +): + """ + Test that the ConfigManager initializes with the expected default values. + + Args: + cm_with_defaults (ConfigManager): A ConfigManager instance with default values. + config_path (str): The path to the configuration file. + schema_path (str): The path to the schema file. + """ + config_response = cm_with_defaults.get_config() + # assert config response + assert config_response.model_provider_id == "bedrock-chat:anthropic.claude-v1" + assert ( + config_response.embeddings_provider_id == "bedrock:amazon.titan-embed-text-v1" + ) + assert config_response.api_keys == ["OPENAI_API_KEY"] + assert config_response.fields == { + "bedrock-chat:anthropic.claude-v1": { + "credentials_profile_name": "default", + "region_name": "us-west-2", + } + } + + del cm_with_defaults + + log = logging.getLogger() + lm_providers = get_lm_providers() + em_providers = get_em_providers() + kwargs = { + **common_cm_kwargs, + "defaults": {"model_provider_id": "bedrock-chat:anthropic.claude-v2"}, + } + cm_with_defaults_override = ConfigManager(**kwargs) + + assert ( + cm_with_defaults_override.get_config().model_provider_id + == "bedrock-chat:anthropic.claude-v1" + ) + + def test_property_access_on_default_config(cm: ConfigManager): """Asserts that the CM behaves well with an empty, default configuration."""