diff --git a/omlmd/cli.py b/omlmd/cli.py index 0f7a1fe..8d66eca 100644 --- a/omlmd/cli.py +++ b/omlmd/cli.py @@ -1,15 +1,15 @@ """Command line interface for OMLMD.""" + from __future__ import annotations + +import logging from pathlib import Path import click import cloup -import logging - -from omlmd.helpers import Helper -from omlmd.model_metadata import deserialize_mdfile -from omlmd.provider import OMLMDRegistry +from .helpers import Helper +from .model_metadata import deserialize_mdfile logger = logging.getLogger(__name__) @@ -23,10 +23,6 @@ ) -def get_OMLMDRegistry(plain_http: bool) -> OMLMDRegistry: - return OMLMDRegistry(insecure=plain_http) - - @cloup.group() def cli(): logging.basicConfig(level=logging.INFO) @@ -45,7 +41,7 @@ def cli(): @click.option("--media-types", "-m", multiple=True, default=[]) def pull(plain_http: bool, target: str, output: Path, media_types: tuple[str]): """Pulls an OCI Artifact containing ML model and metadata, filtering if necessary.""" - Helper(get_OMLMDRegistry(plain_http)).pull(target, output, media_types) + Helper.from_default_registry(plain_http).pull(target, output, media_types) @cli.group() @@ -58,7 +54,7 @@ def get(): @click.argument("target", required=True) def config(plain_http: bool, target: str): """Outputs configuration of the given OCI Artifact for ML model and metadata.""" - click.echo(Helper(get_OMLMDRegistry(plain_http)).get_config(target)) + click.echo(Helper.from_default_registry(plain_http).get_config(target)) @cli.command() @@ -66,7 +62,7 @@ def config(plain_http: bool, target: str): @click.argument("targets", required=True, nargs=-1) def crawl(plain_http: bool, targets: tuple[str]): """Crawls configuration for the given list of OCI Artifact for ML model and metadata.""" - click.echo(Helper(get_OMLMDRegistry(plain_http)).crawl(targets)) + click.echo(Helper.from_default_registry(plain_http).crawl(targets)) @cli.command() @@ -83,15 +79,21 @@ def crawl(plain_http: bool, targets: tuple[str]): "-m", "--metadata", type=click.Path(path_type=Path, exists=True, resolve_path=True), - help="Metadata file in JSON or YAML format" + help="Metadata file in JSON or YAML format", ), - cloup.option('--empty-metadata', help='Push with empty metadata', is_flag=True), + cloup.option("--empty-metadata", help="Push with empty metadata", is_flag=True), constraint=cloup.constraints.require_one, ) -def push(plain_http: bool, target: str, path: Path, metadata: Path | None, empty_metadata: bool): +def push( + plain_http: bool, + target: str, + path: Path, + metadata: Path | None, + empty_metadata: bool, +): """Pushes an OCI Artifact containing ML model and metadata, supplying metadata from file as necessary""" if empty_metadata: logger.warning(f"Pushing to {target} with empty metadata.") md = deserialize_mdfile(metadata) if metadata else {} - click.echo(Helper(get_OMLMDRegistry(plain_http)).push(target, path, **md)) + click.echo(Helper.from_default_registry(plain_http).push(target, path, **md)) diff --git a/omlmd/helpers.py b/omlmd/helpers.py index 4e38e36..4a254f3 100644 --- a/omlmd/helpers.py +++ b/omlmd/helpers.py @@ -4,19 +4,18 @@ import os import urllib.request from collections.abc import Sequence -from dataclasses import fields +from dataclasses import dataclass, field, fields from pathlib import Path -from omlmd.constants import ( +from .constants import ( FILENAME_METADATA_JSON, FILENAME_METADATA_YAML, MIME_APPLICATION_CONFIG, MIME_APPLICATION_MLMODEL, ) -from omlmd.listener import Event, Listener, PushEvent -from omlmd.model_metadata import ModelMetadata -from omlmd.provider import OMLMDRegistry - +from .listener import Event, Listener, PushEvent +from .model_metadata import ModelMetadata +from .provider import OMLMDRegistry logger = logging.getLogger(__name__) @@ -27,20 +26,16 @@ def download_file(uri: str): return file_name +@dataclass class Helper: - _listeners: list[Listener] = [] - - def __init__(self, registry: OMLMDRegistry | None = None): - if registry is None: - self._registry = OMLMDRegistry( - insecure=True - ) # TODO: this is a bit limiting when used from CLI, to be refactored - else: - self._registry = registry + _registry: OMLMDRegistry = field( + default_factory=lambda: OMLMDRegistry(insecure=True) + ) + _listeners: list[Listener] = field(default_factory=list) - @property - def registry(self): - return self._registry + @classmethod + def from_default_registry(cls, insecure: bool): + return cls(OMLMDRegistry(insecure=insecure)) def push( self, @@ -102,7 +97,9 @@ def push( manifest_config=manifest_cfg, do_chunked=True, ) - self.notify_listeners(PushEvent(target, model_metadata)) + self.notify_listeners( + PushEvent.from_response(result, target, model_metadata) + ) return result finally: if owns_meta_files: diff --git a/omlmd/listener.py b/omlmd/listener.py index 7c5795a..8122fde 100644 --- a/omlmd/listener.py +++ b/omlmd/listener.py @@ -1,9 +1,12 @@ from __future__ import annotations +import typing as t from abc import ABC, abstractmethod -from typing import Any +from dataclasses import dataclass -from omlmd.model_metadata import ModelMetadata +import requests + +from .model_metadata import ModelMetadata class Listener(ABC): @@ -12,19 +15,25 @@ class Listener(ABC): """ @abstractmethod - def update(self, source: Any, event: Event) -> None: + def update(self, source: t.Any, event: Event) -> None: """ Receive update event. """ pass -class Event: +class Event(ABC): pass +@dataclass class PushEvent(Event): - def __init__(self, target: str, metadata: ModelMetadata): - # TODO: cannot just receive yet the push sha, waiting for: https://github.com/oras-project/oras-py/pull/146 in a release. - self.target = target - self.metadata = metadata + digest: str + target: str + metadata: ModelMetadata + + @classmethod + def from_response( + cls, response: requests.Response, target: str, metadata: ModelMetadata + ) -> "PushEvent": + return cls(response.headers["Docker-Content-Digest"], target, metadata) diff --git a/omlmd/provider.py b/omlmd/provider.py index f521b29..1ec6a81 100644 --- a/omlmd/provider.py +++ b/omlmd/provider.py @@ -4,18 +4,15 @@ import os import tempfile -import oras.defaults -import oras.oci -import oras.provider -import oras.schemas -import oras.utils +from oras import provider from oras.decorator import ensure_container -from oras.provider import container_type +from oras.defaults import annotation_title as ANNOTATION_TITLE +from oras.utils import sanitize_path logger = logging.getLogger(__name__) -class OMLMDRegistry(oras.provider.Registry): +class OMLMDRegistry(provider.Registry): @ensure_container def download_layers(self, package, download_dir, media_types): """ @@ -33,8 +30,8 @@ def download_layers(self, package, download_dir, media_types): or len(media_types) == 0 or layer["mediaType"] in media_types ): - artifact = layer["annotations"]["org.opencontainers.image.title"] - outfile = oras.utils.sanitize_path( + artifact = layer["annotations"][ANNOTATION_TITLE] + outfile = sanitize_path( download_dir, os.path.join(download_dir, artifact) ) path = self.download_blob(package, layer["digest"], outfile) @@ -74,26 +71,3 @@ def get_config(self, package) -> str: os.rmdir(temp_dir) # print("Temporary directory and its contents have been removed.") raise RuntimeError("Unable to locate config layer") - - @ensure_container - def get_manifest_response( - self, - container: container_type, - allowed_media_type: list | None = None, - refresh_headers: bool = True, - ) -> dict: - """ - like get_manifest but return response, - temporary until https://github.com/oras-project/oras-py/pull/146 in a release. - """ - if not allowed_media_type: - allowed_media_type = [oras.defaults.default_manifest_media_type] - headers = {"Accept": ";".join(allowed_media_type)} - - if not refresh_headers: - headers.update(self.headers) - - get_manifest = f"{self.prefix}://{container.manifest_url()}" # type: ignore - response = self.do_request(get_manifest, "GET", headers=headers) - self._check_200_response(response) - return response diff --git a/poetry.lock b/poetry.lock index c9fbf99..31ce5b5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1130,13 +1130,13 @@ files = [ [[package]] name = "model-registry" -version = "0.2.4a1" +version = "0.2.9" description = "Client for Kubeflow Model Registry" optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "model_registry-0.2.4a1-py3-none-any.whl", hash = "sha256:727300fbd2eb1ec54230507b7f79666a7ddfad50d62912d9832e1c9ed7f5c015"}, - {file = "model_registry-0.2.4a1.tar.gz", hash = "sha256:7fcab7cef0006462bbeb161a526d8057c66a6c9a2898554ef92088d0c1af3d91"}, + {file = "model_registry-0.2.9-py3-none-any.whl", hash = "sha256:fa32689fdc0afa3499d241d990928299efc2e177830349662ab7c76ac7d81ebb"}, + {file = "model_registry-0.2.9.tar.gz", hash = "sha256:def1841cb9b17361785625cc956a43b57c738cb013289334f1284abfcbf86a68"}, ] [package.dependencies] @@ -1149,7 +1149,7 @@ python-dateutil = ">=2.9.0.post0,<3.0.0" typing-extensions = ">=4.8,<5.0" [package.extras] -hf = ["huggingface-hub (>=0.20.1,<0.25.0)"] +hf = ["huggingface-hub (>=0.20.1,<0.26.0)"] [[package]] name = "multidict" @@ -1873,6 +1873,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -2396,6 +2397,20 @@ files = [ {file = "types_PyYAML-6.0.12.20240808-py3-none-any.whl", hash = "sha256:deda34c5c655265fc517b546c902aa6eed2ef8d3e921e4765fe606fe2afe8d35"}, ] +[[package]] +name = "types-requests" +version = "2.32.0.20241016" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-requests-2.32.0.20241016.tar.gz", hash = "sha256:0d9cad2f27515d0e3e3da7134a1b6f28fb97129d86b867f24d9c726452634d95"}, + {file = "types_requests-2.32.0.20241016-py3-none-any.whl", hash = "sha256:4195d62d6d3e043a4eaaf08ff8a62184584d2e8684e9d2aa178c7915a7da3747"}, +] + +[package.dependencies] +urllib3 = ">=2" + [[package]] name = "typing-extensions" version = "4.12.2" @@ -2567,4 +2582,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "be66211b505ebaf3654f303b14158bfb27065852edc71d6e3592a58f65280138" +content-hash = "09148e1d60c1b8c92d31467b68e406f6dbf2566f3d0b0cff7f7f01067276d726" diff --git a/pyproject.toml b/pyproject.toml index 6255327..de386f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,10 +27,11 @@ scikit-learn = "^1.5.0" ipykernel = "^6.29.4" nbconvert = "^7.16.4" markdown-it-py = "^3.0.0" -model-registry = "^0.2.4a1" +model-registry = ">=0.2.9,<0.3.0" ruff = "^0.6.1" mypy = "^1.11.1" types-pyyaml = "^6.0.12.20240808" +types-requests = "^2.32.0.20241016" [tool.poetry.scripts] omlmd = "omlmd.cli:cli" @@ -50,7 +51,9 @@ target-version = "py39" respect-gitignore = true [tool.ruff.lint.per-file-ignores] -"*.ipynb" = ["E402"] # exclude https://docs.astral.sh/ruff/rules/module-import-not-at-top-of-file/#notebook-behavior from linting, especially for demos. +"*.ipynb" = [ + "E402", +] # exclude https://docs.astral.sh/ruff/rules/module-import-not-at-top-of-file/#notebook-behavior from linting, especially for demos. [tool.mypy] python_version = "3.9" diff --git a/tests/test_e2e_model_registry.py b/tests/test_e2e_model_registry.py index b060f41..972b0ac 100644 --- a/tests/test_e2e_model_registry.py +++ b/tests/test_e2e_model_registry.py @@ -12,6 +12,9 @@ def from_oci_to_kfmr( model_registry: ModelRegistry, push_event: PushEvent, sha: str ) -> RegisteredModel: + assert push_event.metadata.name + assert push_event.metadata.model_format_name + assert push_event.metadata.model_format_version rm = model_registry.register_model( name=push_event.metadata.name, uri=f"oci-artifact://{push_event.target}", @@ -35,15 +38,12 @@ def test_e2e_model_registry_scenario1(tmp_path, target): ) class ListenerForModelRegistry(Listener): - sha = None - rm = None + sha: str + rm: RegisteredModel def update(self, source: Helper, event: Event) -> None: if isinstance(event, PushEvent): - self.sha = source.registry.get_manifest_response(event.target).headers[ - "Docker-Content-Digest" - ] - print(self.sha) + self.sha = event.digest self.rm = from_oci_to_kfmr(model_registry, event, self.sha) listener = ListenerForModelRegistry() @@ -61,21 +61,26 @@ def update(self, source: Helper, event: Event) -> None: name="mnist", description="Lorem ipsum", author="John Doe", + model_format_name="onnx", + model_format_version="1", accuracy=accuracy_value, ) v = quote(listener.sha) rm = model_registry.get_registered_model("mnist") + assert rm assert rm.id == listener.rm.id assert rm.name == "mnist" mv = model_registry.get_model_version("mnist", v) + assert mv assert mv.description == "Lorem ipsum" assert mv.author == "John Doe" assert mv.custom_properties == {"accuracy": 0.987} ma = model_registry.get_model_artifact("mnist", v) + assert ma assert ma.uri == f"oci-artifact://{target}" # curl http://localhost:5001/v2/testorgns/ml-model-artifact/manifests/v1 -H "Accept: application/vnd.oci.image.manifest.v1+json" --verbose @@ -112,7 +117,9 @@ def test_e2e_model_registry_scenario2(tmp_path, target): _ = model_registry.get_registered_model(lookup_name) model_version = model_registry.get_model_version(lookup_name, lookup_version) + assert model_version model_artifact = model_registry.get_model_artifact(lookup_name, lookup_version) + assert model_artifact file_from_mr = download_file(model_artifact.uri) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index e5efc83..00bc193 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,8 +1,9 @@ import json import subprocess import tempfile -from pathlib import Path +import typing as t from hashlib import sha256 +from pathlib import Path import pytest @@ -41,13 +42,19 @@ def test_call_push_using_md_from_file(mocker): def test_push_event(mocker): registry = OMLMDRegistry() - mocker.patch.object(registry, "push", return_value=None) + m = mocker.MagicMock() + m.headers = {"Docker-Content-Digest": "sha256:123"} + mocker.patch.object( + registry, + "push", + return_value=m, + ) omlmd = Helper(registry) events = [] class MyListener(Listener): - def update(self, _, event: Event) -> None: + def update(self, source: t.Any, event: Event) -> None: events.append(event) omlmd.add_listener(MyListener()) @@ -141,7 +148,7 @@ def test_e2e_push_pull_column(tmp_path, target): content = "Hello, World!" content_sha = sha256(content.encode("utf-8")).hexdigest() here = Path.cwd() - temp = here / ("sha256:"+content_sha) + temp = here / ("sha256:" + content_sha) try: with open(temp, "w") as f: f.write(content) @@ -155,4 +162,3 @@ def test_e2e_push_pull_column(tmp_path, target): assert pulled_sha == content_sha finally: temp.unlink() -