diff --git a/docs/source/users/index.md b/docs/source/users/index.md index 5ef4e4ed6..1a58ba247 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -705,3 +705,34 @@ The `--region-name` parameter is set to the [AWS region code](https://docs.aws.a The `--request-schema` parameter is the JSON object the endpoint expects as input, with the prompt being substituted into any value that matches the string literal `""`. For example, the request schema `{"text_inputs":""}` will submit a JSON object with the prompt stored under the `text_inputs` key. The `--response-path` option is a [JSONPath](https://goessner.net/articles/JsonPath/index.html) string that retrieves the language model's output from the endpoint's JSON response. For example, if your endpoint returns an object with the schema `{"generated_texts":[""]}`, its response path is `generated_texts.[0]`. + + +## Configuration + +You can specify an allowlist, to only allow only a certain list of providers, or a blocklist, to block some providers. + +### Blocklisting providers +This configuration allows for blocking specific providers in the settings panel. This list takes precedence over the allowlist in the next section. + +``` +jupyter lab --AiExtension.blocked_providers=openai +``` + +To block more than one provider in the block-list, repeat the runtime configuration. + +``` +jupyter lab --AiExtension.blocked_providers=openai --AiExtension.blocked_providers=ai21 +``` + +### Allowlisting providers +This configuration allows for filtering the list of providers in the settings panel to only an allowlisted set of providers. + +``` +jupyter lab --AiExtension.allowed_providers=openai +``` + +To allow more than one provider in the allowlist, repeat the runtime configuration. + +``` +jupyter lab --AiExtension.allowed_providers=openai --AiExtension.allowed_providers=ai21 +``` diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py new file mode 100644 index 000000000..e1c517ebe --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py @@ -0,0 +1,34 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +import pytest +from jupyter_ai_magics.utils import get_lm_providers + +KNOWN_LM_A = "openai" +KNOWN_LM_B = "huggingface_hub" + + +@pytest.mark.parametrize( + "restrictions", + [ + {"allowed_providers": None, "blocked_providers": None}, + {"allowed_providers": [], "blocked_providers": []}, + {"allowed_providers": [], "blocked_providers": [KNOWN_LM_B]}, + {"allowed_providers": [KNOWN_LM_A], "blocked_providers": []}, + ], +) +def test_get_lm_providers_not_restricted(restrictions): + a_not_restricted = get_lm_providers(None, restrictions) + assert KNOWN_LM_A in a_not_restricted + + +@pytest.mark.parametrize( + "restrictions", + [ + {"allowed_providers": [], "blocked_providers": [KNOWN_LM_A]}, + {"allowed_providers": [KNOWN_LM_B], "blocked_providers": []}, + ], +) +def test_get_lm_providers_restricted(restrictions): + a_not_restricted = get_lm_providers(None, restrictions) + assert KNOWN_LM_A not in a_not_restricted diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py index 6a02c61c8..c651581bc 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Optional, Tuple, Type, Union +from typing import Dict, List, Literal, Optional, Tuple, Type, Union from importlib_metadata import entry_points from jupyter_ai_magics.aliases import MODEL_ID_ALIASES @@ -11,13 +11,19 @@ EmProvidersDict = Dict[str, BaseEmbeddingsProvider] AnyProvider = Union[BaseProvider, BaseEmbeddingsProvider] ProviderDict = Dict[str, AnyProvider] +ProviderRestrictions = Dict[ + Literal["allowed_providers", "blocked_providers"], Optional[List[str]] +] -def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict: +def get_lm_providers( + log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None +) -> LmProvidersDict: if not log: log = logging.getLogger() log.addHandler(logging.NullHandler()) - + if not restrictions: + restrictions = {"allowed_providers": None, "blocked_providers": None} providers = {} eps = entry_points() model_provider_eps = eps.select(group="jupyter_ai.model_providers") @@ -29,6 +35,9 @@ def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict: f"Unable to load model provider class from entry point `{model_provider_ep.name}`." ) continue + if not is_provider_allowed(provider.id, restrictions): + log.info(f"Skipping blocked provider `{provider.id}`.") + continue providers[provider.id] = provider log.info(f"Registered model provider `{provider.id}`.") @@ -36,11 +45,13 @@ def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict: def get_em_providers( - log: Optional[Logger] = None, + log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None ) -> EmProvidersDict: if not log: log = logging.getLogger() log.addHandler(logging.NullHandler()) + if not restrictions: + restrictions = {"allowed_providers": None, "blocked_providers": None} providers = {} eps = entry_points() model_provider_eps = eps.select(group="jupyter_ai.embeddings_model_providers") @@ -52,6 +63,9 @@ def get_em_providers( f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`." ) continue + if not is_provider_allowed(provider.id, restrictions): + log.info(f"Skipping blocked provider `{provider.id}`.") + continue providers[provider.id] = provider log.info(f"Registered embeddings model provider `{provider.id}`.") @@ -97,6 +111,16 @@ def get_em_provider( return _get_provider(model_id, em_providers) +def is_provider_allowed(provider_id: str, restrictions: ProviderRestrictions) -> bool: + allowed = restrictions["allowed_providers"] + blocked = restrictions["blocked_providers"] + if blocked and provider_id in blocked: + return False + if allowed and provider_id not in allowed: + return False + return True + + def _get_provider(model_id: str, providers: ProviderDict): provider_id, local_model_id = decompose_model_id(model_id, providers) provider = providers.get(provider_id, None) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index f61e07bde..8708b4b94 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -12,8 +12,10 @@ AnyProvider, EmProvidersDict, LmProvidersDict, + ProviderRestrictions, get_em_provider, get_lm_provider, + is_provider_allowed, ) from jupyter_core.paths import jupyter_data_dir from traitlets import Integer, Unicode @@ -97,6 +99,7 @@ def __init__( log: Logger, lm_providers: LmProvidersDict, em_providers: EmProvidersDict, + restrictions: ProviderRestrictions, *args, **kwargs, ): @@ -106,6 +109,8 @@ def __init__( self._lm_providers = lm_providers """List of EM providers.""" self._em_providers = em_providers + """Provider restrictions.""" + self._restrictions = restrictions """When the server last read the config file. If the file was not modified after this time, then we can return the cached @@ -176,6 +181,10 @@ def _validate_config(self, config: GlobalConfig): _, lm_provider = get_lm_provider( config.model_provider_id, self._lm_providers ) + # do not check config for blocked providers + if not is_provider_allowed(config.model_provider_id, self._restrictions): + assert not lm_provider + return if not lm_provider: raise ValueError( f"No language model is associated with '{config.model_provider_id}'." @@ -187,6 +196,12 @@ def _validate_config(self, config: GlobalConfig): _, em_provider = get_em_provider( config.embeddings_provider_id, self._em_providers ) + # do not check config for blocked providers + if not is_provider_allowed( + config.embeddings_provider_id, self._restrictions + ): + assert not em_provider + return if not em_provider: raise ValueError( f"No embedding model is associated with '{config.embeddings_provider_id}'." diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 7b6d07b31..50865ed96 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -4,6 +4,7 @@ from jupyter_ai.chat_handlers.learn import Retriever from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp +from traitlets import List, Unicode from .chat_handlers import ( AskChatHandler, @@ -36,11 +37,35 @@ class AiExtension(ExtensionApp): (r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler), ] + allowed_providers = List( + Unicode(), + default_value=None, + help="Identifiers of allow-listed providers. If `None`, all are allowed.", + allow_none=True, + config=True, + ) + + blocked_providers = List( + Unicode(), + default_value=None, + help="Identifiers of block-listed providers. If `None`, none are blocked.", + allow_none=True, + config=True, + ) + def initialize_settings(self): start = time.time() + restrictions = { + "allowed_providers": self.allowed_providers, + "blocked_providers": self.blocked_providers, + } - self.settings["lm_providers"] = get_lm_providers(log=self.log) - self.settings["em_providers"] = get_em_providers(log=self.log) + self.settings["lm_providers"] = get_lm_providers( + log=self.log, restrictions=restrictions + ) + self.settings["em_providers"] = get_em_providers( + log=self.log, restrictions=restrictions + ) self.settings["jai_config_manager"] = ConfigManager( # traitlets configuration, not JAI configuration. @@ -48,6 +73,7 @@ def initialize_settings(self): log=self.log, lm_providers=self.settings["lm_providers"], em_providers=self.settings["em_providers"], + restrictions=restrictions, ) self.log.info("Registered providers.") diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index f7afbd29a..8cb1808fe 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -36,6 +36,7 @@ def common_cm_kwargs(config_path, schema_path): "em_providers": em_providers, "config_path": config_path, "schema_path": schema_path, + "restrictions": {"allowed_providers": None, "blocked_providers": None}, } @@ -112,6 +113,7 @@ def test_init_with_existing_config( em_providers=em_providers, config_path=config_path, schema_path=schema_path, + restrictions={"allowed_providers": None, "blocked_providers": None}, ) diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_extension.py b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py new file mode 100644 index 000000000..d1a10df77 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py @@ -0,0 +1,39 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. +import pytest +from jupyter_ai.extension import AiExtension + +pytest_plugins = ["pytest_jupyter.jupyter_server"] + +KNOWN_LM_A = "openai" +KNOWN_LM_B = "huggingface_hub" + + +@pytest.mark.parametrize( + "argv", + [ + ["--AiExtension.blocked_providers", KNOWN_LM_B], + ["--AiExtension.allowed_providers", KNOWN_LM_A], + ], +) +def test_allows_providers(argv, jp_configurable_serverapp): + server = jp_configurable_serverapp(argv=argv) + ai = AiExtension() + ai._link_jupyter_server_extension(server) + ai.initialize_settings() + assert KNOWN_LM_A in ai.settings["lm_providers"] + + +@pytest.mark.parametrize( + "argv", + [ + ["--AiExtension.blocked_providers", KNOWN_LM_A], + ["--AiExtension.allowed_providers", KNOWN_LM_B], + ], +) +def test_blocks_providers(argv, jp_configurable_serverapp): + server = jp_configurable_serverapp(argv=argv) + ai = AiExtension() + ai._link_jupyter_server_extension(server) + ai.initialize_settings() + assert KNOWN_LM_A not in ai.settings["lm_providers"] diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index 993070c04..291775703 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -51,6 +51,7 @@ test = [ "pytest-asyncio", "pytest-cov", "pytest-tornasync", + "pytest-jupyter", "syrupy~=4.0.8" ]