diff --git a/omlmd/cli.py b/omlmd/cli.py index b546c8b..5c9c911 100644 --- a/omlmd/cli.py +++ b/omlmd/cli.py @@ -1,11 +1,20 @@ -# Using this to scope CLI targets +"""Command line interface for OMLMD.""" + +from pathlib import Path + import click + from omlmd.helpers import Helper -from omlmd.provider import OMLMDRegistry from omlmd.model_metadata import deserialize_mdfile +from omlmd.provider import OMLMDRegistry - -plain_http = click.option('--plain-http', help="allow insecure connections to registry without SSL check", is_flag=True, default=False, show_default=True) +plain_http = click.option( + "--plain-http", + help="allow insecure connections to registry without SSL check", + is_flag=True, + default=False, + show_default=True, +) def get_OMLMDRegistry(plain_http: bool) -> OMLMDRegistry: @@ -16,47 +25,62 @@ def get_OMLMDRegistry(plain_http: bool) -> OMLMDRegistry: def cli(): pass -@click.command() + +@cli.command() @plain_http -@click.argument('target', required=True) -@click.option('-o', '--output', default='.', show_default=True) -@click.option('--media-types', '-m', multiple=True, default=[]) -def pull(plain_http, target, output, media_types): +@click.argument("target", required=True) +@click.option( + "-o", + "--output", + default=Path.cwd(), + show_default=True, + type=click.Path(path_type=Path, resolve_path=True), +) +@click.option("--media-types", "-m", multiple=True, default=[]) +def pull(plain_http: bool, target: str, output: Path, media_types: tuple[str]): """Pulls an OCI Artifact containing ML model and metadata, filtering if necessary.""" Helper(get_OMLMDRegistry(plain_http)).pull(target, output, media_types) -@click.group() + +@cli.group() def get(): pass -@click.command() + +@get.command() @plain_http -@click.argument('target', required=True) -def config(plain_http, target): +@click.argument("target", required=True) +def config(plain_http: bool, target: str): """Outputs configuration of the given OCI Artifact for ML model and metadata.""" click.echo(Helper(get_OMLMDRegistry(plain_http)).get_config(target)) -@click.command() + +@cli.command() @plain_http -@click.argument('targets', required=True, nargs=-1) -def crawl(plain_http, targets): +@click.argument("targets", required=True, nargs=-1) +def crawl(plain_http: bool, targets: tuple[str]): """Crawls configuration for the given list of OCI Artifact for ML model and metadata.""" click.echo(Helper(get_OMLMDRegistry(plain_http)).crawl(targets)) - -@click.command() + + +@cli.command() @plain_http -@click.argument('target', required=True) -@click.argument('path', required=True, type=click.Path()) -@click.option('-m', '--metadata', required=True, type=click.Path()) -def push(plain_http, target, path, metadata): +@click.argument("target", required=True) +@click.argument( + "path", + required=True, + type=click.Path(path_type=Path, exists=True, resolve_path=True), +) +@click.option( + "-m", + "--metadata", + required=True, + type=click.Path(path_type=Path, exists=True, resolve_path=True), +) +def push(plain_http: bool, target: str, path: Path, metadata: Path): """Pushes an OCI Artifact containing ML model and metadata, supplying metadata from file as necessary""" import logging + logging.basicConfig(level=logging.DEBUG) md = deserialize_mdfile(metadata) click.echo(Helper(get_OMLMDRegistry(plain_http)).push(target, path, **md)) - -cli.add_command(pull) -cli.add_command(get) -get.add_command(config) -cli.add_command(crawl) -cli.add_command(push) diff --git a/omlmd/helpers.py b/omlmd/helpers.py index bb86111..1ffa893 100644 --- a/omlmd/helpers.py +++ b/omlmd/helpers.py @@ -1,33 +1,30 @@ +from __future__ import annotations + +import os +import urllib.request +from collections.abc import Sequence from dataclasses import fields -from typing import Optional, List +from pathlib import Path + from omlmd.listener import Event, Listener, PushEvent from omlmd.model_metadata import ModelMetadata from omlmd.provider import OMLMDRegistry -import os -import urllib.request - -def write_content_to_file(filename: str, content_fn): - try: - with open(filename, 'x') as f: - content = content_fn() - f.write(content) - except FileExistsError: - raise RuntimeError(f"File '{filename}' already exists. Aborting TODO: demonstrator.") -def download_file(uri): +def download_file(uri: str): file_name = os.path.basename(uri) urllib.request.urlretrieve(uri, file_name) return file_name class Helper: + _listeners: list[Listener] = [] - _listeners: List[Listener] = [] - - def __init__(self, registry: Optional[OMLMDRegistry] = None): + def __init__(self, registry: OMLMDRegistry | None = None): if registry is None: - self._registry = OMLMDRegistry(insecure=True) # TODO: this is a bit limiting when used from CLI, to be refactored + self._registry = OMLMDRegistry( + insecure=True + ) # TODO: this is a bit limiting when used from CLI, to be refactored else: self._registry = registry @@ -38,30 +35,45 @@ def registry(self): def push( self, target: str, - path: str, - name: Optional[str] = None, - description: Optional[str] = None, - author: Optional[str] = None, - model_format_name: Optional[str] = None, - model_format_version: Optional[str] = None, - **kwargs + path: Path | str, + name: str | None = None, + description: str | None = None, + author: str | None = None, + model_format_name: str | None = None, + model_format_version: str | None = None, + **kwargs, ): - dataclass_fields = {f.name for f in fields(ModelMetadata)} # avoid anything specified in kwargs which would collide - custom_properties = {k: v for k, v in kwargs.items() if k not in dataclass_fields} + dataclass_fields = { + f.name for f in fields(ModelMetadata) + } # avoid anything specified in kwargs which would collide + custom_properties = { + k: v for k, v in kwargs.items() if k not in dataclass_fields + } model_metadata = ModelMetadata( name=name, description=description, author=author, customProperties=custom_properties, model_format_name=model_format_name, - model_format_version=model_format_version + model_format_version=model_format_version, ) - write_content_to_file("model_metadata.omlmd.json", lambda: model_metadata.to_json()) - write_content_to_file("model_metadata.omlmd.yaml", lambda: model_metadata.to_yaml()) + if isinstance(path, str): + path = Path(path) + + json_meta = path.parent / "model_metadata.omlmd.json" + yaml_meta = path.parent / "model_metadata.omlmd.yaml" + if (p := json_meta).exists() or (p := yaml_meta).exists(): + raise RuntimeError( + f"File '{p}' already exists. Aborting TODO: demonstrator." + ) + json_meta.write_text(model_metadata.to_json()) + yaml_meta.write_text(model_metadata.to_yaml()) + + manifest_cfg = f"{json_meta}:application/x-config" files = [ f"{path}:application/x-mlmodel", - "model_metadata.omlmd.json:application/x-config", - "model_metadata.omlmd.yaml:application/x-config", + manifest_cfg, + f"{yaml_meta}:application/x-config", ] try: # print(target, files, model_metadata.to_annotations_dict()) @@ -69,40 +81,27 @@ def push( target=target, files=files, manifest_annotations=model_metadata.to_annotations_dict(), - manifest_config="model_metadata.omlmd.json:application/x-config" + manifest_config=manifest_cfg, ) self.notify_listeners(PushEvent(target, model_metadata)) return result finally: - os.remove("model_metadata.omlmd.json") - os.remove("model_metadata.omlmd.yaml") - + json_meta.unlink() + yaml_meta.unlink() def pull( - self, - target: str, - outdir: str, - media_types: Optional[List[str]] = None + self, target: str, outdir: Path | str, media_types: Sequence[str] | None = None ): self._registry.download_layers(target, outdir, media_types) + def get_config(self, target: str) -> str: + return f'{{"reference":"{target}", "config": {self._registry.get_config(target)} }}' # this assumes OCI Manifest.Config later is JSON (per std spec) - def get_config( - self, - target: str - ) -> str: - return f'{{"reference":"{target}", "config": {self._registry.get_config(target)} }}' # this assumes OCI Manifest.Config later is JSON (per std spec) - - - def crawl( - self, - targets: List[str] - ) -> str: + def crawl(self, targets: Sequence[str]) -> str: configs = map(self.get_config, targets) joined = "[" + ", ".join(configs) + "]" return joined - def add_listener(self, listener: Listener) -> None: self._listeners.append(listener) diff --git a/omlmd/listener.py b/omlmd/listener.py index b7d4c63..7c5795a 100644 --- a/omlmd/listener.py +++ b/omlmd/listener.py @@ -1,12 +1,16 @@ from __future__ import annotations + from abc import ABC, abstractmethod from typing import Any + from omlmd.model_metadata import ModelMetadata + class Listener(ABC): """ TODO: not yet settled for multi-method or current single update method. """ + @abstractmethod def update(self, source: Any, event: Event) -> None: """ @@ -24,4 +28,3 @@ def __init__(self, target: str, metadata: ModelMetadata): # TODO: cannot just receive yet the push sha, waiting for: https://github.com/oras-project/oras-py/pull/146 in a release. self.target = target self.metadata = metadata - diff --git a/omlmd/model_metadata.py b/omlmd/model_metadata.py index 3b1a4bd..dc26239 100644 --- a/omlmd/model_metadata.py +++ b/omlmd/model_metadata.py @@ -1,25 +1,28 @@ -from dataclasses import dataclass, field, asdict, fields -from typing import Optional, Dict, Any +from __future__ import annotations + import json +from dataclasses import asdict, dataclass, field, fields +from typing import Any + import yaml + @dataclass class ModelMetadata: - name: Optional[str] = None - description: Optional[str] = None - author: Optional[str] = None - customProperties: Optional[Dict[str, Any]] = field(default_factory=dict) - uri: Optional[str] = None - model_format_name: Optional[str] = None - model_format_version: Optional[str] = None + name: str | None = None + description: str | None = None + author: str | None = None + customProperties: dict[str, Any] | None = field(default_factory=dict) + uri: str | None = None + model_format_name: str | None = None + model_format_version: str | None = None def to_json(self) -> str: - return json.dumps(asdict(self), indent=4) - + return json.dumps(self.to_dict(), indent=4) + def to_dict(self) -> dict[str, Any]: - as_json = self.to_json() - return json.loads(as_json) - + return asdict(self) + def to_annotations_dict(self) -> dict[str, str]: as_dict = self.to_dict() result = {} @@ -30,45 +33,49 @@ def to_annotations_dict(self) -> dict[str, str]: elif v is None: continue else: - result[f"{k}+json"] = json.dumps(v) # post-fix "+json" for OCI annotation which is a str representing a json + result[f"{k}+json"] = json.dumps( + v + ) # post-fix "+json" for OCI annotation which is a str representing a json return result @staticmethod - def from_json(json_str: str) -> 'ModelMetadata': + def from_json(json_str: str) -> "ModelMetadata": data = json.loads(json_str) return ModelMetadata(**data) def to_yaml(self) -> str: - return yaml.dump(asdict(self), default_flow_style=False) + return yaml.dump(self.to_dict(), default_flow_style=False) @staticmethod - def from_yaml(yaml_str: str) -> 'ModelMetadata': + def from_yaml(yaml_str: str) -> "ModelMetadata": data = yaml.safe_load(yaml_str) return ModelMetadata(**data) - + @staticmethod - def from_dict(data: Dict[str, Any]) -> 'ModelMetadata': + def from_dict(data: dict[str, Any]) -> "ModelMetadata": known_keys = {f.name for f in fields(ModelMetadata)} known_properties = {key: data.get(key) for key in known_keys if key in data} - custom_properties = {key: value for key, value in data.items() if key not in known_keys} - - return ModelMetadata( - **known_properties, - customProperties=custom_properties - ) + custom_properties = { + key: value for key, value in data.items() if key not in known_keys + } + + return ModelMetadata(**known_properties, customProperties=custom_properties) def deserialize_mdfile(file): - with open(file, 'r') as file: + with open(file, "r") as file: content = file.read() - try: - return json.loads(content) - except json.JSONDecodeError: - pass - - try: - return yaml.safe_load(content) - except yaml.YAMLError: - pass - - raise ValueError(f"The file at {file} is neither a valid JSON nor a valid YAML file.") + + try: + return json.loads(content) + except json.JSONDecodeError: + pass + + try: + return yaml.safe_load(content) + except yaml.YAMLError: + pass + + raise ValueError( + f"The file at {file} is neither a valid JSON nor a valid YAML file." + ) diff --git a/omlmd/provider.py b/omlmd/provider.py index 4e696d0..f521b29 100644 --- a/omlmd/provider.py +++ b/omlmd/provider.py @@ -1,21 +1,21 @@ +from __future__ import annotations + +import logging import os +import tempfile import oras.defaults import oras.oci import oras.provider +import oras.schemas import oras.utils from oras.decorator import ensure_container from oras.provider import container_type -import oras.schemas -import logging -import tempfile -from typing import Optional logger = logging.getLogger(__name__) class OMLMDRegistry(oras.provider.Registry): - @ensure_container def download_layers(self, package, download_dir, media_types): """ @@ -27,15 +27,21 @@ def download_layers(self, package, download_dir, media_types): paths = [] - for layer in manifest.get('layers', []): - if media_types is None or len(media_types) == 0 or layer['mediaType'] in media_types: - artifact = layer['annotations']['org.opencontainers.image.title'] - outfile = oras.utils.sanitize_path(download_dir, os.path.join(download_dir, artifact)) - path = self.download_blob(package, layer['digest'], outfile) + for layer in manifest.get("layers", []): + if ( + media_types is None + or len(media_types) == 0 + or layer["mediaType"] in media_types + ): + artifact = layer["annotations"]["org.opencontainers.image.title"] + outfile = oras.utils.sanitize_path( + download_dir, os.path.join(download_dir, artifact) + ) + path = self.download_blob(package, layer["digest"], outfile) paths.append(path) return paths - + @ensure_container def get_config(self, package) -> str: """ @@ -44,16 +50,18 @@ def get_config(self, package) -> str: # If you intend to call this function again, you might cache this response # for the package of interest. manifest = self.get_manifest(package) - - manifest_config = manifest.get('config', {}) - for layer in manifest.get('layers', []): + manifest_config = manifest.get("config", {}) + + for layer in manifest.get("layers", []): if layer["digest"] == manifest_config["digest"]: temp_dir = tempfile.mkdtemp() try: - with tempfile.NamedTemporaryFile(dir=temp_dir, mode='w', delete=False) as temp_file: - self.download_blob(package, layer['digest'], temp_file.name) - with open(temp_file.name, 'r') as temp_file_read: + with tempfile.NamedTemporaryFile( + dir=temp_dir, delete=False + ) as temp_file: + self.download_blob(package, layer["digest"], temp_file.name) + with open(temp_file.name, "r") as temp_file_read: file_content = temp_file_read.read() return file_content finally: @@ -66,12 +74,12 @@ def get_config(self, package) -> str: os.rmdir(temp_dir) # print("Temporary directory and its contents have been removed.") raise RuntimeError("Unable to locate config layer") - + @ensure_container def get_manifest_response( self, container: container_type, - allowed_media_type: Optional[list] = None, + allowed_media_type: list | None = None, refresh_headers: bool = True, ) -> dict: """ @@ -88,4 +96,4 @@ def get_manifest_response( get_manifest = f"{self.prefix}://{container.manifest_url()}" # type: ignore response = self.do_request(get_manifest, "GET", headers=headers) self._check_200_response(response) - return response \ No newline at end of file + return response diff --git a/pyproject.toml b/pyproject.toml index 3f4d2bb..0f7502d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] markers = [ "e2e: end-to-end testing with localhost:5001", - "e2e_model_registry: end-to-end testing with localhost:5001 and Kubeflow Model Registry" + "e2e_model_registry: end-to-end testing with localhost:5001 and Kubeflow Model Registry", ] [tool.ruff] diff --git a/tests/conftest.py b/tests/conftest.py index 7075826..1fa05d6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,9 +3,15 @@ def pytest_collection_modifyitems(config, items): for item in items: - skip_e2e = pytest.mark.skip(reason="this is an end-to-end test, requires explicit opt-in --e2e option to run.") - skip_e2e_model_registry = pytest.mark.skip(reason="this is an end-to-end test, requires explicit opt-in --e2e-model-registry option to run.") - skip_not_e2e = pytest.mark.skip(reason="skipping non-e2e tests; opt-out of --e2e -like options to run.") + skip_e2e = pytest.mark.skip( + reason="this is an end-to-end test, requires explicit opt-in --e2e option to run." + ) + skip_e2e_model_registry = pytest.mark.skip( + reason="this is an end-to-end test, requires explicit opt-in --e2e-model-registry option to run." + ) + skip_not_e2e = pytest.mark.skip( + reason="skipping non-e2e tests; opt-out of --e2e -like options to run." + ) if "e2e" in item.keywords: if not config.getoption("--e2e"): item.add_marker(skip_e2e) @@ -14,21 +20,26 @@ def pytest_collection_modifyitems(config, items): if not config.getoption("--e2e-model-registry"): item.add_marker(skip_e2e_model_registry) continue - + if config.getoption("--e2e") or config.getoption("--e2e-model-registry"): item.add_marker(skip_not_e2e) def pytest_addoption(parser): parser.addoption( - "--e2e", action="store_true", default=False, help="opt-in to run tests marked with e2e" + "--e2e", + action="store_true", + default=False, + help="opt-in to run tests marked with e2e", ) parser.addoption( - "--e2e-model-registry", action="store_true", default=False, help="opt-in to run tests marked with e2e_model_registry" + "--e2e-model-registry", + action="store_true", + default=False, + help="opt-in to run tests marked with e2e_model_registry", ) @pytest.fixture def target() -> str: return "localhost:5001/testorgns/ml-model-artifact:v1" - diff --git a/tests/test_e2e_model_registry.py b/tests/test_e2e_model_registry.py index 79c1b78..b060f41 100644 --- a/tests/test_e2e_model_registry.py +++ b/tests/test_e2e_model_registry.py @@ -1,14 +1,17 @@ -from omlmd.helpers import Helper -from omlmd.listener import Event, Listener, PushEvent -import pytest from pathlib import Path +from urllib.parse import quote + +import pytest from model_registry import ModelRegistry from model_registry.types import RegisteredModel -from urllib.parse import quote -from omlmd.helpers import download_file + +from omlmd.helpers import Helper, download_file +from omlmd.listener import Event, Listener, PushEvent -def from_oci_to_kfmr(model_registry: ModelRegistry, push_event: PushEvent, sha: str) -> RegisteredModel: +def from_oci_to_kfmr( + model_registry: ModelRegistry, push_event: PushEvent, sha: str +) -> RegisteredModel: rm = model_registry.register_model( name=push_event.metadata.name, uri=f"oci-artifact://{push_event.target}", @@ -27,22 +30,26 @@ def test_e2e_model_registry_scenario1(tmp_path, target): """ Given a ML model and some metadata, to OCI registry, and then to KF Model Registry (at once) """ - model_registry = ModelRegistry("http://localhost", 8081, author="mmortari", is_secure=False) + model_registry = ModelRegistry( + "http://localhost", 8081, author="mmortari", is_secure=False + ) class ListenerForModelRegistry(Listener): sha = None - rm = None + rm = None def update(self, source: Helper, event: Event) -> None: if isinstance(event, PushEvent): - self.sha = source.registry.get_manifest_response(event.target).headers["Docker-Content-Digest"] + self.sha = source.registry.get_manifest_response(event.target).headers[ + "Docker-Content-Digest" + ] print(self.sha) self.rm = from_oci_to_kfmr(model_registry, event, self.sha) - + listener = ListenerForModelRegistry() omlmd = Helper() omlmd.add_listener(listener) - + # assuming a model ... model_file = Path(__file__).parent / ".." / "README.md" # ...with some additional characteristics @@ -54,7 +61,7 @@ def update(self, source: Helper, event: Event) -> None: name="mnist", description="Lorem ipsum", author="John Doe", - accuracy=accuracy_value + accuracy=accuracy_value, ) v = quote(listener.sha) @@ -66,7 +73,7 @@ def update(self, source: Helper, event: Event) -> None: mv = model_registry.get_model_version("mnist", v) assert mv.description == "Lorem ipsum" assert mv.author == "John Doe" - assert mv.custom_properties == {'accuracy': 0.987} + assert mv.custom_properties == {"accuracy": 0.987} ma = model_registry.get_model_artifact("mnist", v) assert ma.uri == f"oci-artifact://{target}" @@ -80,7 +87,9 @@ def test_e2e_model_registry_scenario2(tmp_path, target): """ Given some metadata entry in KF model registry, attempt retrieve pointed ML model file asset, then OCI registry """ - model_registry = ModelRegistry("http://localhost", 8081, author="mmortari", is_secure=False) + model_registry = ModelRegistry( + "http://localhost", 8081, author="mmortari", is_secure=False + ) # assuming a model indexed on KF Model Registry ... registeredmodel_name = "mnist" @@ -95,11 +104,11 @@ def test_e2e_model_registry_scenario2(tmp_path, target): metadata={ "accuracy": 3.14, "license": "apache-2.0", - } + }, ) lookup_name = "mnist" - lookup_version = "v0.1" + lookup_version = "v0.1" _ = model_registry.get_registered_model(lookup_name) model_version = model_registry.get_model_version(lookup_name, lookup_version) @@ -120,7 +129,7 @@ def test_e2e_model_registry_scenario2(tmp_path, target): author=model_version.author, model_format_name=model_artifact.model_format_name, model_format_version=model_artifact.model_format_version, - **model_version.custom_properties + **model_version.custom_properties, ) # curl http://localhost:5001/v2/testorgns/ml-model-artifact/manifests/v0.1 -H "Accept: application/vnd.oci.image.manifest.v1+json" --verbose # tag v0.1 is defined in this test scenario. diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 6350a48..a4b4f9d 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,11 +1,14 @@ +import json +import tempfile +from pathlib import Path + +import pytest + from omlmd.helpers import Helper from omlmd.listener import Event, Listener from omlmd.model_metadata import ModelMetadata, deserialize_mdfile -import tempfile -import json from omlmd.provider import OMLMDRegistry -import pytest -from pathlib import Path + def test_call_push_using_md_from_file(mocker): helper = Helper() @@ -15,7 +18,7 @@ def test_call_push_using_md_from_file(mocker): "name": "mnist", "description": "Lorem ipsum", "author": "John Doe", - "accuracy": .987 + "accuracy": 0.987, } with tempfile.NamedTemporaryFile(delete=True, mode="w") as f: f.write(json.dumps(md)) @@ -29,7 +32,7 @@ def test_call_push_using_md_from_file(mocker): name="mnist", description="Lorem ipsum", author="John Doe", - accuracy=0.987 + accuracy=0.987, ) @@ -39,16 +42,18 @@ def test_push_event(mocker): omlmd = Helper(registry) events = [] + class MyListener(Listener): def update(self, _, event: Event) -> None: events.append(event) + omlmd.add_listener(MyListener()) md = { "name": "mnist", "description": "Lorem ipsum", "author": "John Doe", - "accuracy": .987 + "accuracy": 0.987, } omlmd.push("unexistent:8080/testorgns/ml-iris:v1", "README.md", **md) @@ -67,12 +72,9 @@ def test_e2e_push_pull(tmp_path, target): name="mnist", description="Lorem ipsum", author="John Doe", - accuracy=0.987 - ) - omlmd.pull( - target, - tmp_path + accuracy=0.987, ) + omlmd.pull(target, tmp_path) assert len(list(tmp_path.iterdir())) == 3 @@ -85,11 +87,7 @@ def test_e2e_push_pull_with_filters(tmp_path, target): name="mnist", description="Lorem ipsum", author="John Doe", - accuracy=0.987 - ) - omlmd.pull( - target, - tmp_path, - media_types=["application/x-mlmodel"] + accuracy=0.987, ) + omlmd.pull(target, tmp_path, media_types=["application/x-mlmodel"]) assert len(list(tmp_path.iterdir())) == 1 diff --git a/tests/test_omlmd.py b/tests/test_omlmd.py index 03f6235..0d171dd 100644 --- a/tests/test_omlmd.py +++ b/tests/test_omlmd.py @@ -1,9 +1,11 @@ -from omlmd.model_metadata import ModelMetadata -from omlmd.model_metadata import deserialize_mdfile -import tempfile import json +import tempfile + import yaml +from omlmd.model_metadata import ModelMetadata, deserialize_mdfile + + def test_dry_run_model_metadata_json_yaml_conversions(): metadata = ModelMetadata(name="Example Model", author="John Doe") json_str = metadata.to_json() @@ -23,7 +25,13 @@ def test_dry_run_model_metadata_json_yaml_conversions(): def test_deserialize_file_json(): - md_dict = ModelMetadata(name="Example Model", author="John Doe", model_format_name="onnx", model_format_version="1", customProperties={"accuracy": .987}).to_dict() + md_dict = ModelMetadata( + name="Example Model", + author="John Doe", + model_format_name="onnx", + model_format_version="1", + customProperties={"accuracy": 0.987}, + ).to_dict() json_str = json.dumps(md_dict) with tempfile.NamedTemporaryFile(delete=True, mode="w") as f: @@ -34,7 +42,13 @@ def test_deserialize_file_json(): def test_deserialize_file_yaml(): - md_dict = ModelMetadata(name="Example Model", author="John Doe", model_format_name="onnx", model_format_version="1", customProperties={"accuracy": .987}).to_dict() + md_dict = ModelMetadata( + name="Example Model", + author="John Doe", + model_format_name="onnx", + model_format_version="1", + customProperties={"accuracy": 0.987}, + ).to_dict() yaml_str = yaml.dump(md_dict) with tempfile.NamedTemporaryFile(delete=True, mode="w") as f: @@ -49,14 +63,12 @@ def test_from_dict(): "name": "mnist", "description": "Lorem ipsum", "author": "John Doe", - "accuracy": .987 + "accuracy": 0.987, } md = ModelMetadata( name="mnist", description="Lorem ipsum", author="John Doe", - customProperties={ - "accuracy": .987 - } + customProperties={"accuracy": 0.987}, ) assert ModelMetadata.from_dict(data) == md