Skip to content

Commit

Permalink
py: provide default URI builder for S3 in register_model
Browse files Browse the repository at this point in the history
Signed-off-by: Isabella Basso do Amaral <[email protected]>
  • Loading branch information
isinyaaa committed Mar 14, 2024
1 parent 962c78d commit a42eebc
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 48 deletions.
61 changes: 44 additions & 17 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Standard client for the model registry."""

from __future__ import annotations

import os
from typing import get_args
from urllib import parse
from warnings import warn

from .core import ModelRegistryAPIClient
Expand Down Expand Up @@ -70,7 +72,7 @@ def _register_model_artifact(
def register_model(
self,
name: str,
uri: str,
uri: str | None = None,
*,
model_format_name: str,
model_format_version: str,
Expand All @@ -79,12 +81,24 @@ def register_model(
description: str | None = None,
storage_key: str | None = None,
storage_path: str | None = None,
bucket_name: str | None = None,
bucket_endpoint: str | None = None,
bucket_region: str | None = None,
service_account_name: str | None = None,
metadata: dict[str, ScalarType] | None = None,
) -> RegisteredModel:
"""Register a model.
Either `storage_key` and `storage_path`, or `service_account_name` must be provided.
This registers a model in the model registry. The model is not downloaded, and has to be stored prior to
registration.
Most models can be registered using a URI, along with optional parameters `storage_key` and `storage_path`.
For models in S3 compatible object-storage, you can simply provide a `service_account_name`.
In the absence of a `service_account_name`, you should provide a `storage_key` and `storage_path`
along with the `bucket_name`.
If your environment is not set to use this bucket by default, `bucket_endpoint` and `bucket_region` must be
provided as well.
Args:
name: Name of the model.
Expand All @@ -98,19 +112,46 @@ def register_model(
author: Author of the model. Defaults to the client author.
storage_key: Storage key.
storage_path: Storage path.
bucket_name: Name of the S3 bucket.
bucket_endpoint: Endpoint of the S3 bucket. Must be provided if `bucket_name` doesn't match the default.
bucket_region: Region of the S3 bucket. Must be provided if `bucket_name` doesn't match the default.
service_account_name: Service account name.
metadata: Additional version metadata. Defaults to values returned by `default_metadata()`.
Returns:
Registered model.
"""
if not uri: # S3 only
# if the bucket_name is not the default, the endpoint and region must be provided
# if the default bucket is not set, we actually want to error on the try-except block to provide a better error message
if os.environ.get("AWS_S3_BUCKET", bucket_name) != bucket_name and not (
bucket_endpoint and bucket_region
):
msg = "bucket_endpoint and bucket_region must be provided for non-default bucket"
raise StoreException(msg)

bucket_endpoint = bucket_endpoint or os.getenv("AWS_S3_ENDPOINT")
bucket_region = bucket_region or os.getenv("AWS_DEFAULT_REGION")

if not bucket_endpoint or not bucket_region:
msg = "Missing environment variables: bucket_endpoint and bucket_region are required"
raise StoreException(msg)

uri = parse.urljoin(
f"s3://{bucket_name}/",
f"{storage_path}?endpoint={bucket_endpoint}&defaultRegion={bucket_region}",
)
elif parse.urlparse(uri).scheme != "s3" and service_account_name:
msg = "service_account_name can only be used with S3 URIs"
raise StoreException(msg)

rm = self._register_model(name)
mv = self._register_new_version(
rm,
version,
author or self._author,
description=description,
metadata=metadata or self.default_metadata(),
metadata=metadata or {},
)
self._register_model_artifact(
mv,
Expand All @@ -124,19 +165,6 @@ def register_model(

return rm

def default_metadata(self) -> dict[str, ScalarType]:
"""Default metadata valorisations.
When not explicitly supplied by the end users, these valorisations will be used
by default.
Returns:
default metadata valorisations.
"""
return {
key: os.environ[key] for key in ["AWS_S3_ENDPOINT", "AWS_S3_BUCKET", "AWS_DEFAULT_REGION"] if key in os.environ
}

def register_hf_model(
self,
repo: str,
Expand Down Expand Up @@ -202,7 +230,6 @@ def register_hf_model(
model_author = author
source_uri = hf_hub_url(repo, path, revision=git_ref)
metadata = {
**self.default_metadata(),
"repo": repo,
"source_uri": source_uri,
"model_origin": "huggingface_hub",
Expand Down
88 changes: 57 additions & 31 deletions clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_register_new(mr_client: ModelRegistry):
version = "1.0.0"
rm = mr_client.register_model(
name,
"s3",
"not-s3://test-model",
model_format_name="test_format",
model_format_version="test_version",
version=version,
Expand All @@ -31,10 +31,49 @@ def test_register_new(mr_client: ModelRegistry):
assert mr_api.get_model_artifact_by_params(mv.id) is not None


def test_register_many_uri_formats(mr_client: ModelRegistry):
mr_client.register_model(
name="plain",
uri="not-s3://no-model",
version="1",
model_format_name="test_format",
model_format_version="test_version",
)
mr_client.register_model(
name="has key and path",
uri="not-s3://no-model",
version="2",
model_format_name="test_format",
model_format_version="test_version",
storage_key="test-key",
storage_path="test-path",
)
mr_client.register_model(
name="s3 with bucket info",
uri="s3://no-model",
version="3",
model_format_name="test_format",
model_format_version="test_version",
storage_key="test-key",
storage_path="test-path",
bucket_name="test-bucket",
bucket_endpoint="test-endpoint",
bucket_region="test-region",
)
mr_client.register_model(
name="s3 with service account",
uri="s3://no-model",
version="4",
model_format_name="test_format",
model_format_version="test_version",
service_account_name="test account",
)


def test_register_existing_version(mr_client: ModelRegistry):
params = {
"name": "test_model",
"uri": "s3",
"uri": "not-s3://test-model",
"model_format_name": "test_format",
"model_format_version": "test_version",
"version": "1.0.0",
Expand All @@ -52,11 +91,11 @@ def test_get(mr_client: ModelRegistry):

rm = mr_client.register_model(
name,
"s3",
"not-s3://test-model",
model_format_name="test_format",
model_format_version="test_version",
version=version,
metadata=metadata
metadata=metadata,
)

assert (_rm := mr_client.get_registered_model(name))
Expand All @@ -73,28 +112,6 @@ def test_get(mr_client: ModelRegistry):
assert ma.id == _ma.id


def test_default_md(mr_client: ModelRegistry):
name = "test_model"
version = "1.0.0"
env_values = {"AWS_S3_ENDPOINT": "value1", "AWS_S3_BUCKET": "value2", "AWS_DEFAULT_REGION": "value3"}
for k, v in env_values.items():
os.environ[k] = v

assert mr_client.register_model(
name,
"s3",
model_format_name="test_format",
model_format_version="test_version",
version=version,
# ensure leave empty metadata
)
assert (mv := mr_client.get_model_version(name, version))
assert mv.metadata == env_values

for k in env_values:
os.environ.pop(k)


def test_hf_import(mr_client: ModelRegistry):
pytest.importorskip("huggingface_hub")
name = "openai-community/gpt2"
Expand All @@ -113,19 +130,25 @@ def test_hf_import(mr_client: ModelRegistry):
assert mv.author == author
assert mv.metadata["model_author"] == author
assert mv.metadata["model_origin"] == "huggingface_hub"
assert mv.metadata["source_uri"] == "https://huggingface.co/openai-community/gpt2/resolve/main/onnx/decoder_model.onnx"
assert (
mv.metadata["source_uri"]
== "https://huggingface.co/openai-community/gpt2/resolve/main/onnx/decoder_model.onnx"
)
assert mv.metadata["repo"] == name
assert mr_client.get_model_artifact(name, version)


def test_hf_import_default_env(mr_client: ModelRegistry):
"""Test setting environment variables, hence triggering defaults, does _not_ interfere with HF metadata
"""
"""Test setting environment variables, hence triggering defaults, does _not_ interfere with HF metadata"""
pytest.importorskip("huggingface_hub")
name = "openai-community/gpt2"
version = "1.2.3"
author = "test author"
env_values = {"AWS_S3_ENDPOINT": "value1", "AWS_S3_BUCKET": "value2", "AWS_DEFAULT_REGION": "value3"}
env_values = {
"AWS_S3_ENDPOINT": "value1",
"AWS_S3_BUCKET": "value2",
"AWS_DEFAULT_REGION": "value3",
}
for k, v in env_values.items():
os.environ[k] = v

Expand All @@ -140,7 +163,10 @@ def test_hf_import_default_env(mr_client: ModelRegistry):
assert (mv := mr_client.get_model_version(name, version))
assert mv.metadata["model_author"] == author
assert mv.metadata["model_origin"] == "huggingface_hub"
assert mv.metadata["source_uri"] == "https://huggingface.co/openai-community/gpt2/resolve/main/onnx/decoder_model.onnx"
assert (
mv.metadata["source_uri"]
== "https://huggingface.co/openai-community/gpt2/resolve/main/onnx/decoder_model.onnx"
)
assert mv.metadata["repo"] == name
assert mr_client.get_model_artifact(name, version)

Expand Down

0 comments on commit a42eebc

Please sign in to comment.