Skip to content

Commit

Permalink
chore(py): avoid use of assert for business logic
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Mortari <[email protected]>
  • Loading branch information
tarilabs committed Jun 11, 2024
1 parent f38a86e commit 378ac27
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 22 deletions.
2 changes: 1 addition & 1 deletion clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ select = [
ignore = [
"D105", # missing docstring in magic method
"E501", # line too long
"S101", # use of assert detected
]
mccabe.max-complexity = 8
per-file-ignores = { "tests/**/*.py" = [
"D", # missing docstring in public module
"S101", # use of assert detected
] }

[tool.ruff.lint.pydocstyle]
Expand Down
8 changes: 6 additions & 2 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def _register_model(self, name: str) -> RegisteredModel:
def _register_new_version(
self, rm: RegisteredModel, version: str, author: str, /, **kwargs
) -> ModelVersion:
assert rm.id is not None, "Registered model must have an ID"
if rm.id is None:
msg = "Registered model must have an ID"
raise ValueError(msg)
if self._api.get_model_version_by_params(rm.id, version):
msg = f"Version {version} already exists"
raise StoreException(msg)
Expand All @@ -92,7 +94,9 @@ def _register_new_version(
def _register_model_artifact(
self, mv: ModelVersion, uri: str, /, **kwargs
) -> ModelArtifact:
assert mv.id is not None, "Model version must have an ID"
if mv.id is None:
msg = "Model version must have an ID"
raise ValueError(msg)
ma = ModelArtifact(mv.model_name, uri, **kwargs)
self._api.upsert_model_artifact(ma, mv.id)
return ma
Expand Down
12 changes: 6 additions & 6 deletions clients/python/src/model_registry/types/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def map(self, type_id: int) -> Artifact:
@override
def unmap(cls, mlmd_obj: Artifact) -> BaseArtifact:
py_obj = super().unmap(mlmd_obj)
assert isinstance(
py_obj, BaseArtifact
), f"Expected BaseArtifact, got {type(py_obj)}"
if not isinstance(py_obj, BaseArtifact):
msg = f"Expected BaseArtifact, got {type(py_obj)}"
raise TypeError(msg)
py_obj.uri = mlmd_obj.uri
py_obj.state = ArtifactState(mlmd_obj.state)
return py_obj
Expand Down Expand Up @@ -120,9 +120,9 @@ def map(self, type_id: int) -> Artifact:
@classmethod
def unmap(cls, mlmd_obj: Artifact) -> ModelArtifact:
py_obj = super().unmap(mlmd_obj)
assert isinstance(
py_obj, ModelArtifact
), f"Expected ModelArtifact, got {type(py_obj)}"
if not isinstance(py_obj, ModelArtifact):
msg = f"Expected ModelArtifact, got {type(py_obj)}"
raise TypeError(msg)
py_obj.model_format_name = mlmd_obj.properties["model_format_name"].string_value
py_obj.model_format_version = mlmd_obj.properties[
"model_format_version"
Expand Down
4 changes: 3 additions & 1 deletion clients/python/src/model_registry/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ def unmap(cls: type[T], mlmd_obj: ProtoType) -> T:
py_obj.id = str(mlmd_obj.id)
if isinstance(py_obj, Prefixable):
name: str = mlmd_obj.name
assert ":" in name, f"Expected {name} to be prefixed"
if ":" not in name:
msg = f"Expected {name} to be prefixed"
raise ValueError(msg)
py_obj.name = name.split(":", 1)[1]
else:
py_obj.name = mlmd_obj.name
Expand Down
24 changes: 12 additions & 12 deletions clients/python/src/model_registry/types/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def map(self, type_id: int) -> Context:
@override
def unmap(cls, mlmd_obj: Context) -> BaseContext:
py_obj = super().unmap(mlmd_obj)
assert isinstance(
py_obj, BaseContext
), f"Expected BaseContext, got {type(py_obj)}"
if not isinstance(py_obj, BaseContext):
msg = f"Expected BaseContext, got {type(py_obj)}"
raise TypeError(msg)
py_obj.state = ContextState(mlmd_obj.properties["state"].string_value)
return py_obj

Expand Down Expand Up @@ -92,9 +92,9 @@ def __attrs_post_init__(self) -> None:
@property
@override
def mlmd_name_prefix(self) -> str:
assert (
self._registered_model_id is not None
), "There's no registered model associated with this version"
if self._registered_model_id is None:
msg = "There's no registered model associated with this version"
raise ValueError(msg)
return self._registered_model_id

@override
Expand All @@ -113,9 +113,9 @@ def map(self, type_id: int) -> Context:
@override
def unmap(cls, mlmd_obj: Context) -> ModelVersion:
py_obj = super().unmap(mlmd_obj)
assert isinstance(
py_obj, ModelVersion
), f"Expected ModelVersion, got {type(py_obj)}"
if not isinstance(py_obj, ModelVersion):
msg = f"Expected ModelVersion, got {type(py_obj)}"
raise TypeError(msg)
py_obj.version = py_obj.name
py_obj.model_name = mlmd_obj.properties["model_name"].string_value
py_obj.author = mlmd_obj.properties["author"].string_value
Expand Down Expand Up @@ -150,8 +150,8 @@ def map(self, type_id: int) -> Context:
@override
def unmap(cls, mlmd_obj: Context) -> RegisteredModel:
py_obj = super().unmap(mlmd_obj)
assert isinstance(
py_obj, RegisteredModel
), f"Expected RegisteredModel, got {type(py_obj)}"
if not isinstance(py_obj, RegisteredModel):
msg = f"Expected RegisteredModel, got {type(py_obj)}"
raise TypeError(msg)
py_obj.owner = mlmd_obj.properties["owner"].string_value
return py_obj

0 comments on commit 378ac27

Please sign in to comment.