Skip to content

Commit

Permalink
py: set author as owner when creating RegisteredModel (#147)
Browse files Browse the repository at this point in the history
Signed-off-by: Isabella do Amaral <[email protected]>
  • Loading branch information
isinyaaa authored Jun 21, 2024
1 parent 7643a53 commit 148e1c4
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
11 changes: 8 additions & 3 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ def __init__(
server_address, port, user_token
)

def _register_model(self, name: str) -> RegisteredModel:
def _register_model(self, name: str, **kwargs) -> RegisteredModel:
if rm := self._api.get_registered_model_by_params(name):
return rm

rm = RegisteredModel(name)
rm = RegisteredModel(name, **kwargs)
self._api.upsert_registered_model(rm)
return rm

Expand Down Expand Up @@ -109,6 +109,7 @@ def register_model(
storage_path: str | None = None,
service_account_name: str | None = None,
author: str | None = None,
owner: str | None = None,
description: str | None = None,
metadata: dict[str, ScalarType] | None = None,
) -> RegisteredModel:
Expand All @@ -132,6 +133,7 @@ def register_model(
model_format_version: Version of the model format.
description: Description of the model.
author: Author of the model. Defaults to the client author.
owner: Owner of the model. Defaults to the client author.
storage_key: Storage key.
storage_path: Storage path.
service_account_name: Service account name.
Expand All @@ -140,7 +142,7 @@ def register_model(
Returns:
Registered model.
"""
rm = self._register_model(name)
rm = self._register_model(name, owner=owner or self._author)
mv = self._register_new_version(
rm,
version,
Expand Down Expand Up @@ -169,6 +171,7 @@ def register_hf_model(
model_format_name: str,
model_format_version: str,
author: str | None = None,
owner: str | None = None,
model_name: str | None = None,
description: str | None = None,
git_ref: str = "main",
Expand All @@ -187,6 +190,7 @@ def register_hf_model(
model_format_name: Name of the model format.
model_format_version: Version of the model format.
author: Author of the model. Defaults to repo owner.
owner: Owner of the model. Defaults to the client author.
model_name: Name of the model. Defaults to the repo name.
description: Description of the model.
git_ref: Git reference to use. Defaults to `main`.
Expand Down Expand Up @@ -244,6 +248,7 @@ def register_hf_model(
model_name or model_info.id,
source_uri,
author=author or model_author,
owner=owner or self._author,
version=version,
model_format_name=model_format_name,
model_format_version=model_format_version,
Expand Down
6 changes: 2 additions & 4 deletions clients/python/src/model_registry/types/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,12 @@ class RegisteredModel(BaseContext):
"""

name: str
owner: str = None
owner: str | None = None

@override
def map(self, type_id: int) -> Context:
mlmd_obj = super().map(type_id)
props = {
"owner": self.owner
}
props = {"owner": self.owner}
self._map_props(props, mlmd_obj.properties)
return mlmd_obj

Expand Down
1 change: 1 addition & 0 deletions clients/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def store_wrapper(plain_wrapper: MLMDStore) -> MLMDStore:
[
"description",
"state",
"owner",
],
)

Expand Down

0 comments on commit 148e1c4

Please sign in to comment.