Skip to content

Commit

Permalink
Py: Fix misleading errors on missing list results (#65)
Browse files Browse the repository at this point in the history
* py: fix type annotations to return concrete types

Signed-off-by: Isabella Basso do Amaral <[email protected]>

* py: fix missing type check on empty list results

Signed-off-by: Isabella Basso do Amaral <[email protected]>

---------

Signed-off-by: Isabella Basso do Amaral <[email protected]>
  • Loading branch information
isinyaaa authored Apr 16, 2024
1 parent 5f7fac0 commit 65613f9
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 16 deletions.
9 changes: 4 additions & 5 deletions clients/python/src/model_registry/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Client for the model registry."""
from __future__ import annotations

from collections.abc import Sequence
from __future__ import annotations

from ml_metadata.proto import MetadataStoreClientConfig

Expand Down Expand Up @@ -130,7 +129,7 @@ def get_registered_model_by_params(

def get_registered_models(
self, options: ListOptions | None = None
) -> Sequence[RegisteredModel]:
) -> list[RegisteredModel]:
"""Fetch registered models.
Args:
Expand Down Expand Up @@ -194,7 +193,7 @@ def get_model_version_by_id(self, model_version_id: str) -> ModelVersion | None:

def get_model_versions(
self, registered_model_id: str, options: ListOptions | None = None
) -> Sequence[ModelVersion]:
) -> list[ModelVersion]:
"""Fetch model versions by registered model ID.
Args:
Expand Down Expand Up @@ -344,7 +343,7 @@ def get_model_artifacts(
self,
model_version_id: str | None = None,
options: ListOptions | None = None,
) -> Sequence[ModelArtifact]:
) -> list[ModelArtifact]:
"""Fetches model artifacts.
Args:
Expand Down
32 changes: 21 additions & 11 deletions clients/python/src/model_registry/store/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def put_context(self, context: Context) -> int:

def _filter_type(
self, type_name: str, protos: Sequence[ProtoType]
) -> Sequence[ProtoType]:
) -> list[ProtoType]:
return [proto for proto in protos if proto.type == type_name]

def get_context(
Expand Down Expand Up @@ -168,9 +168,7 @@ def get_context(

return None

def get_contexts(
self, ctx_type_name: str, options: ListOptions
) -> Sequence[Context]:
def get_contexts(self, ctx_type_name: str, options: ListOptions) -> list[Context]:
"""Get contexts from the store.
Args:
Expand All @@ -179,6 +177,11 @@ def get_contexts(
Returns:
Contexts.
Raises:
TypeNotFoundException: If the type doesn't exist.
ServerException: If there was an error getting the type.
StoreException: Invalid arguments.
"""
# TODO: should we make options optional?
# if options is not None:
Expand All @@ -195,9 +198,11 @@ def get_contexts(
# else:
# contexts = self._mlmd_store.get_contexts_by_type(ctx_type_name)

if not contexts:
if not contexts and ctx_type_name not in [
t.name for t in self._mlmd_store.get_context_types()
]:
msg = f"Context type {ctx_type_name} does not exist"
raise StoreException(msg)
raise TypeNotFoundException(msg)

return contexts

Expand Down Expand Up @@ -309,9 +314,7 @@ def get_attributed_artifact(self, art_type_name: str, ctx_id: int) -> Artifact:

return None

def get_artifacts(
self, art_type_name: str, options: ListOptions
) -> Sequence[Artifact]:
def get_artifacts(self, art_type_name: str, options: ListOptions) -> list[Artifact]:
"""Get artifacts from the store.
Args:
Expand All @@ -320,6 +323,11 @@ def get_artifacts(
Returns:
Artifacts.
Raises:
TypeNotFoundException: If the type doesn't exist.
ServerException: If there was an error getting the type.
StoreException: Invalid arguments.
"""
try:
artifacts = self._mlmd_store.get_artifacts(options)
Expand All @@ -331,8 +339,10 @@ def get_artifacts(
raise ServerException(msg) from e

artifacts = self._filter_type(art_type_name, artifacts)
if not artifacts:
if not artifacts and art_type_name not in [
t.name for t in self._mlmd_store.get_artifact_types()
]:
msg = f"Artifact type {art_type_name} does not exist"
raise StoreException(msg)
raise TypeNotFoundException(msg)

return artifacts
23 changes: 23 additions & 0 deletions clients/python/tests/store/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TypeNotFoundException,
)
from model_registry.store import MLMDStore
from model_registry.types.options import MLMDListOptions


@pytest.fixture()
Expand Down Expand Up @@ -53,6 +54,28 @@ def test_get_undefined_context_type_id(plain_wrapper: MLMDStore):
plain_wrapper.get_type_id(Context, "undefined")


@pytest.mark.usefixtures("artifact")
def test_get_no_artifacts(plain_wrapper: MLMDStore):
arts = plain_wrapper.get_artifacts("test_artifact", MLMDListOptions())
assert arts == []


def test_get_undefined_artifacts(plain_wrapper: MLMDStore):
with pytest.raises(TypeNotFoundException):
plain_wrapper.get_artifacts("undefined", MLMDListOptions())


@pytest.mark.usefixtures("context")
def test_get_no_contexts(plain_wrapper: MLMDStore):
ctxs = plain_wrapper.get_contexts("test_context", MLMDListOptions())
assert ctxs == []


def test_get_undefined_contexts(plain_wrapper: MLMDStore):
with pytest.raises(TypeNotFoundException):
plain_wrapper.get_contexts("undefined", MLMDListOptions())


def test_put_invalid_artifact(plain_wrapper: MLMDStore, artifact: Artifact):
artifact.properties["null"].int_value = 0

Expand Down

0 comments on commit 65613f9

Please sign in to comment.