diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 8708b4b94..5adc604cf 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -3,7 +3,7 @@ import os import shutil import time -from typing import Optional, Union +from typing import List, Optional, Union from deepmerge import always_merger as Merger from jsonschema import Draft202012Validator as Validator @@ -12,10 +12,8 @@ AnyProvider, EmProvidersDict, LmProvidersDict, - ProviderRestrictions, get_em_provider, get_lm_provider, - is_provider_allowed, ) from jupyter_core.paths import jupyter_data_dir from traitlets import Integer, Unicode @@ -57,6 +55,10 @@ class KeyEmptyError(Exception): pass +class BlockedModelError(Exception): + pass + + def _validate_provider_authn(config: GlobalConfig, provider: AnyProvider): # TODO: handle non-env auth strategies if not provider.auth_strategy or provider.auth_strategy.type != "env": @@ -99,27 +101,34 @@ def __init__( log: Logger, lm_providers: LmProvidersDict, em_providers: EmProvidersDict, - restrictions: ProviderRestrictions, + allowed_providers: Optional[List[str]], + blocked_providers: Optional[List[str]], + allowed_models: Optional[List[str]], + blocked_models: Optional[List[str]], *args, **kwargs, ): super().__init__(*args, **kwargs) self.log = log - """List of LM providers.""" + self._lm_providers = lm_providers - """List of EM providers.""" + """List of LM providers.""" self._em_providers = em_providers - """Provider restrictions.""" - self._restrictions = restrictions + """List of EM providers.""" + + self._allowed_providers = allowed_providers + self._blocked_providers = blocked_providers + self._allowed_models = allowed_models + self._blocked_models = blocked_models + self._last_read: Optional[int] = None """When the server last read the config file. If the file was not modified after this time, then we can return the cached `self._config`.""" - self._last_read: Optional[int] = None + self._config: Optional[GlobalConfig] = None """In-memory cache of the `GlobalConfig` object parsed from the config file.""" - self._config: Optional[GlobalConfig] = None self._init_config_schema() self._init_validator() @@ -140,6 +149,26 @@ def _init_config(self): if os.path.exists(self.config_path): with open(self.config_path, encoding="utf-8") as f: config = GlobalConfig(**json.loads(f.read())) + lm_id = config.model_provider_id + em_id = config.embeddings_provider_id + + # if the currently selected language or embedding model are + # forbidden, set them to `None` and log a warning. + if lm_id is not None and not self._validate_model( + lm_id, raise_exc=False + ): + self.log.warning( + f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None." + ) + config.model_provider_id = None + if em_id is not None and not self._validate_model( + em_id, raise_exc=False + ): + self.log.warning( + f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None." + ) + config.embeddings_provider_id = None + # re-write to the file to validate the config and apply any # updates to the config file immediately self._write_config(config) @@ -181,14 +210,17 @@ def _validate_config(self, config: GlobalConfig): _, lm_provider = get_lm_provider( config.model_provider_id, self._lm_providers ) - # do not check config for blocked providers - if not is_provider_allowed(config.model_provider_id, self._restrictions): - assert not lm_provider - return + + # verify model is declared by some provider if not lm_provider: raise ValueError( f"No language model is associated with '{config.model_provider_id}'." ) + + # verify model is not blocked + self._validate_model(config.model_provider_id) + + # verify model is authenticated _validate_provider_authn(config, lm_provider) # validate embedding model config @@ -196,18 +228,56 @@ def _validate_config(self, config: GlobalConfig): _, em_provider = get_em_provider( config.embeddings_provider_id, self._em_providers ) - # do not check config for blocked providers - if not is_provider_allowed( - config.embeddings_provider_id, self._restrictions - ): - assert not em_provider - return + + # verify model is declared by some provider if not em_provider: raise ValueError( f"No embedding model is associated with '{config.embeddings_provider_id}'." ) + + # verify model is not blocked + self._validate_model(config.embeddings_provider_id) + + # verify model is authenticated _validate_provider_authn(config, em_provider) + def _validate_model(self, model_id: str, raise_exc=True): + """ + Validates a model against the set of allow/blocklists specified by the + traitlets configuration, returning `True` if the model is allowed, and + raising a `BlockedModelError` otherwise. If `raise_exc=False`, this + function returns `False` if the model is not allowed. + """ + + assert model_id is not None + components = model_id.split(":", 1) + assert len(components) == 2 + provider_id, _ = components + + try: + if self._allowed_providers and provider_id not in self._allowed_providers: + raise BlockedModelError( + "Model provider not included in the provider allowlist." + ) + + if self._blocked_providers and provider_id in self._blocked_providers: + raise BlockedModelError( + "Model provider included in the provider blocklist." + ) + + if self._allowed_models and model_id not in self._allowed_models: + raise BlockedModelError("Model not included in the model allowlist.") + + if self._blocked_models and model_id in self._blocked_models: + raise BlockedModelError("Model included in the model blocklist.") + except BlockedModelError as e: + if raise_exc: + raise e + else: + return False + + return True + def _write_config(self, new_config: GlobalConfig): """Updates configuration and persists it to disk. This accepts a complete `GlobalConfig` object, and should not be called publicly.""" diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index b92eda686..a2ecd5245 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -56,7 +56,15 @@ class AiExtension(ExtensionApp): allowed_models = List( Unicode(), default_value=None, - help="Language models to allow, as a list of global model IDs in the format `:`. If `None`, all are allowed. Defaults to `None`.", + help=""" + Language models to allow, as a list of global model IDs in the format + `:`. If `None`, all are allowed. Defaults to + `None`. + + Note: Currently, if `allowed_providers` is also set, then this field is + ignored. This is subject to change in a future non-major release. Using + both traits is considered to be undefined behavior at this time. + """, allow_none=True, config=True, ) @@ -64,7 +72,11 @@ class AiExtension(ExtensionApp): blocked_models = List( Unicode(), default_value=None, - help="Language models to block, as a list of global model IDs in the format `:`. If `None`, none are blocked. Defaults to `None`.", + help=""" + Language models to block, as a list of global model IDs in the format + `:`. If `None`, none are blocked. Defaults to + `None`. + """, allow_none=True, config=True, ) @@ -98,7 +110,10 @@ def initialize_settings(self): log=self.log, lm_providers=self.settings["lm_providers"], em_providers=self.settings["em_providers"], - restrictions=restrictions, + allowed_providers=self.allowed_providers, + blocked_providers=self.blocked_providers, + allowed_models=self.allowed_models, + blocked_models=self.blocked_models, ) self.log.info("Registered providers.") diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index 8cb1808fe..dafba7adf 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -36,7 +36,10 @@ def common_cm_kwargs(config_path, schema_path): "em_providers": em_providers, "config_path": config_path, "schema_path": schema_path, - "restrictions": {"allowed_providers": None, "blocked_providers": None}, + "allowed_providers": None, + "blocked_providers": None, + "allowed_models": None, + "blocked_models": None, } @@ -46,6 +49,26 @@ def cm(common_cm_kwargs): return ConfigManager(**common_cm_kwargs) +@pytest.fixture +def cm_with_blocklists(common_cm_kwargs): + kwargs = { + **common_cm_kwargs, + "blocked_providers": ["ai21"], + "blocked_models": ["cohere:medium"], + } + return ConfigManager(**kwargs) + + +@pytest.fixture +def cm_with_allowlists(common_cm_kwargs): + kwargs = { + **common_cm_kwargs, + "allowed_providers": ["ai21"], + "allowed_models": ["cohere:medium"], + } + return ConfigManager(**kwargs) + + @pytest.fixture(autouse=True) def reset(config_path, schema_path): """Fixture that deletes the config and config schema after each test.""" @@ -98,23 +121,43 @@ def test_snapshot_default_config(cm: ConfigManager, snapshot): assert config_from_cm == snapshot(exclude=lambda prop, path: prop == "last_read") -def test_init_with_existing_config( - cm: ConfigManager, config_path: str, schema_path: str -): +def test_init_with_existing_config(cm: ConfigManager, common_cm_kwargs): configure_to_cohere(cm) del cm - log = logging.getLogger() - lm_providers = get_lm_providers() - em_providers = get_em_providers() - ConfigManager( - log=log, - lm_providers=lm_providers, - em_providers=em_providers, - config_path=config_path, - schema_path=schema_path, - restrictions={"allowed_providers": None, "blocked_providers": None}, - ) + ConfigManager(**common_cm_kwargs) + + +def test_init_with_blocklists(cm: ConfigManager, common_cm_kwargs): + configure_to_openai(cm) + del cm + + blocked_providers = ["openai"] # blocks EM + blocked_models = ["openai-chat-new:gpt-3.5-turbo"] # blocks LM + kwargs = { + **common_cm_kwargs, + "blocked_providers": blocked_providers, + "blocked_models": blocked_models, + } + test_cm = ConfigManager(**kwargs) + assert test_cm._blocked_providers == blocked_providers + assert test_cm._blocked_models == blocked_models + assert test_cm.lm_gid == None + assert test_cm.em_gid == None + + +def test_init_with_allowlists(cm: ConfigManager, common_cm_kwargs): + configure_to_cohere(cm) + del cm + + allowed_providers = ["openai"] # blocks both LM & EM + + kwargs = {**common_cm_kwargs, "allowed_providers": allowed_providers} + test_cm = ConfigManager(**kwargs) + assert test_cm._allowed_providers == allowed_providers + assert test_cm._allowed_models == None + assert test_cm.lm_gid == None + assert test_cm.em_gid == None def test_property_access_on_default_config(cm: ConfigManager):