Skip to content

Commit

Permalink
Fix "No language model is associated with" issue
Browse files Browse the repository at this point in the history
This was appearing because the models which are blocked were not
returned (correctly!) but the previous validation logic did not
know that sometimes models may be missing for a valid reason even
if there are existing settings for these.
  • Loading branch information
krassowski committed Oct 21, 2023
1 parent ac122de commit 874f477
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 31 deletions.
51 changes: 20 additions & 31 deletions packages/jupyter-ai-magics/jupyter_ai_magics/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, Optional, Tuple, Type, Union, Literal, List
from typing import Dict, List, Literal, Optional, Tuple, Type, Union

from importlib_metadata import entry_points
from jupyter_ai_magics.aliases import MODEL_ID_ALIASES
Expand All @@ -11,28 +11,19 @@
EmProvidersDict = Dict[str, BaseEmbeddingsProvider]
AnyProvider = Union[BaseProvider, BaseEmbeddingsProvider]
ProviderDict = Dict[str, AnyProvider]


ProviderRestrictions = Dict[
Literal["allowed_providers", "blocked_providers"],
Optional[List[str]]
Literal["allowed_providers", "blocked_providers"], Optional[List[str]]
]


def get_lm_providers(
log: Optional[Logger] = None,
restrictions: Optional[ProviderRestrictions] = None
log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None
) -> LmProvidersDict:
if not log:
log = logging.getLogger()
log.addHandler(logging.NullHandler())
if not restrictions:
restrictions = {
"allowed_providers": None,
"blocked_providers": None
}
allowed = restrictions["allowed_providers"]
blocked = restrictions["blocked_providers"]
restrictions = {"allowed_providers": None, "blocked_providers": None}
providers = {}
eps = entry_points()
model_provider_eps = eps.select(group="jupyter_ai.model_providers")
Expand All @@ -44,11 +35,8 @@ def get_lm_providers(
f"Unable to load model provider class from entry point `{model_provider_ep.name}`."
)
continue
if blocked and provider.id in blocked:
log.info(f"Skipping provider not on block-list `{provider.id}`.")
continue
if allowed and provider.id not in allowed:
log.info(f"Skipping provider not on allow-list `{provider.id}`.")
if not is_provider_allowed(provider.id, restrictions):
log.info(f"Skipping blocked provider `{provider.id}`.")
continue
providers[provider.id] = provider
log.info(f"Registered model provider `{provider.id}`.")
Expand All @@ -57,19 +45,13 @@ def get_lm_providers(


def get_em_providers(
log: Optional[Logger] = None,
restrictions: Optional[ProviderRestrictions] = None
log: Optional[Logger] = None, restrictions: Optional[ProviderRestrictions] = None
) -> EmProvidersDict:
if not log:
log = logging.getLogger()
log.addHandler(logging.NullHandler())
if not restrictions:
restrictions = {
"allowed_providers": None,
"blocked_providers": None
}
allowed = restrictions["allowed_providers"]
blocked = restrictions["blocked_providers"]
restrictions = {"allowed_providers": None, "blocked_providers": None}
providers = {}
eps = entry_points()
model_provider_eps = eps.select(group="jupyter_ai.embeddings_model_providers")
Expand All @@ -81,11 +63,8 @@ def get_em_providers(
f"Unable to load embeddings model provider class from entry point `{model_provider_ep.name}`."
)
continue
if blocked and provider.id in blocked:
log.info(f"Skipping provider not on block-list `{provider.id}`.")
continue
if allowed and provider.id not in allowed:
log.info(f"Skipping provider not on allow-list `{provider.id}`.")
if not is_provider_allowed(provider.id, restrictions):
log.info(f"Skipping blocked provider `{provider.id}`.")
continue
providers[provider.id] = provider
log.info(f"Registered embeddings model provider `{provider.id}`.")
Expand Down Expand Up @@ -132,6 +111,16 @@ def get_em_provider(
return _get_provider(model_id, em_providers)


def is_provider_allowed(provider_id: str, restrictions: ProviderRestrictions) -> bool:
allowed = restrictions["allowed_providers"]
blocked = restrictions["blocked_providers"]
if blocked and provider_id in blocked:
return False
if allowed and provider_id not in allowed:
return False
return True


def _get_provider(model_id: str, providers: ProviderDict):
provider_id, local_model_id = decompose_model_id(model_id, providers)
provider = providers.get(provider_id, None)
Expand Down
15 changes: 15 additions & 0 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
AnyProvider,
EmProvidersDict,
LmProvidersDict,
ProviderRestrictions,
get_em_provider,
get_lm_provider,
is_provider_allowed,
)
from jupyter_core.paths import jupyter_data_dir
from traitlets import Integer, Unicode
Expand Down Expand Up @@ -97,6 +99,7 @@ def __init__(
log: Logger,
lm_providers: LmProvidersDict,
em_providers: EmProvidersDict,
restrictions: ProviderRestrictions,
*args,
**kwargs,
):
Expand All @@ -106,6 +109,8 @@ def __init__(
self._lm_providers = lm_providers
"""List of EM providers."""
self._em_providers = em_providers
"""Provider restrictions."""
self._restrictions = restrictions

"""When the server last read the config file. If the file was not
modified after this time, then we can return the cached
Expand Down Expand Up @@ -176,6 +181,10 @@ def _validate_config(self, config: GlobalConfig):
_, lm_provider = get_lm_provider(
config.model_provider_id, self._lm_providers
)
# do not check config for blocked providers
if not is_provider_allowed(config.model_provider_id, self._restrictions):
assert not lm_provider
return
if not lm_provider:
raise ValueError(
f"No language model is associated with '{config.model_provider_id}'."
Expand All @@ -187,6 +196,12 @@ def _validate_config(self, config: GlobalConfig):
_, em_provider = get_em_provider(
config.embeddings_provider_id, self._em_providers
)
# do not check config for blocked providers
if not is_provider_allowed(
config.embeddings_provider_id, self._restrictions
):
assert not em_provider
return
if not em_provider:
raise ValueError(
f"No embedding model is associated with '{config.embeddings_provider_id}'."
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def initialize_settings(self):
log=self.log,
lm_providers=self.settings["lm_providers"],
em_providers=self.settings["em_providers"],
restrictions=restrictions,
)

self.log.info("Registered providers.")
Expand Down
2 changes: 2 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def common_cm_kwargs(config_path, schema_path):
"em_providers": em_providers,
"config_path": config_path,
"schema_path": schema_path,
"restrictions": {"allowed_providers": None, "blocked_providers": None},
}


Expand Down Expand Up @@ -112,6 +113,7 @@ def test_init_with_existing_config(
em_providers=em_providers,
config_path=config_path,
schema_path=schema_path,
restrictions={"allowed_providers": None, "blocked_providers": None},
)


Expand Down

0 comments on commit 874f477

Please sign in to comment.