From 148e1c49d628b6388afc134d441982eae5e9608a Mon Sep 17 00:00:00 2001 From: Isabella Basso Date: Fri, 21 Jun 2024 08:41:06 -0300 Subject: [PATCH] py: set author as owner when creating RegisteredModel (#147) Signed-off-by: Isabella do Amaral --- clients/python/src/model_registry/_client.py | 11 ++++++++--- clients/python/src/model_registry/types/contexts.py | 6 ++---- clients/python/tests/conftest.py | 1 + 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/clients/python/src/model_registry/_client.py b/clients/python/src/model_registry/_client.py index f8095684..6c70726a 100644 --- a/clients/python/src/model_registry/_client.py +++ b/clients/python/src/model_registry/_client.py @@ -69,11 +69,11 @@ def __init__( server_address, port, user_token ) - def _register_model(self, name: str) -> RegisteredModel: + def _register_model(self, name: str, **kwargs) -> RegisteredModel: if rm := self._api.get_registered_model_by_params(name): return rm - rm = RegisteredModel(name) + rm = RegisteredModel(name, **kwargs) self._api.upsert_registered_model(rm) return rm @@ -109,6 +109,7 @@ def register_model( storage_path: str | None = None, service_account_name: str | None = None, author: str | None = None, + owner: str | None = None, description: str | None = None, metadata: dict[str, ScalarType] | None = None, ) -> RegisteredModel: @@ -132,6 +133,7 @@ def register_model( model_format_version: Version of the model format. description: Description of the model. author: Author of the model. Defaults to the client author. + owner: Owner of the model. Defaults to the client author. storage_key: Storage key. storage_path: Storage path. service_account_name: Service account name. @@ -140,7 +142,7 @@ def register_model( Returns: Registered model. """ - rm = self._register_model(name) + rm = self._register_model(name, owner=owner or self._author) mv = self._register_new_version( rm, version, @@ -169,6 +171,7 @@ def register_hf_model( model_format_name: str, model_format_version: str, author: str | None = None, + owner: str | None = None, model_name: str | None = None, description: str | None = None, git_ref: str = "main", @@ -187,6 +190,7 @@ def register_hf_model( model_format_name: Name of the model format. model_format_version: Version of the model format. author: Author of the model. Defaults to repo owner. + owner: Owner of the model. Defaults to the client author. model_name: Name of the model. Defaults to the repo name. description: Description of the model. git_ref: Git reference to use. Defaults to `main`. @@ -244,6 +248,7 @@ def register_hf_model( model_name or model_info.id, source_uri, author=author or model_author, + owner=owner or self._author, version=version, model_format_name=model_format_name, model_format_version=model_format_version, diff --git a/clients/python/src/model_registry/types/contexts.py b/clients/python/src/model_registry/types/contexts.py index 04fbd061..ecd083c7 100644 --- a/clients/python/src/model_registry/types/contexts.py +++ b/clients/python/src/model_registry/types/contexts.py @@ -135,14 +135,12 @@ class RegisteredModel(BaseContext): """ name: str - owner: str = None + owner: str | None = None @override def map(self, type_id: int) -> Context: mlmd_obj = super().map(type_id) - props = { - "owner": self.owner - } + props = {"owner": self.owner} self._map_props(props, mlmd_obj.properties) return mlmd_obj diff --git a/clients/python/tests/conftest.py b/clients/python/tests/conftest.py index 9cb4f397..b5ceb1e2 100644 --- a/clients/python/tests/conftest.py +++ b/clients/python/tests/conftest.py @@ -160,6 +160,7 @@ def store_wrapper(plain_wrapper: MLMDStore) -> MLMDStore: [ "description", "state", + "owner", ], )