From 3c8245eb61b7e05d9a96a3a2109167cd63a47699 Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Sat, 21 Oct 2023 13:25:42 +0100 Subject: [PATCH 1/7] Allow to block or allow-list providers by id --- .../jupyter_ai_magics/utils.py | 41 +++++++++++++++++-- packages/jupyter-ai/jupyter_ai/extension.py | 29 ++++++++++++- 2 files changed, 65 insertions(+), 5 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py index 6a02c61c8..6a21eb763 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Optional, Tuple, Type, Union +from typing import Dict, Optional, Tuple, Type, Union, Literal, List from importlib_metadata import entry_points from jupyter_ai_magics.aliases import MODEL_ID_ALIASES @@ -13,11 +13,26 @@ ProviderDict = Dict[str, AnyProvider] -def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict: +ProviderRestrictions = Dict[ + Literal["allowed_providers", "blocked_providers"], + Optional[List[str]] +] + + +def get_lm_providers( + log: Optional[Logger] = None, + restrictions: Optional[ProviderRestrictions] = None +) -> LmProvidersDict: if not log: log = logging.getLogger() log.addHandler(logging.NullHandler()) - + if not restrictions: + restrictions = { + "allowed_providers": None, + "blocked_providers": None + } + allowed = restrictions["allowed_providers"] + blocked = restrictions["blocked_providers"] providers = {} eps = entry_points() model_provider_eps = eps.select(group="jupyter_ai.model_providers") @@ -29,6 +44,12 @@ def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict: f"Unable to load model provider class from entry point `{model_provider_ep.name}`." ) continue + if blocked and provider.id in blocked: + log.info(f"Skipping provider not on block-list `{provider.id}`.") + continue + if allowed and provider.id not in allowed: + log.info(f"Skipping provider not on allow-list `{provider.id}`.") + continue providers[provider.id] = provider log.info(f"Registered model provider `{provider.id}`.") @@ -37,10 +58,18 @@ def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict: def get_em_providers( log: Optional[Logger] = None, + restrictions: Optional[ProviderRestrictions] = None ) -> EmProvidersDict: if not log: log = logging.getLogger() log.addHandler(logging.NullHandler()) + if not restrictions: + restrictions = { + "allowed_providers": None, + "blocked_providers": None + } + allowed = restrictions["allowed_providers"] + blocked = restrictions["blocked_providers"] providers = {} eps = entry_points() model_provider_eps = eps.select(group="jupyter_ai.embeddings_model_providers") @@ -52,6 +81,12 @@ def get_em_providers( f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`." ) continue + if blocked and provider.id in blocked: + log.info(f"Skipping provider not on block-list `{provider.id}`.") + continue + if allowed and provider.id not in allowed: + log.info(f"Skipping provider not on allow-list `{provider.id}`.") + continue providers[provider.id] = provider log.info(f"Registered embeddings model provider `{provider.id}`.") diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 7b6d07b31..71e2f3978 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -4,6 +4,7 @@ from jupyter_ai.chat_handlers.learn import Retriever from jupyter_ai_magics.utils import get_em_providers, get_lm_providers from jupyter_server.extension.application import ExtensionApp +from traitlets import List, Unicode from .chat_handlers import ( AskChatHandler, @@ -36,11 +37,35 @@ class AiExtension(ExtensionApp): (r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler), ] + allowed_providers = List( + Unicode(), + default_value=None, + help="Identifiers of allow-listed providers. If `None`, all are allowed.", + allow_none=True, + config=True, + ) + + blocked_providers = List( + Unicode(), + default_value=None, + help="Identifiers of block-listed providers. If `None`, none are blocked.", + allow_none=True, + config=True, + ) + def initialize_settings(self): start = time.time() + restrictions = { + "allowed_providers": self.allowed_providers, + "blocked_providers": self.blocked_providers, + } - self.settings["lm_providers"] = get_lm_providers(log=self.log) - self.settings["em_providers"] = get_em_providers(log=self.log) + self.settings["lm_providers"] = get_lm_providers( + log=self.log, restrictions=restrictions + ) + self.settings["em_providers"] = get_em_providers( + log=self.log, restrictions=restrictions + ) self.settings["jai_config_manager"] = ConfigManager( # traitlets configuration, not JAI configuration. From ac122de8f90e3048063ff7ea324a78a666199a78 Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Sat, 21 Oct 2023 14:56:17 +0100 Subject: [PATCH 2/7] Add tests for block/allow-lists --- .../jupyter_ai_magics/tests/test_utils.py | 34 ++++++++++++++++ .../jupyter_ai/tests/test_extension.py | 39 +++++++++++++++++++ packages/jupyter-ai/pyproject.toml | 1 + 3 files changed, 74 insertions(+) create mode 100644 packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py create mode 100644 packages/jupyter-ai/jupyter_ai/tests/test_extension.py diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py new file mode 100644 index 000000000..e1c517ebe --- /dev/null +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py @@ -0,0 +1,34 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +import pytest +from jupyter_ai_magics.utils import get_lm_providers + +KNOWN_LM_A = "openai" +KNOWN_LM_B = "huggingface_hub" + + +@pytest.mark.parametrize( + "restrictions", + [ + {"allowed_providers": None, "blocked_providers": None}, + {"allowed_providers": [], "blocked_providers": []}, + {"allowed_providers": [], "blocked_providers": [KNOWN_LM_B]}, + {"allowed_providers": [KNOWN_LM_A], "blocked_providers": []}, + ], +) +def test_get_lm_providers_not_restricted(restrictions): + a_not_restricted = get_lm_providers(None, restrictions) + assert KNOWN_LM_A in a_not_restricted + + +@pytest.mark.parametrize( + "restrictions", + [ + {"allowed_providers": [], "blocked_providers": [KNOWN_LM_A]}, + {"allowed_providers": [KNOWN_LM_B], "blocked_providers": []}, + ], +) +def test_get_lm_providers_restricted(restrictions): + a_not_restricted = get_lm_providers(None, restrictions) + assert KNOWN_LM_A not in a_not_restricted diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_extension.py b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py new file mode 100644 index 000000000..d1a10df77 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tests/test_extension.py @@ -0,0 +1,39 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. +import pytest +from jupyter_ai.extension import AiExtension + +pytest_plugins = ["pytest_jupyter.jupyter_server"] + +KNOWN_LM_A = "openai" +KNOWN_LM_B = "huggingface_hub" + + +@pytest.mark.parametrize( + "argv", + [ + ["--AiExtension.blocked_providers", KNOWN_LM_B], + ["--AiExtension.allowed_providers", KNOWN_LM_A], + ], +) +def test_allows_providers(argv, jp_configurable_serverapp): + server = jp_configurable_serverapp(argv=argv) + ai = AiExtension() + ai._link_jupyter_server_extension(server) + ai.initialize_settings() + assert KNOWN_LM_A in ai.settings["lm_providers"] + + +@pytest.mark.parametrize( + "argv", + [ + ["--AiExtension.blocked_providers", KNOWN_LM_A], + ["--AiExtension.allowed_providers", KNOWN_LM_B], + ], +) +def test_blocks_providers(argv, jp_configurable_serverapp): + server = jp_configurable_serverapp(argv=argv) + ai = AiExtension() + ai._link_jupyter_server_extension(server) + ai.initialize_settings() + assert KNOWN_LM_A not in ai.settings["lm_providers"] diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index 2e56e0b07..99dd13ddd 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -51,6 +51,7 @@ test = [ "pytest-asyncio", "pytest-cov", "pytest-tornasync", + "pytest-jupyter", "syrupy~=4.0.8" ] From 874f4775f9efbaddf8000ddfdccf8e12474e71c2 Mon Sep 17 00:00:00 2001 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: Sat, 21 Oct 2023 15:10:29 +0100 Subject: [PATCH 3/7] Fix "No language model is associated with" issue This was appearing because the models which are blocked were not returned (correctly!) but the previous validation logic did not know that sometimes models may be missing for a valid reason even if there are existing settings for these. --- .../jupyter_ai_magics/utils.py | 51 ++++++++----------- .../jupyter-ai/jupyter_ai/config_manager.py | 15 ++++++ packages/jupyter-ai/jupyter_ai/extension.py | 1 + .../jupyter_ai/tests/test_config_manager.py | 2 + 4 files changed, 38 insertions(+), 31 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py index 6a21eb763..c651581bc 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/utils.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Optional, Tuple, Type, Union, Literal, List +from typing import Dict, List, Literal, Optional, Tuple, Type, Union from importlib_metadata import entry_points from jupyter_ai_magics.aliases import MODEL_ID_ALIASES @@ -11,28 +11,19 @@ EmProvidersDict = Dict[str, BaseEmbeddingsProvider] AnyProvider = Union[BaseProvider, BaseEmbeddingsProvider] ProviderDict = Dict[str, AnyProvider] - - ProviderRestrictions = Dict[ - Literal["allowed_providers", "blocked_providers"], - Optional[List[str]] + Literal["allowed_providers", "blocked_providers"], Optional[List[str]] ] def get_lm_providers( - log: Optional[Logger] = None, - restrictions: Optional[ProviderRestrictions] = None + log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None ) -> LmProvidersDict: if not log: log = logging.getLogger() log.addHandler(logging.NullHandler()) if not restrictions: - restrictions = { - "allowed_providers": None, - "blocked_providers": None - } - allowed = restrictions["allowed_providers"] - blocked = restrictions["blocked_providers"] + restrictions = {"allowed_providers": None, "blocked_providers": None} providers = {} eps = entry_points() model_provider_eps = eps.select(group="jupyter_ai.model_providers") @@ -44,11 +35,8 @@ def get_lm_providers( f"Unable to load model provider class from entry point `{model_provider_ep.name}`." ) continue - if blocked and provider.id in blocked: - log.info(f"Skipping provider not on block-list `{provider.id}`.") - continue - if allowed and provider.id not in allowed: - log.info(f"Skipping provider not on allow-list `{provider.id}`.") + if not is_provider_allowed(provider.id, restrictions): + log.info(f"Skipping blocked provider `{provider.id}`.") continue providers[provider.id] = provider log.info(f"Registered model provider `{provider.id}`.") @@ -57,19 +45,13 @@ def get_lm_providers( def get_em_providers( - log: Optional[Logger] = None, - restrictions: Optional[ProviderRestrictions] = None + log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None ) -> EmProvidersDict: if not log: log = logging.getLogger() log.addHandler(logging.NullHandler()) if not restrictions: - restrictions = { - "allowed_providers": None, - "blocked_providers": None - } - allowed = restrictions["allowed_providers"] - blocked = restrictions["blocked_providers"] + restrictions = {"allowed_providers": None, "blocked_providers": None} providers = {} eps = entry_points() model_provider_eps = eps.select(group="jupyter_ai.embeddings_model_providers") @@ -81,11 +63,8 @@ def get_em_providers( f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`." ) continue - if blocked and provider.id in blocked: - log.info(f"Skipping provider not on block-list `{provider.id}`.") - continue - if allowed and provider.id not in allowed: - log.info(f"Skipping provider not on allow-list `{provider.id}`.") + if not is_provider_allowed(provider.id, restrictions): + log.info(f"Skipping blocked provider `{provider.id}`.") continue providers[provider.id] = provider log.info(f"Registered embeddings model provider `{provider.id}`.") @@ -132,6 +111,16 @@ def get_em_provider( return _get_provider(model_id, em_providers) +def is_provider_allowed(provider_id: str, restrictions: ProviderRestrictions) -> bool: + allowed = restrictions["allowed_providers"] + blocked = restrictions["blocked_providers"] + if blocked and provider_id in blocked: + return False + if allowed and provider_id not in allowed: + return False + return True + + def _get_provider(model_id: str, providers: ProviderDict): provider_id, local_model_id = decompose_model_id(model_id, providers) provider = providers.get(provider_id, None) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index f61e07bde..8708b4b94 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -12,8 +12,10 @@ AnyProvider, EmProvidersDict, LmProvidersDict, + ProviderRestrictions, get_em_provider, get_lm_provider, + is_provider_allowed, ) from jupyter_core.paths import jupyter_data_dir from traitlets import Integer, Unicode @@ -97,6 +99,7 @@ def __init__( log: Logger, lm_providers: LmProvidersDict, em_providers: EmProvidersDict, + restrictions: ProviderRestrictions, *args, **kwargs, ): @@ -106,6 +109,8 @@ def __init__( self._lm_providers = lm_providers """List of EM providers.""" self._em_providers = em_providers + """Provider restrictions.""" + self._restrictions = restrictions """When the server last read the config file. If the file was not modified after this time, then we can return the cached @@ -176,6 +181,10 @@ def _validate_config(self, config: GlobalConfig): _, lm_provider = get_lm_provider( config.model_provider_id, self._lm_providers ) + # do not check config for blocked providers + if not is_provider_allowed(config.model_provider_id, self._restrictions): + assert not lm_provider + return if not lm_provider: raise ValueError( f"No language model is associated with '{config.model_provider_id}'." @@ -187,6 +196,12 @@ def _validate_config(self, config: GlobalConfig): _, em_provider = get_em_provider( config.embeddings_provider_id, self._em_providers ) + # do not check config for blocked providers + if not is_provider_allowed( + config.embeddings_provider_id, self._restrictions + ): + assert not em_provider + return if not em_provider: raise ValueError( f"No embedding model is associated with '{config.embeddings_provider_id}'." diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index 71e2f3978..50865ed96 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -73,6 +73,7 @@ def initialize_settings(self): log=self.log, lm_providers=self.settings["lm_providers"], em_providers=self.settings["em_providers"], + restrictions=restrictions, ) self.log.info("Registered providers.") diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index f7afbd29a..8cb1808fe 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -36,6 +36,7 @@ def common_cm_kwargs(config_path, schema_path): "em_providers": em_providers, "config_path": config_path, "schema_path": schema_path, + "restrictions": {"allowed_providers": None, "blocked_providers": None}, } @@ -112,6 +113,7 @@ def test_init_with_existing_config( em_providers=em_providers, config_path=config_path, schema_path=schema_path, + restrictions={"allowed_providers": None, "blocked_providers": None}, ) From 8293cb54b8af13a45c775dea11a8b23a9a74223e Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 23 Oct 2023 11:22:57 -0700 Subject: [PATCH 4/7] Add docs for allow listing and block listing providers --- docs/source/users/index.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/docs/source/users/index.md b/docs/source/users/index.md index 28f73193d..f0af39cda 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -748,3 +748,32 @@ The `--region-name` parameter is set to the [AWS region code](https://docs.aws.a The `--request-schema` parameter is the JSON object the endpoint expects as input, with the prompt being substituted into any value that matches the string literal `""`. For example, the request schema `{"text_inputs":""}` will submit a JSON object with the prompt stored under the `text_inputs` key. The `--response-path` option is a [JSONPath](https://goessner.net/articles/JsonPath/index.html) string that retrieves the language model's output from the endpoint's JSON response. For example, if your endpoint returns an object with the schema `{"generated_texts":[""]}`, its response path is `generated_texts.[0]`. + + +## Configuration + +### Block-listing providers +This configuration allows for blocking specific providers in the settings panel. This list takes precedence over the allowlist in the next section. + +``` +jupyter lab --Ai.blocked_providers=openai +``` + +To block more than one provider in the block-list, repeat the runtime configuration. + +``` +jupyter lab --Ai.blocked_providers=openai --Ai.blocked_providers=ai21 +``` + +### Allow-listing providers +This configuration allows for filtering the list of providers in the settings panel to only an allow-listed set of providers. + +``` +jupyter lab --Ai.allowed_providers=openai +``` + +To allow more than one provider in the allow-list, repeat the runtime configuration. + +``` +jupyter lab --Ai.allowed_providers=openai --Ai.allowed_providers=ai21 +``` From 0c06599ca2b44171e06c2078cf461be13979b768 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 23 Oct 2023 12:04:14 -0700 Subject: [PATCH 5/7] Updated docs --- docs/source/users/index.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/users/index.md b/docs/source/users/index.md index f0af39cda..e1fc02ecd 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -752,7 +752,7 @@ The `--response-path` option is a [JSONPath](https://goessner.net/articles/JsonP ## Configuration -### Block-listing providers +### Blocklisting providers This configuration allows for blocking specific providers in the settings panel. This list takes precedence over the allowlist in the next section. ``` @@ -765,14 +765,14 @@ To block more than one provider in the block-list, repeat the runtime configurat jupyter lab --Ai.blocked_providers=openai --Ai.blocked_providers=ai21 ``` -### Allow-listing providers -This configuration allows for filtering the list of providers in the settings panel to only an allow-listed set of providers. +### Allowlisting providers +This configuration allows for filtering the list of providers in the settings panel to only an allowlisted set of providers. ``` jupyter lab --Ai.allowed_providers=openai ``` -To allow more than one provider in the allow-list, repeat the runtime configuration. +To allow more than one provider in the allowlist, repeat the runtime configuration. ``` jupyter lab --Ai.allowed_providers=openai --Ai.allowed_providers=ai21 From 7270751edc05cbbb7841a9ceaa93c1fea50db014 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 23 Oct 2023 12:17:01 -0700 Subject: [PATCH 6/7] Added an intro block to docs --- docs/source/users/index.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/users/index.md b/docs/source/users/index.md index e1fc02ecd..bd52bb1a6 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -752,6 +752,8 @@ The `--response-path` option is a [JSONPath](https://goessner.net/articles/JsonP ## Configuration +You can specify an allowlist, to only allow a certain list of providers, or a blocklist, to block some providers. + ### Blocklisting providers This configuration allows for blocking specific providers in the settings panel. This list takes precedence over the allowlist in the next section. From 53d0d9eef20d6eb9bdcef5c2615fc14c42e6e9ba Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 23 Oct 2023 12:28:25 -0700 Subject: [PATCH 7/7] Updated the docs --- docs/source/users/index.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/users/index.md b/docs/source/users/index.md index bd52bb1a6..72b826a67 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -752,30 +752,30 @@ The `--response-path` option is a [JSONPath](https://goessner.net/articles/JsonP ## Configuration -You can specify an allowlist, to only allow a certain list of providers, or a blocklist, to block some providers. +You can specify an allowlist, to only allow only a certain list of providers, or a blocklist, to block some providers. ### Blocklisting providers This configuration allows for blocking specific providers in the settings panel. This list takes precedence over the allowlist in the next section. ``` -jupyter lab --Ai.blocked_providers=openai +jupyter lab --AiExtension.blocked_providers=openai ``` To block more than one provider in the block-list, repeat the runtime configuration. ``` -jupyter lab --Ai.blocked_providers=openai --Ai.blocked_providers=ai21 +jupyter lab --AiExtension.blocked_providers=openai --AiExtension.blocked_providers=ai21 ``` ### Allowlisting providers This configuration allows for filtering the list of providers in the settings panel to only an allowlisted set of providers. ``` -jupyter lab --Ai.allowed_providers=openai +jupyter lab --AiExtension.allowed_providers=openai ``` To allow more than one provider in the allowlist, repeat the runtime configuration. ``` -jupyter lab --Ai.allowed_providers=openai --Ai.allowed_providers=ai21 +jupyter lab --AiExtension.allowed_providers=openai --AiExtension.allowed_providers=ai21 ```