diff --git a/clients/python/src/model_registry/_client.py b/clients/python/src/model_registry/_client.py index 535d57bb..6fe8a422 100644 --- a/clients/python/src/model_registry/_client.py +++ b/clients/python/src/model_registry/_client.py @@ -9,6 +9,7 @@ from typing_extensions import overload +from ._utils import required_args from .core import ModelRegistryAPIClient from .exceptions import StoreException from .store import ScalarType @@ -80,37 +81,12 @@ def register_model( model_format_name: str, model_format_version: str, version: str, - author: str | None = None, - description: str | None = None, - metadata: dict[str, ScalarType] | None = None, - ) -> RegisteredModel: ... - - @overload - def register_model( - self, - name: str, - uri: str, - *, - model_format_name: str, - model_format_version: str, - version: str, - storage_key: str, - storage_path: str, - author: str | None = None, - description: str | None = None, - metadata: dict[str, ScalarType] | None = None, - ) -> RegisteredModel: ... - - @overload - def register_model( - self, - name: str, - uri: str, - *, - model_format_name: str, - model_format_version: str, - version: str, - service_account_name: str, + storage_key: str | None = None, + storage_path: str | None = None, + bucket_name: str | None = None, + bucket_endpoint: str | None = None, + bucket_region: str | None = None, + service_account_name: str | None = None, author: str | None = None, description: str | None = None, metadata: dict[str, ScalarType] | None = None, @@ -127,6 +103,7 @@ def register_model( storage_key: str, storage_path: str, bucket_name: str, + service_account_name: str | None = None, author: str | None = None, description: str | None = None, metadata: dict[str, ScalarType] | None = None, @@ -145,11 +122,29 @@ def register_model( bucket_name: str, bucket_endpoint: str, bucket_region: str, + service_account_name: str | None = None, author: str | None = None, description: str | None = None, metadata: dict[str, ScalarType] | None = None, ) -> RegisteredModel: ... + @required_args( + # non S3 + ("uri",), + # S3 only + ( # pre-configured env + "storage_key", + "storage_path", + "bucket_name", + ), + ( # custom env or non-default bucket + "storage_key", + "storage_path", + "bucket_name", + "bucket_endpoint", + "bucket_region", + ), + ) def register_model( self, name: str, @@ -158,14 +153,14 @@ def register_model( model_format_name: str, model_format_version: str, version: str, - author: str | None = None, - description: str | None = None, storage_key: str | None = None, storage_path: str | None = None, bucket_name: str | None = None, bucket_endpoint: str | None = None, bucket_region: str | None = None, service_account_name: str | None = None, + author: str | None = None, + description: str | None = None, metadata: dict[str, ScalarType] | None = None, ) -> RegisteredModel: """Register a model. @@ -173,11 +168,10 @@ def register_model( This registers a model in the model registry. The model is not downloaded, and has to be stored prior to registration. - Most models can be registered using a URI, along with optional parameters `storage_key` and `storage_path`. - - For models in S3 compatible object-storage, you can simply provide a `service_account_name`. - In the absence of a `service_account_name`, you should provide a `storage_key` and `storage_path` - along with the `bucket_name`. + Most models can be registered using a URI, along with optional connection-specific parameters, `storage_key` + and `storage_path` or, simply a `service_account_name`. + However it's advised to omit the URI when using S3 object storage data connections, and to provide instead the + desired `bucket_name`. If your environment is not set to use this bucket by default, `bucket_endpoint` and `bucket_region` must be provided as well. @@ -204,7 +198,7 @@ def register_model( """ if not uri: # S3 only # if the bucket_name is not the default, the endpoint and region must be provided - # if the default bucket is not set, we actually want to error on the try-except block to provide a better error message + # if the default bucket is not set, we actually want to error on the next block to provide a better error message if os.environ.get("AWS_S3_BUCKET", bucket_name) != bucket_name and not ( bucket_endpoint and bucket_region ): @@ -222,9 +216,6 @@ def register_model( f"s3://{bucket_name}/", f"{storage_path}?endpoint={bucket_endpoint}&defaultRegion={bucket_region}", ) - elif parse.urlparse(uri).scheme != "s3" and service_account_name: - msg = "service_account_name can only be used with S3 URIs" - raise StoreException(msg) rm = self._register_model(name) mv = self._register_new_version( @@ -334,7 +325,6 @@ def register_hf_model( model_format_name=model_format_name, model_format_version=model_format_version, description=description, - storage_path=path, metadata=metadata, ) diff --git a/clients/python/src/model_registry/_utils.py b/clients/python/src/model_registry/_utils.py new file mode 100644 index 00000000..b2a32cb8 --- /dev/null +++ b/clients/python/src/model_registry/_utils.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import functools +import inspect +from collections.abc import Sequence +from typing import Any, Callable, TypeVar + +CallableT = TypeVar("CallableT", bound=Callable[..., Any]) + + +# copied from https://github.com/Rapptz/RoboDanny +def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str: + size = len(seq) + if size == 0: + return "" + + if size == 1: + return seq[0] + + if size == 2: + return f"{seq[0]} {final} {seq[1]}" + + return delim.join(seq[:-1]) + f" {final} {seq[-1]}" + + +def quote(string: str) -> str: + """Add single quotation marks around the given string. Does *not* do any escaping.""" + return f"'{string}'" + + +# copied from https://github.com/openai/openai-python +def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]: # noqa: C901 + """Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function. + + Useful for enforcing runtime validation of overloaded functions. + + Example usage: + ```py + @overload + def foo(*, a: str) -> str: + ... + + + @overload + def foo(*, b: bool) -> str: + ... + + + # This enforces the same constraints that a static type checker would + # i.e. that either a or b must be passed to the function + @required_args(["a"], ["b"]) + def foo(*, a: str | None = None, b: bool | None = None) -> str: + ... + ``` + """ + + def inner(func: CallableT) -> CallableT: # noqa: C901 + params = inspect.signature(func).parameters + positional = [ + name + for name, param in params.items() + if param.kind + in { + param.POSITIONAL_ONLY, + param.POSITIONAL_OR_KEYWORD, + } + ] + + @functools.wraps(func) + def wrapper(*args: object, **kwargs: object) -> object: + given_params: set[str] = set() + for i, _ in enumerate(args): + try: + given_params.add(positional[i]) + except IndexError: + msg = f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given" + raise TypeError(msg) from None + + for key in kwargs: + given_params.add(key) + + for variant in variants: + matches = all(param in given_params for param in variant) + if matches: + break + else: # no break + if len(variants) > 1: + variations = human_join( + [ + "(" + + human_join([quote(arg) for arg in variant], final="and") + + ")" + for variant in variants + ] + ) + msg = f"Missing required arguments; Expected either {variations} arguments to be given" + else: + # TODO: this error message is not deterministic + missing = list(set(variants[0]) - given_params) + if len(missing) > 1: + msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}" + else: + msg = f"Missing required argument: {quote(missing[0])}" + raise TypeError(msg) + return func(*args, **kwargs) + + return wrapper # type: ignore + + return inner