Skip to content

Commit

Permalink
py: verify overload signatures at runtime
Browse files Browse the repository at this point in the history
Signed-off-by: Isabella Basso do Amaral <[email protected]>
  • Loading branch information
isinyaaa committed Mar 14, 2024
1 parent 3e47c8c commit ae85d22
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 43 deletions.
76 changes: 33 additions & 43 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from typing_extensions import overload

from ._utils import required_args
from .core import ModelRegistryAPIClient
from .exceptions import StoreException
from .store import ScalarType
Expand Down Expand Up @@ -80,37 +81,12 @@ def register_model(
model_format_name: str,
model_format_version: str,
version: str,
author: str | None = None,
description: str | None = None,
metadata: dict[str, ScalarType] | None = None,
) -> RegisteredModel: ...

@overload
def register_model(
self,
name: str,
uri: str,
*,
model_format_name: str,
model_format_version: str,
version: str,
storage_key: str,
storage_path: str,
author: str | None = None,
description: str | None = None,
metadata: dict[str, ScalarType] | None = None,
) -> RegisteredModel: ...

@overload
def register_model(
self,
name: str,
uri: str,
*,
model_format_name: str,
model_format_version: str,
version: str,
service_account_name: str,
storage_key: str | None = None,
storage_path: str | None = None,
bucket_name: str | None = None,
bucket_endpoint: str | None = None,
bucket_region: str | None = None,
service_account_name: str | None = None,
author: str | None = None,
description: str | None = None,
metadata: dict[str, ScalarType] | None = None,
Expand All @@ -127,6 +103,7 @@ def register_model(
storage_key: str,
storage_path: str,
bucket_name: str,
service_account_name: str | None = None,
author: str | None = None,
description: str | None = None,
metadata: dict[str, ScalarType] | None = None,
Expand All @@ -145,11 +122,29 @@ def register_model(
bucket_name: str,
bucket_endpoint: str,
bucket_region: str,
service_account_name: str | None = None,
author: str | None = None,
description: str | None = None,
metadata: dict[str, ScalarType] | None = None,
) -> RegisteredModel: ...

@required_args(
# non S3
("uri",),
# S3 only
( # pre-configured env
"storage_key",
"storage_path",
"bucket_name",
),
( # custom env or non-default bucket
"storage_key",
"storage_path",
"bucket_name",
"bucket_endpoint",
"bucket_region",
),
)
def register_model(
self,
name: str,
Expand All @@ -158,26 +153,25 @@ def register_model(
model_format_name: str,
model_format_version: str,
version: str,
author: str | None = None,
description: str | None = None,
storage_key: str | None = None,
storage_path: str | None = None,
bucket_name: str | None = None,
bucket_endpoint: str | None = None,
bucket_region: str | None = None,
service_account_name: str | None = None,
author: str | None = None,
description: str | None = None,
metadata: dict[str, ScalarType] | None = None,
) -> RegisteredModel:
"""Register a model.
This registers a model in the model registry. The model is not downloaded, and has to be stored prior to
registration.
Most models can be registered using a URI, along with optional parameters `storage_key` and `storage_path`.
For models in S3 compatible object-storage, you can simply provide a `service_account_name`.
In the absence of a `service_account_name`, you should provide a `storage_key` and `storage_path`
along with the `bucket_name`.
Most models can be registered using a URI, along with optional connection-specific parameters, `storage_key`
and `storage_path` or, simply a `service_account_name`.
However it's advised to omit the URI when using S3 object storage data connections, and to provide instead the
desired `bucket_name`.
If your environment is not set to use this bucket by default, `bucket_endpoint` and `bucket_region` must be
provided as well.
Expand All @@ -204,7 +198,7 @@ def register_model(
"""
if not uri: # S3 only
# if the bucket_name is not the default, the endpoint and region must be provided
# if the default bucket is not set, we actually want to error on the try-except block to provide a better error message
# if the default bucket is not set, we actually want to error on the next block to provide a better error message
if os.environ.get("AWS_S3_BUCKET", bucket_name) != bucket_name and not (
bucket_endpoint and bucket_region
):
Expand All @@ -222,9 +216,6 @@ def register_model(
f"s3://{bucket_name}/",
f"{storage_path}?endpoint={bucket_endpoint}&defaultRegion={bucket_region}",
)
elif parse.urlparse(uri).scheme != "s3" and service_account_name:
msg = "service_account_name can only be used with S3 URIs"
raise StoreException(msg)

rm = self._register_model(name)
mv = self._register_new_version(
Expand Down Expand Up @@ -334,7 +325,6 @@ def register_hf_model(
model_format_name=model_format_name,
model_format_version=model_format_version,
description=description,
storage_path=path,
metadata=metadata,
)

Expand Down
109 changes: 109 additions & 0 deletions clients/python/src/model_registry/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from __future__ import annotations

import functools
import inspect
from collections.abc import Sequence
from typing import Any, Callable, TypeVar

CallableT = TypeVar("CallableT", bound=Callable[..., Any])


# copied from https://github.com/Rapptz/RoboDanny
def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
size = len(seq)
if size == 0:
return ""

if size == 1:
return seq[0]

if size == 2:
return f"{seq[0]} {final} {seq[1]}"

return delim.join(seq[:-1]) + f" {final} {seq[-1]}"


def quote(string: str) -> str:
"""Add single quotation marks around the given string. Does *not* do any escaping."""
return f"'{string}'"


# copied from https://github.com/openai/openai-python
def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]: # noqa: C901
"""Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
Useful for enforcing runtime validation of overloaded functions.
Example usage:
```py
@overload
def foo(*, a: str) -> str:
...
@overload
def foo(*, b: bool) -> str:
...
# This enforces the same constraints that a static type checker would
# i.e. that either a or b must be passed to the function
@required_args(["a"], ["b"])
def foo(*, a: str | None = None, b: bool | None = None) -> str:
...
```
"""

def inner(func: CallableT) -> CallableT: # noqa: C901
params = inspect.signature(func).parameters
positional = [
name
for name, param in params.items()
if param.kind
in {
param.POSITIONAL_ONLY,
param.POSITIONAL_OR_KEYWORD,
}
]

@functools.wraps(func)
def wrapper(*args: object, **kwargs: object) -> object:
given_params: set[str] = set()
for i, _ in enumerate(args):
try:
given_params.add(positional[i])
except IndexError:
msg = f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given"
raise TypeError(msg) from None

for key in kwargs:
given_params.add(key)

for variant in variants:
matches = all(param in given_params for param in variant)
if matches:
break
else: # no break
if len(variants) > 1:
variations = human_join(
[
"("
+ human_join([quote(arg) for arg in variant], final="and")
+ ")"
for variant in variants
]
)
msg = f"Missing required arguments; Expected either {variations} arguments to be given"
else:
# TODO: this error message is not deterministic
missing = list(set(variants[0]) - given_params)
if len(missing) > 1:
msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
else:
msg = f"Missing required argument: {quote(missing[0])}"
raise TypeError(msg)
return func(*args, **kwargs)

return wrapper # type: ignore

return inner

0 comments on commit ae85d22

Please sign in to comment.