From 68f7d41f445ccf9234a1d6aab0b792a92c3a48e7 Mon Sep 17 00:00:00 2001 From: Isabella do Amaral Date: Thu, 31 Oct 2024 18:32:25 -0300 Subject: [PATCH] upload as image Signed-off-by: Isabella do Amaral --- Makefile | 2 +- omlmd/cli.py | 10 +++- omlmd/constants.py | 6 ++- omlmd/helpers.py | 78 ++++++++++++++++++++++++++++---- omlmd/listener.py | 4 +- omlmd/model_metadata.py | 13 ------ tests/test_e2e_model_registry.py | 4 +- tests/test_helpers.py | 50 ++++++++++++++++---- tests/test_omlmd.py | 6 +-- 9 files changed, 131 insertions(+), 42 deletions(-) diff --git a/Makefile b/Makefile index 5803fa4..3329828 100644 --- a/Makefile +++ b/Makefile @@ -25,7 +25,7 @@ test: .PHONY: test-e2e test-e2e: - poetry run pytest --e2e -s -x -rA + poetry run pytest --e2e -s -x -rA -v .PHONY: test-e2e-model-registry test-e2e-model-registry: diff --git a/omlmd/cli.py b/omlmd/cli.py index 8d66eca..20a7400 100644 --- a/omlmd/cli.py +++ b/omlmd/cli.py @@ -73,6 +73,11 @@ def crawl(plain_http: bool, targets: tuple[str]): required=True, type=click.Path(path_type=Path, exists=True, resolve_path=True), ) +@click.option( + "--as-artifact", + is_flag=True, + help="Push as an artifact (default is as a blob)", +) @cloup.option_group( "Metadata options", cloup.option( @@ -88,6 +93,7 @@ def push( plain_http: bool, target: str, path: Path, + as_artifact: bool, metadata: Path | None, empty_metadata: bool, ): @@ -96,4 +102,6 @@ def push( if empty_metadata: logger.warning(f"Pushing to {target} with empty metadata.") md = deserialize_mdfile(metadata) if metadata else {} - click.echo(Helper.from_default_registry(plain_http).push(target, path, **md)) + click.echo( + Helper.from_default_registry(plain_http).push(target, path, as_artifact, **md) + ) diff --git a/omlmd/constants.py b/omlmd/constants.py index 05db48e..d20dad0 100644 --- a/omlmd/constants.py +++ b/omlmd/constants.py @@ -1,3 +1,7 @@ +from oras.defaults import default_blob_media_type + FILENAME_METADATA_JSON = "model_metadata.omlmd.json" -MIME_APPLICATION_CONFIG = "application/x-config" MIME_APPLICATION_MLMODEL = "application/x-mlmodel" +MIME_APPLICATION_MLMETADATA = "application/x-mlmetadata+json" +MIME_BLOB = default_blob_media_type +MIME_MANIFEST_CONFIG = "application/vnd.oci.image.config.v1+json" diff --git a/omlmd/helpers.py b/omlmd/helpers.py index d3b2bd2..b365e9c 100644 --- a/omlmd/helpers.py +++ b/omlmd/helpers.py @@ -1,7 +1,10 @@ from __future__ import annotations +import json import logging import os +import platform +import tarfile import urllib.request from collections.abc import Sequence from dataclasses import dataclass, field @@ -10,8 +13,10 @@ from .constants import ( FILENAME_METADATA_JSON, - MIME_APPLICATION_CONFIG, + MIME_APPLICATION_MLMETADATA, MIME_APPLICATION_MLMODEL, + MIME_BLOB, + MIME_MANIFEST_CONFIG, ) from .listener import Event, Listener, PushEvent from .model_metadata import ModelMetadata @@ -20,6 +25,18 @@ logger = logging.getLogger(__name__) +def get_arch() -> str: + mac = platform.machine() + if mac == "x86_64": + return "amd64" + if mac == "arm64": + return "arm64" + if mac == "aarch64": + return "arm64" + msg = f"Unsupported architecture: {platform.machine()}" + raise NotImplementedError(msg) + + def download_file(uri: str): file_name = os.path.basename(uri) urllib.request.urlretrieve(uri, file_name) @@ -41,6 +58,7 @@ def push( self, target: str, path: Path | str, + as_artifact: bool = False, **kwargs, ): owns_meta = True @@ -52,8 +70,7 @@ def push( owns_meta = False logger.warning("Reusing intermediate metadata files.") logger.debug(f"{meta_path}") - with open(meta_path, "r") as f: - model_metadata = ModelMetadata.from_json(f.read()) + model_metadata = ModelMetadata.from_dict(json.loads(meta_path.read_bytes())) elif meta_path.exists(): err = dedent(f""" OMLMD intermediate metadata files found at '{meta_path}'. @@ -65,13 +82,50 @@ def push( raise RuntimeError(err) else: model_metadata = ModelMetadata.from_dict(kwargs) - meta_path.write_text(model_metadata.to_json()) + meta_path.write_text(json.dumps(model_metadata.to_dict())) + + owns_model_tar = False + owns_md_tar = False + manifest_path = path.parent / "manifest.json" + model_tar = None + meta_tar = None + if not as_artifact: + manifest_path.write_text( + json.dumps( + { + "architecture": get_arch(), + "os": "linux", + } + ) + ) + config = f"{manifest_path}:{MIME_MANIFEST_CONFIG}" + model_tar = path.parent / f"{path.stem}.tar" + meta_tar = path.parent / f"{meta_path.stem}.tar" + if not model_tar.exists(): + owns_model_tar = True + with tarfile.open(model_tar, "w") as tf: + tf.add(path, arcname=path.name) + if not meta_tar.exists(): + with tarfile.open(meta_tar, "w:gz") as tf: + tf.add(meta_path, arcname=meta_path.name) + files = [ + f"{model_tar}:{MIME_BLOB}", + f"{meta_tar}:{MIME_BLOB}+gzip", + ] + else: + manifest_path.write_text( + json.dumps( + { + "artifactType": MIME_APPLICATION_MLMODEL, + } + ) + ) + config = f"{manifest_path}:{MIME_APPLICATION_MLMODEL}" + files = [ + f"{path}:{MIME_APPLICATION_MLMODEL}", + f"{meta_path}:{MIME_APPLICATION_MLMETADATA}", + ] - config = f"{meta_path}:{MIME_APPLICATION_CONFIG}" - files = [ - f"{path}:{MIME_APPLICATION_MLMODEL}", - config, - ] try: # print(target, files, model_metadata.to_annotations_dict()) result = self._registry.push( @@ -88,6 +142,12 @@ def push( finally: if owns_meta: meta_path.unlink() + if owns_model_tar: + assert isinstance(model_tar, Path) + model_tar.unlink() + if owns_md_tar: + assert isinstance(meta_tar, Path) + meta_tar.unlink() def pull( self, target: str, outdir: Path | str, media_types: Sequence[str] | None = None diff --git a/omlmd/listener.py b/omlmd/listener.py index 8122fde..455edab 100644 --- a/omlmd/listener.py +++ b/omlmd/listener.py @@ -10,9 +10,7 @@ class Listener(ABC): - """ - TODO: not yet settled for multi-method or current single update method. - """ + # TODO: not yet settled for multi-method or current single update method. @abstractmethod def update(self, source: t.Any, event: Event) -> None: diff --git a/omlmd/model_metadata.py b/omlmd/model_metadata.py index 4ff929a..ad96697 100644 --- a/omlmd/model_metadata.py +++ b/omlmd/model_metadata.py @@ -17,9 +17,6 @@ class ModelMetadata: model_format_name: str | None = None model_format_version: str | None = None - def to_json(self) -> str: - return json.dumps(self.to_dict(), indent=4) - def to_dict(self) -> dict[str, t.Any]: return asdict(self) @@ -38,16 +35,6 @@ def to_annotations_dict(self) -> dict[str, str]: ) # post-fix "+json" for OCI annotation which is a str representing a json return result - @staticmethod - def from_json(json_str: str) -> "ModelMetadata": - data = json.loads(json_str) - return ModelMetadata(**data) - - @staticmethod - def from_yaml(yaml_str: str) -> "ModelMetadata": - data = yaml.safe_load(yaml_str) - return ModelMetadata(**data) - @staticmethod def from_dict(data: dict[str, t.Any]) -> "ModelMetadata": known_keys = {f.name for f in fields(ModelMetadata)} diff --git a/tests/test_e2e_model_registry.py b/tests/test_e2e_model_registry.py index 972b0ac..8629814 100644 --- a/tests/test_e2e_model_registry.py +++ b/tests/test_e2e_model_registry.py @@ -77,7 +77,9 @@ def update(self, source: Helper, event: Event) -> None: assert mv assert mv.description == "Lorem ipsum" assert mv.author == "John Doe" - assert mv.custom_properties == {"accuracy": 0.987} + assert mv.custom_properties == { + "accuracy": accuracy_value, + } ma = model_registry.get_model_artifact("mnist", v) assert ma diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 00bc193..20a3c30 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,5 +1,7 @@ +import io import json import subprocess +import tarfile import tempfile import typing as t from hashlib import sha256 @@ -7,13 +9,19 @@ import pytest -from omlmd.constants import MIME_APPLICATION_MLMODEL +from omlmd.constants import MIME_BLOB from omlmd.helpers import Helper from omlmd.listener import Event, Listener from omlmd.model_metadata import ModelMetadata, deserialize_mdfile from omlmd.provider import OMLMDRegistry +def untar(tar: Path, out: Path): + out.write_bytes( + t.cast(io.BufferedReader, tarfile.open(tar, "r").extractfile(tar.stem)).read() + ) + + def test_call_push_using_md_from_file(mocker): helper = Helper() mocker.patch.object(helper, "push", return_value=None) @@ -100,12 +108,33 @@ def test_push_pull_chunked(tmp_path, target): omlmd.push(target, temp, **md) omlmd.pull(target, tmp_path) - assert len(list(tmp_path.iterdir())) == 3 - assert tmp_path.joinpath(temp.name).stat().st_size == base_size + files = list(tmp_path.iterdir()) + print(files) + assert len(files) == 2 + print(tmp_path) + out = tmp_path.joinpath(temp.name) + untar(out.with_suffix(".tar"), out) + assert temp.stat().st_size == base_size finally: temp.unlink() +@pytest.mark.e2e +def test_e2e_push_pull_as_artifact(tmp_path, target): + omlmd = Helper() + omlmd.push( + target, + Path(__file__).parent / ".." / "README.md", + as_artifact=True, + name="mnist", + description="Lorem ipsum", + author="John Doe", + accuracy=0.987, + ) + omlmd.pull(target, tmp_path) + assert len(list(tmp_path.iterdir())) == 2 + + @pytest.mark.e2e def test_e2e_push_pull(tmp_path, target): omlmd = Helper() @@ -118,7 +147,7 @@ def test_e2e_push_pull(tmp_path, target): accuracy=0.987, ) omlmd.pull(target, tmp_path) - assert len(list(tmp_path.iterdir())) == 3 + assert len(list(tmp_path.iterdir())) == 2 @pytest.mark.e2e @@ -132,7 +161,7 @@ def test_e2e_push_pull_with_filters(tmp_path, target): author="John Doe", accuracy=0.987, ) - omlmd.pull(target, tmp_path, media_types=[MIME_APPLICATION_MLMODEL]) + omlmd.pull(target, tmp_path, media_types=[MIME_BLOB]) assert len(list(tmp_path.iterdir())) == 1 @@ -155,10 +184,11 @@ def test_e2e_push_pull_column(tmp_path, target): omlmd.push(target, temp, **md) omlmd.pull(target, tmp_path) - with open(tmp_path.joinpath(temp.name), "r") as f: - pulled = f.read() - assert pulled == content - pulled_sha = sha256(pulled.encode("utf-8")).hexdigest() - assert pulled_sha == content_sha + out = tmp_path.joinpath(temp.name) + untar(out.with_suffix(".tar"), out) + pulled = out.read_text() + assert pulled == content + pulled_sha = sha256(pulled.encode("utf-8")).hexdigest() + assert pulled_sha == content_sha finally: temp.unlink() diff --git a/tests/test_omlmd.py b/tests/test_omlmd.py index 9c800b1..7028c31 100644 --- a/tests/test_omlmd.py +++ b/tests/test_omlmd.py @@ -8,14 +8,14 @@ def test_dry_run_model_metadata_json_yaml_conversions(): metadata = ModelMetadata(name="Example Model", author="John Doe") - json_str = metadata.to_json() + json_str = json.dumps(metadata.to_dict(), indent=4) yaml_str = yaml.dump(metadata.to_dict(), default_flow_style=False) print("JSON representation:\n", json_str) print("YAML representation:\n", yaml_str) - metadata_from_json = ModelMetadata.from_json(json_str) - metadata_from_yaml = ModelMetadata.from_yaml(yaml_str) + metadata_from_json = ModelMetadata(**json.loads(json_str)) + metadata_from_yaml = ModelMetadata(**yaml.safe_load(yaml_str)) print("Metadata from JSON:\n", metadata_from_json) print("Metadata from YAML:\n", metadata_from_yaml)