diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 608bea77f..71ca3f185 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -3,9 +3,9 @@ import os import shutil import time +from copy import deepcopy from typing import List, Optional, Type, Union -from copy import deepcopy from deepmerge import always_merger as Merger from jsonschema import Draft202012Validator as Validator from jupyter_ai.models import DescribeConfigResponse, GlobalConfig, UpdateConfigRequest diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index c807cf91b..168c675f7 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -24,6 +24,7 @@ def config_path(jp_data_dir): def schema_path(jp_data_dir): return str(jp_data_dir / "config_schema.json") + @pytest.fixture def config_file_with_model_fields(jp_data_dir): """ @@ -33,21 +34,16 @@ def config_file_with_model_fields(jp_data_dir): config_data = { "model_provider_id:": "openai-chat:gpt-4o", "embeddings_provider_id": None, - "api_keys": { - "openai_api_key": "foobar" - }, + "api_keys": {"openai_api_key": "foobar"}, "send_with_shift_enter": False, - "fields": { - "openai-chat:gpt-4o": { - "openai_api_base": "https://example.com" - } - }, + "fields": {"openai-chat:gpt-4o": {"openai_api_base": "https://example.com"}}, } config_path = jp_data_dir / "config.json" with open(config_path, "w") as file: json.dump(config_data, file) return str(config_path) + @pytest.fixture def common_cm_kwargs(config_path, schema_path): """Kwargs that are commonly used when initializing the CM.""" @@ -197,6 +193,7 @@ def configure_to_openai(cm: ConfigManager): cm.update_config(req) return LM_GID, EM_GID, LM_LID, EM_LID, API_PARAMS + def configure_with_fields(cm: ConfigManager): """ Configures the ConfigManager with fields and API keys. @@ -204,22 +201,21 @@ def configure_with_fields(cm: ConfigManager): """ req = UpdateConfigRequest( model_provider_id="openai-chat:gpt-4o", - api_keys={ - "OPENAI_API_KEY": "foobar" - }, + api_keys={"OPENAI_API_KEY": "foobar"}, fields={ "openai-chat:gpt-4o": { "openai_api_base": "https://example.com", } - } + }, ) cm.update_config(req) return { "model_id": "gpt-4o", "openai_api_key": "foobar", - "openai_api_base": "https://example.com" + "openai_api_base": "https://example.com", } + def test_snapshot_default_config(cm: ConfigManager, snapshot): config_from_cm: DescribeConfigResponse = cm.get_config() assert config_from_cm == snapshot(exclude=lambda prop, path: prop == "last_read") @@ -448,6 +444,7 @@ def test_handle_bad_provider_ids(cm_with_bad_provider_ids): assert config_desc.model_provider_id is None assert config_desc.embeddings_provider_id is None + def test_config_manager_returns_fields(cm): """ Asserts that `ConfigManager.lm_provider_params` returns model fields set by @@ -456,12 +453,16 @@ def test_config_manager_returns_fields(cm): expected_model_args = configure_with_fields(cm) assert cm.lm_provider_params == expected_model_args -def test_config_manager_does_not_write_to_defaults(config_file_with_model_fields, schema_path): + +def test_config_manager_does_not_write_to_defaults( + config_file_with_model_fields, schema_path +): """ Asserts that `ConfigManager` does not write to the `defaults` argument when the configured chat model differs from the one specified in `defaults`. """ from copy import deepcopy + config_path = config_file_with_model_fields log = logging.getLogger() lm_providers = get_lm_providers() @@ -471,7 +472,7 @@ def test_config_manager_does_not_write_to_defaults(config_file_with_model_fields "model_provider_id": None, "embeddings_provider_id": None, "api_keys": {}, - "fields": {} + "fields": {}, } expected_defaults = deepcopy(defaults)