Skip to content

Commit

Permalink
upload as image
Browse files Browse the repository at this point in the history
Signed-off-by: Isabella do Amaral <[email protected]>
  • Loading branch information
isinyaaa committed Oct 31, 2024
1 parent 91c65b6 commit a001104
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ test:

.PHONY: test-e2e
test-e2e:
poetry run pytest --e2e -s -x -rA
poetry run pytest --e2e -s -x -rA -v

.PHONY: test-e2e-model-registry
test-e2e-model-registry:
Expand Down
10 changes: 9 additions & 1 deletion omlmd/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def crawl(plain_http: bool, targets: tuple[str]):
required=True,
type=click.Path(path_type=Path, exists=True, resolve_path=True),
)
@click.option(
"--as-artifact",
is_flag=True,
help="Push as an artifact (default is as a blob)",
)
@cloup.option_group(
"Metadata options",
cloup.option(
Expand All @@ -88,6 +93,7 @@ def push(
plain_http: bool,
target: str,
path: Path,
as_artifact: bool,
metadata: Path | None,
empty_metadata: bool,
):
Expand All @@ -96,4 +102,6 @@ def push(
if empty_metadata:
logger.warning(f"Pushing to {target} with empty metadata.")
md = deserialize_mdfile(metadata) if metadata else {}
click.echo(Helper.from_default_registry(plain_http).push(target, path, **md))
click.echo(
Helper.from_default_registry(plain_http).push(target, path, as_artifact, **md)
)
5 changes: 4 additions & 1 deletion omlmd/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from oras.defaults import default_blob_media_type

FILENAME_METADATA_JSON = "model_metadata.omlmd.json"
MIME_APPLICATION_CONFIG = "application/x-config"
MIME_APPLICATION_MLMODEL = "application/x-mlmodel"
MIME_BLOB = default_blob_media_type
MIME_MANIFEST_CONFIG = "application/vnd.oci.image.config.v1+json"
48 changes: 42 additions & 6 deletions omlmd/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import logging
import os
import platform
import tarfile
import urllib.request
from collections.abc import Sequence
from dataclasses import dataclass, field
Expand All @@ -10,8 +12,9 @@

from .constants import (
FILENAME_METADATA_JSON,
MIME_APPLICATION_CONFIG,
MIME_APPLICATION_MLMODEL,
MIME_BLOB,
MIME_MANIFEST_CONFIG,
)
from .listener import Event, Listener, PushEvent
from .model_metadata import ModelMetadata
Expand All @@ -20,6 +23,19 @@
logger = logging.getLogger(__name__)


def get_arch() -> str:
match platform.machine():
case "x86_64":
return "amd64"
case "arm64":
return "arm64"
case "aarch64":
return "arm64"
case _:
msg = f"Unsupported architecture: {platform.machine()}"
raise NotImplementedError(msg)


def download_file(uri: str):
file_name = os.path.basename(uri)
urllib.request.urlretrieve(uri, file_name)
Expand All @@ -41,12 +57,16 @@ def push(
self,
target: str,
path: Path | str,
as_artifact: bool = False,
**kwargs,
):
owns_meta = True
if isinstance(path, str):
path = Path(path)

manifest_cfg = {"architecture": get_arch(), "os": "linux"}

kwargs.update(manifest_cfg)
meta_path = path.parent / FILENAME_METADATA_JSON
if not kwargs and meta_path.exists():
owns_meta = False
Expand All @@ -67,11 +87,24 @@ def push(
model_metadata = ModelMetadata.from_dict(kwargs)
meta_path.write_text(model_metadata.to_json())

config = f"{meta_path}:{MIME_APPLICATION_CONFIG}"
files = [
f"{path}:{MIME_APPLICATION_MLMODEL}",
config,
]
config = f"{meta_path}:{MIME_MANIFEST_CONFIG}"

owns_tar = False
tar = None
if not as_artifact:
tar = path.parent / f"{path.stem}.tar"
if not tar.exists():
owns_tar = True
with tarfile.open(tar, "w") as tf:
tf.add(path, arcname=path.name)
files = [
f"{tar}:{MIME_BLOB}",
]
else:
files = [
f"{path}:{MIME_APPLICATION_MLMODEL}",
]

try:
# print(target, files, model_metadata.to_annotations_dict())
result = self._registry.push(
Expand All @@ -88,6 +121,9 @@ def push(
finally:
if owns_meta:
meta_path.unlink()
if owns_tar:
assert isinstance(tar, Path)
tar.unlink()

def pull(
self, target: str, outdir: Path | str, media_types: Sequence[str] | None = None
Expand Down
4 changes: 1 addition & 3 deletions omlmd/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@


class Listener(ABC):
"""
TODO: not yet settled for multi-method or current single update method.
"""
# TODO: not yet settled for multi-method or current single update method.

@abstractmethod
def update(self, source: t.Any, event: Event) -> None:
Expand Down
8 changes: 6 additions & 2 deletions tests/test_e2e_model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from model_registry import ModelRegistry
from model_registry.types import RegisteredModel

from omlmd.helpers import Helper, download_file
from omlmd.helpers import Helper, download_file, get_arch
from omlmd.listener import Event, Listener, PushEvent


Expand Down Expand Up @@ -77,7 +77,11 @@ def update(self, source: Helper, event: Event) -> None:
assert mv
assert mv.description == "Lorem ipsum"
assert mv.author == "John Doe"
assert mv.custom_properties == {"accuracy": 0.987}
assert mv.custom_properties == {
"accuracy": accuracy_value,
"os": "linux",
"architecture": get_arch(),
}

ma = model_registry.get_model_artifact("mnist", v)
assert ma
Expand Down
50 changes: 40 additions & 10 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
import io
import json
import subprocess
import tarfile
import tempfile
import typing as t
from hashlib import sha256
from pathlib import Path

import pytest

from omlmd.constants import MIME_APPLICATION_MLMODEL
from omlmd.constants import MIME_BLOB
from omlmd.helpers import Helper
from omlmd.listener import Event, Listener
from omlmd.model_metadata import ModelMetadata, deserialize_mdfile
from omlmd.provider import OMLMDRegistry


def untar(tar: Path, out: Path):
out.write_bytes(
t.cast(io.BufferedReader, tarfile.open(tar, "r").extractfile(tar.stem)).read()
)


def test_call_push_using_md_from_file(mocker):
helper = Helper()
mocker.patch.object(helper, "push", return_value=None)
Expand Down Expand Up @@ -100,12 +108,33 @@ def test_push_pull_chunked(tmp_path, target):

omlmd.push(target, temp, **md)
omlmd.pull(target, tmp_path)
assert len(list(tmp_path.iterdir())) == 3
assert tmp_path.joinpath(temp.name).stat().st_size == base_size
files = list(tmp_path.iterdir())
print(files)
assert len(files) == 1
print(tmp_path)
out = tmp_path.joinpath(temp.name)
untar(out.with_suffix(".tar"), out)
assert temp.stat().st_size == base_size
finally:
temp.unlink()


@pytest.mark.e2e
def test_e2e_push_pull_as_artifact(tmp_path, target):
omlmd = Helper()
omlmd.push(
target,
Path(__file__).parent / ".." / "README.md",
as_artifact=True,
name="mnist",
description="Lorem ipsum",
author="John Doe",
accuracy=0.987,
)
omlmd.pull(target, tmp_path)
assert len(list(tmp_path.iterdir())) == 1


@pytest.mark.e2e
def test_e2e_push_pull(tmp_path, target):
omlmd = Helper()
Expand All @@ -118,7 +147,7 @@ def test_e2e_push_pull(tmp_path, target):
accuracy=0.987,
)
omlmd.pull(target, tmp_path)
assert len(list(tmp_path.iterdir())) == 3
assert len(list(tmp_path.iterdir())) == 1


@pytest.mark.e2e
Expand All @@ -132,7 +161,7 @@ def test_e2e_push_pull_with_filters(tmp_path, target):
author="John Doe",
accuracy=0.987,
)
omlmd.pull(target, tmp_path, media_types=[MIME_APPLICATION_MLMODEL])
omlmd.pull(target, tmp_path, media_types=[MIME_BLOB])
assert len(list(tmp_path.iterdir())) == 1


Expand All @@ -155,10 +184,11 @@ def test_e2e_push_pull_column(tmp_path, target):

omlmd.push(target, temp, **md)
omlmd.pull(target, tmp_path)
with open(tmp_path.joinpath(temp.name), "r") as f:
pulled = f.read()
assert pulled == content
pulled_sha = sha256(pulled.encode("utf-8")).hexdigest()
assert pulled_sha == content_sha
out = tmp_path.joinpath(temp.name)
untar(out.with_suffix(".tar"), out)
pulled = out.read_text()
assert pulled == content
pulled_sha = sha256(pulled.encode("utf-8")).hexdigest()
assert pulled_sha == content_sha
finally:
temp.unlink()

0 comments on commit a001104

Please sign in to comment.