From c97bcd205f2795161e5c5113ab19622cbdf8de9d Mon Sep 17 00:00:00 2001 From: Isabella Basso do Amaral Date: Wed, 13 Mar 2024 14:52:44 -0300 Subject: [PATCH] py: provide overloads for type-checking register_model Signed-off-by: Isabella Basso do Amaral --- clients/python/src/model_registry/_client.py | 81 ++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/clients/python/src/model_registry/_client.py b/clients/python/src/model_registry/_client.py index ae522f39..535d57bb 100644 --- a/clients/python/src/model_registry/_client.py +++ b/clients/python/src/model_registry/_client.py @@ -7,6 +7,8 @@ from urllib import parse from warnings import warn +from typing_extensions import overload + from .core import ModelRegistryAPIClient from .exceptions import StoreException from .store import ScalarType @@ -69,6 +71,85 @@ def _register_model_artifact( self._api.upsert_model_artifact(ma, mv.id) return ma + @overload + def register_model( + self, + name: str, + uri: str, + *, + model_format_name: str, + model_format_version: str, + version: str, + author: str | None = None, + description: str | None = None, + metadata: dict[str, ScalarType] | None = None, + ) -> RegisteredModel: ... + + @overload + def register_model( + self, + name: str, + uri: str, + *, + model_format_name: str, + model_format_version: str, + version: str, + storage_key: str, + storage_path: str, + author: str | None = None, + description: str | None = None, + metadata: dict[str, ScalarType] | None = None, + ) -> RegisteredModel: ... + + @overload + def register_model( + self, + name: str, + uri: str, + *, + model_format_name: str, + model_format_version: str, + version: str, + service_account_name: str, + author: str | None = None, + description: str | None = None, + metadata: dict[str, ScalarType] | None = None, + ) -> RegisteredModel: ... + + @overload + def register_model( + self, + name: str, + *, + model_format_name: str, + model_format_version: str, + version: str, + storage_key: str, + storage_path: str, + bucket_name: str, + author: str | None = None, + description: str | None = None, + metadata: dict[str, ScalarType] | None = None, + ) -> RegisteredModel: ... + + @overload + def register_model( + self, + name: str, + *, + model_format_name: str, + model_format_version: str, + version: str, + storage_key: str, + storage_path: str, + bucket_name: str, + bucket_endpoint: str, + bucket_region: str, + author: str | None = None, + description: str | None = None, + metadata: dict[str, ScalarType] | None = None, + ) -> RegisteredModel: ... + def register_model( self, name: str,