Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Py: Fix misleading errors on missing list results #65

Merged
merged 2 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions clients/python/src/model_registry/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Client for the model registry."""
from __future__ import annotations

from collections.abc import Sequence
from __future__ import annotations

from ml_metadata.proto import MetadataStoreClientConfig

Expand Down Expand Up @@ -130,7 +129,7 @@ def get_registered_model_by_params(

def get_registered_models(
self, options: ListOptions | None = None
) -> Sequence[RegisteredModel]:
) -> list[RegisteredModel]:
"""Fetch registered models.

Args:
Expand Down Expand Up @@ -194,7 +193,7 @@ def get_model_version_by_id(self, model_version_id: str) -> ModelVersion | None:

def get_model_versions(
self, registered_model_id: str, options: ListOptions | None = None
) -> Sequence[ModelVersion]:
) -> list[ModelVersion]:
"""Fetch model versions by registered model ID.

Args:
Expand Down Expand Up @@ -344,7 +343,7 @@ def get_model_artifacts(
self,
model_version_id: str | None = None,
options: ListOptions | None = None,
) -> Sequence[ModelArtifact]:
) -> list[ModelArtifact]:
"""Fetches model artifacts.

Args:
Expand Down
32 changes: 21 additions & 11 deletions clients/python/src/model_registry/store/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def put_context(self, context: Context) -> int:

def _filter_type(
self, type_name: str, protos: Sequence[ProtoType]
) -> Sequence[ProtoType]:
) -> list[ProtoType]:
return [proto for proto in protos if proto.type == type_name]

def get_context(
Expand Down Expand Up @@ -168,9 +168,7 @@ def get_context(

return None

def get_contexts(
self, ctx_type_name: str, options: ListOptions
) -> Sequence[Context]:
def get_contexts(self, ctx_type_name: str, options: ListOptions) -> list[Context]:
"""Get contexts from the store.

Args:
Expand All @@ -179,6 +177,11 @@ def get_contexts(

Returns:
Contexts.

Raises:
TypeNotFoundException: If the type doesn't exist.
ServerException: If there was an error getting the type.
StoreException: Invalid arguments.
"""
# TODO: should we make options optional?
# if options is not None:
Expand All @@ -195,9 +198,11 @@ def get_contexts(
# else:
# contexts = self._mlmd_store.get_contexts_by_type(ctx_type_name)

if not contexts:
if not contexts and ctx_type_name not in [
t.name for t in self._mlmd_store.get_context_types()
]:
msg = f"Context type {ctx_type_name} does not exist"
raise StoreException(msg)
raise TypeNotFoundException(msg)

return contexts

Expand Down Expand Up @@ -309,9 +314,7 @@ def get_attributed_artifact(self, art_type_name: str, ctx_id: int) -> Artifact:

return None

def get_artifacts(
self, art_type_name: str, options: ListOptions
) -> Sequence[Artifact]:
def get_artifacts(self, art_type_name: str, options: ListOptions) -> list[Artifact]:
"""Get artifacts from the store.

Args:
Expand All @@ -320,6 +323,11 @@ def get_artifacts(

Returns:
Artifacts.

Raises:
TypeNotFoundException: If the type doesn't exist.
ServerException: If there was an error getting the type.
StoreException: Invalid arguments.
"""
try:
artifacts = self._mlmd_store.get_artifacts(options)
Expand All @@ -331,8 +339,10 @@ def get_artifacts(
raise ServerException(msg) from e

artifacts = self._filter_type(art_type_name, artifacts)
if not artifacts:
if not artifacts and art_type_name not in [
t.name for t in self._mlmd_store.get_artifact_types()
]:
msg = f"Artifact type {art_type_name} does not exist"
raise StoreException(msg)
raise TypeNotFoundException(msg)

return artifacts
23 changes: 23 additions & 0 deletions clients/python/tests/store/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TypeNotFoundException,
)
from model_registry.store import MLMDStore
from model_registry.types.options import MLMDListOptions


@pytest.fixture()
Expand Down Expand Up @@ -53,6 +54,28 @@ def test_get_undefined_context_type_id(plain_wrapper: MLMDStore):
plain_wrapper.get_type_id(Context, "undefined")


@pytest.mark.usefixtures("artifact")
def test_get_no_artifacts(plain_wrapper: MLMDStore):
arts = plain_wrapper.get_artifacts("test_artifact", MLMDListOptions())
assert arts == []


def test_get_undefined_artifacts(plain_wrapper: MLMDStore):
with pytest.raises(TypeNotFoundException):
plain_wrapper.get_artifacts("undefined", MLMDListOptions())


@pytest.mark.usefixtures("context")
def test_get_no_contexts(plain_wrapper: MLMDStore):
ctxs = plain_wrapper.get_contexts("test_context", MLMDListOptions())
assert ctxs == []


def test_get_undefined_contexts(plain_wrapper: MLMDStore):
with pytest.raises(TypeNotFoundException):
plain_wrapper.get_contexts("undefined", MLMDListOptions())


def test_put_invalid_artifact(plain_wrapper: MLMDStore, artifact: Artifact):
artifact.properties["null"].int_value = 0

Expand Down
Loading