Skip to content

Commit

Permalink
py: provide overloads for type-checking register_model
Browse files Browse the repository at this point in the history
Signed-off-by: Isabella Basso do Amaral <[email protected]>
  • Loading branch information
isinyaaa committed Mar 13, 2024
1 parent 88a6b82 commit c97bcd2
Showing 1 changed file with 81 additions and 0 deletions.
81 changes: 81 additions & 0 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from urllib import parse
from warnings import warn

from typing_extensions import overload

from .core import ModelRegistryAPIClient
from .exceptions import StoreException
from .store import ScalarType
Expand Down Expand Up @@ -69,6 +71,85 @@ def _register_model_artifact(
self._api.upsert_model_artifact(ma, mv.id)
return ma

@overload
def register_model(
self,
name: str,
uri: str,
*,
model_format_name: str,
model_format_version: str,
version: str,
author: str | None = None,
description: str | None = None,
metadata: dict[str, ScalarType] | None = None,
) -> RegisteredModel: ...

@overload
def register_model(
self,
name: str,
uri: str,
*,
model_format_name: str,
model_format_version: str,
version: str,
storage_key: str,
storage_path: str,
author: str | None = None,
description: str | None = None,
metadata: dict[str, ScalarType] | None = None,
) -> RegisteredModel: ...

@overload
def register_model(
self,
name: str,
uri: str,
*,
model_format_name: str,
model_format_version: str,
version: str,
service_account_name: str,
author: str | None = None,
description: str | None = None,
metadata: dict[str, ScalarType] | None = None,
) -> RegisteredModel: ...

@overload
def register_model(
self,
name: str,
*,
model_format_name: str,
model_format_version: str,
version: str,
storage_key: str,
storage_path: str,
bucket_name: str,
author: str | None = None,
description: str | None = None,
metadata: dict[str, ScalarType] | None = None,
) -> RegisteredModel: ...

@overload
def register_model(
self,
name: str,
*,
model_format_name: str,
model_format_version: str,
version: str,
storage_key: str,
storage_path: str,
bucket_name: str,
bucket_endpoint: str,
bucket_region: str,
author: str | None = None,
description: str | None = None,
metadata: dict[str, ScalarType] | None = None,
) -> RegisteredModel: ...

def register_model(
self,
name: str,
Expand Down

0 comments on commit c97bcd2

Please sign in to comment.