Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to define block and allow lists for providers #415

Merged
merged 7 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.

import pytest
from jupyter_ai_magics.utils import get_lm_providers

KNOWN_LM_A = "openai"
KNOWN_LM_B = "huggingface_hub"


@pytest.mark.parametrize(
"restrictions",
[
{"allowed_providers": None, "blocked_providers": None},
{"allowed_providers": [], "blocked_providers": []},
{"allowed_providers": [], "blocked_providers": [KNOWN_LM_B]},
{"allowed_providers": [KNOWN_LM_A], "blocked_providers": []},
],
)
def test_get_lm_providers_not_restricted(restrictions):
a_not_restricted = get_lm_providers(None, restrictions)
assert KNOWN_LM_A in a_not_restricted


@pytest.mark.parametrize(
"restrictions",
[
{"allowed_providers": [], "blocked_providers": [KNOWN_LM_A]},
{"allowed_providers": [KNOWN_LM_B], "blocked_providers": []},
],
)
def test_get_lm_providers_restricted(restrictions):
a_not_restricted = get_lm_providers(None, restrictions)
assert KNOWN_LM_A not in a_not_restricted
32 changes: 28 additions & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, Optional, Tuple, Type, Union
from typing import Dict, List, Literal, Optional, Tuple, Type, Union

from importlib_metadata import entry_points
from jupyter_ai_magics.aliases import MODEL_ID_ALIASES
Expand All @@ -11,13 +11,19 @@
EmProvidersDict = Dict[str, BaseEmbeddingsProvider]
AnyProvider = Union[BaseProvider, BaseEmbeddingsProvider]
ProviderDict = Dict[str, AnyProvider]
ProviderRestrictions = Dict[
Literal["allowed_providers", "blocked_providers"], Optional[List[str]]
]


def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict:
def get_lm_providers(
log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None
) -> LmProvidersDict:
if not log:
log = logging.getLogger()
log.addHandler(logging.NullHandler())

if not restrictions:
restrictions = {"allowed_providers": None, "blocked_providers": None}
providers = {}
eps = entry_points()
model_provider_eps = eps.select(group="jupyter_ai.model_providers")
Expand All @@ -29,18 +35,23 @@ def get_lm_providers(log: Optional[Logger] = None) -> LmProvidersDict:
f"Unable to load model provider class from entry point `{model_provider_ep.name}`."
)
continue
if not is_provider_allowed(provider.id, restrictions):
log.info(f"Skipping blocked provider `{provider.id}`.")
continue
providers[provider.id] = provider
log.info(f"Registered model provider `{provider.id}`.")

return providers


def get_em_providers(
log: Optional[Logger] = None,
log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None
) -> EmProvidersDict:
if not log:
log = logging.getLogger()
log.addHandler(logging.NullHandler())
if not restrictions:
restrictions = {"allowed_providers": None, "blocked_providers": None}
providers = {}
eps = entry_points()
model_provider_eps = eps.select(group="jupyter_ai.embeddings_model_providers")
Expand All @@ -52,6 +63,9 @@ def get_em_providers(
f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`."
)
continue
if not is_provider_allowed(provider.id, restrictions):
log.info(f"Skipping blocked provider `{provider.id}`.")
continue
providers[provider.id] = provider
log.info(f"Registered embeddings model provider `{provider.id}`.")

Expand Down Expand Up @@ -97,6 +111,16 @@ def get_em_provider(
return _get_provider(model_id, em_providers)


def is_provider_allowed(provider_id: str, restrictions: ProviderRestrictions) -> bool:
allowed = restrictions["allowed_providers"]
blocked = restrictions["blocked_providers"]
if blocked and provider_id in blocked:
return False
if allowed and provider_id not in allowed:
return False
return True


def _get_provider(model_id: str, providers: ProviderDict):
provider_id, local_model_id = decompose_model_id(model_id, providers)
provider = providers.get(provider_id, None)
Expand Down
15 changes: 15 additions & 0 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
AnyProvider,
EmProvidersDict,
LmProvidersDict,
ProviderRestrictions,
get_em_provider,
get_lm_provider,
is_provider_allowed,
)
from jupyter_core.paths import jupyter_data_dir
from traitlets import Integer, Unicode
Expand Down Expand Up @@ -97,6 +99,7 @@ def __init__(
log: Logger,
lm_providers: LmProvidersDict,
em_providers: EmProvidersDict,
restrictions: ProviderRestrictions,
*args,
**kwargs,
):
Expand All @@ -106,6 +109,8 @@ def __init__(
self._lm_providers = lm_providers
"""List of EM providers."""
self._em_providers = em_providers
"""Provider restrictions."""
self._restrictions = restrictions

"""When the server last read the config file. If the file was not
modified after this time, then we can return the cached
Expand Down Expand Up @@ -176,6 +181,10 @@ def _validate_config(self, config: GlobalConfig):
_, lm_provider = get_lm_provider(
config.model_provider_id, self._lm_providers
)
# do not check config for blocked providers
if not is_provider_allowed(config.model_provider_id, self._restrictions):
assert not lm_provider
return
if not lm_provider:
raise ValueError(
f"No language model is associated with '{config.model_provider_id}'."
Expand All @@ -187,6 +196,12 @@ def _validate_config(self, config: GlobalConfig):
_, em_provider = get_em_provider(
config.embeddings_provider_id, self._em_providers
)
# do not check config for blocked providers
if not is_provider_allowed(
config.embeddings_provider_id, self._restrictions
):
assert not em_provider
return
if not em_provider:
raise ValueError(
f"No embedding model is associated with '{config.embeddings_provider_id}'."
Expand Down
30 changes: 28 additions & 2 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from jupyter_ai.chat_handlers.learn import Retriever
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
from jupyter_server.extension.application import ExtensionApp
from traitlets import List, Unicode

from .chat_handlers import (
AskChatHandler,
Expand Down Expand Up @@ -36,18 +37,43 @@ class AiExtension(ExtensionApp):
(r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler),
]

allowed_providers = List(
Unicode(),
default_value=None,
help="Identifiers of allow-listed providers. If `None`, all are allowed.",
allow_none=True,
config=True,
)

blocked_providers = List(
Unicode(),
default_value=None,
help="Identifiers of block-listed providers. If `None`, none are blocked.",
allow_none=True,
config=True,
)

def initialize_settings(self):
start = time.time()
restrictions = {
"allowed_providers": self.allowed_providers,
"blocked_providers": self.blocked_providers,
}

self.settings["lm_providers"] = get_lm_providers(log=self.log)
self.settings["em_providers"] = get_em_providers(log=self.log)
self.settings["lm_providers"] = get_lm_providers(
log=self.log, restrictions=restrictions
)
self.settings["em_providers"] = get_em_providers(
log=self.log, restrictions=restrictions
)

self.settings["jai_config_manager"] = ConfigManager(
# traitlets configuration, not JAI configuration.
config=self.config,
log=self.log,
lm_providers=self.settings["lm_providers"],
em_providers=self.settings["em_providers"],
restrictions=restrictions,
)

self.log.info("Registered providers.")
Expand Down
2 changes: 2 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def common_cm_kwargs(config_path, schema_path):
"em_providers": em_providers,
"config_path": config_path,
"schema_path": schema_path,
"restrictions": {"allowed_providers": None, "blocked_providers": None},
}


Expand Down Expand Up @@ -112,6 +113,7 @@ def test_init_with_existing_config(
em_providers=em_providers,
config_path=config_path,
schema_path=schema_path,
restrictions={"allowed_providers": None, "blocked_providers": None},
)


Expand Down
39 changes: 39 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/test_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import pytest
from jupyter_ai.extension import AiExtension

pytest_plugins = ["pytest_jupyter.jupyter_server"]

KNOWN_LM_A = "openai"
KNOWN_LM_B = "huggingface_hub"


@pytest.mark.parametrize(
"argv",
[
["--AiExtension.blocked_providers", KNOWN_LM_B],
["--AiExtension.allowed_providers", KNOWN_LM_A],
],
)
def test_allows_providers(argv, jp_configurable_serverapp):
server = jp_configurable_serverapp(argv=argv)
ai = AiExtension()
ai._link_jupyter_server_extension(server)
ai.initialize_settings()
assert KNOWN_LM_A in ai.settings["lm_providers"]


@pytest.mark.parametrize(
"argv",
[
["--AiExtension.blocked_providers", KNOWN_LM_A],
["--AiExtension.allowed_providers", KNOWN_LM_B],
],
)
def test_blocks_providers(argv, jp_configurable_serverapp):
server = jp_configurable_serverapp(argv=argv)
ai = AiExtension()
ai._link_jupyter_server_extension(server)
ai.initialize_settings()
assert KNOWN_LM_A not in ai.settings["lm_providers"]
1 change: 1 addition & 0 deletions packages/jupyter-ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ test = [
"pytest-asyncio",
"pytest-cov",
"pytest-tornasync",
"pytest-jupyter",
"syrupy~=4.0.8"
]

Expand Down
Loading