Skip to content

Commit

Permalink
Minor improvements (#9)
Browse files Browse the repository at this point in the history
* formatting

Signed-off-by: Isabella do Amaral <[email protected]>

* simplify type annotations

Signed-off-by: Isabella do Amaral <[email protected]>

* provider: skip mode-setting before download_blob

Signed-off-by: Isabella do Amaral <[email protected]>

* mlmd: dont hold open file after reading content

Signed-off-by: Isabella do Amaral <[email protected]>

* mlmd: simplify serialization methods

Signed-off-by: Isabella do Amaral <[email protected]>

* cli: simplify click cmd declarations

Signed-off-by: Isabella do Amaral <[email protected]>

* improve typing

Signed-off-by: Isabella do Amaral <[email protected]>

* helpers: save json and yaml md on path

Signed-off-by: Isabella do Amaral <[email protected]>

---------

Signed-off-by: Isabella do Amaral <[email protected]>
  • Loading branch information
isinyaaa authored Aug 14, 2024
1 parent c5e9455 commit c5c2797
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 189 deletions.
80 changes: 52 additions & 28 deletions omlmd/cli.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
# Using this to scope CLI targets
"""Command line interface for OMLMD."""

from pathlib import Path

import click

from omlmd.helpers import Helper
from omlmd.provider import OMLMDRegistry
from omlmd.model_metadata import deserialize_mdfile
from omlmd.provider import OMLMDRegistry


plain_http = click.option('--plain-http', help="allow insecure connections to registry without SSL check", is_flag=True, default=False, show_default=True)
plain_http = click.option(
"--plain-http",
help="allow insecure connections to registry without SSL check",
is_flag=True,
default=False,
show_default=True,
)


def get_OMLMDRegistry(plain_http: bool) -> OMLMDRegistry:
Expand All @@ -16,47 +25,62 @@ def get_OMLMDRegistry(plain_http: bool) -> OMLMDRegistry:
def cli():
pass

@click.command()

@cli.command()
@plain_http
@click.argument('target', required=True)
@click.option('-o', '--output', default='.', show_default=True)
@click.option('--media-types', '-m', multiple=True, default=[])
def pull(plain_http, target, output, media_types):
@click.argument("target", required=True)
@click.option(
"-o",
"--output",
default=Path.cwd(),
show_default=True,
type=click.Path(path_type=Path, resolve_path=True),
)
@click.option("--media-types", "-m", multiple=True, default=[])
def pull(plain_http: bool, target: str, output: Path, media_types: tuple[str]):
"""Pulls an OCI Artifact containing ML model and metadata, filtering if necessary."""
Helper(get_OMLMDRegistry(plain_http)).pull(target, output, media_types)

@click.group()

@cli.group()
def get():
pass

@click.command()

@get.command()
@plain_http
@click.argument('target', required=True)
def config(plain_http, target):
@click.argument("target", required=True)
def config(plain_http: bool, target: str):
"""Outputs configuration of the given OCI Artifact for ML model and metadata."""
click.echo(Helper(get_OMLMDRegistry(plain_http)).get_config(target))

@click.command()

@cli.command()
@plain_http
@click.argument('targets', required=True, nargs=-1)
def crawl(plain_http, targets):
@click.argument("targets", required=True, nargs=-1)
def crawl(plain_http: bool, targets: tuple[str]):
"""Crawls configuration for the given list of OCI Artifact for ML model and metadata."""
click.echo(Helper(get_OMLMDRegistry(plain_http)).crawl(targets))

@click.command()


@cli.command()
@plain_http
@click.argument('target', required=True)
@click.argument('path', required=True, type=click.Path())
@click.option('-m', '--metadata', required=True, type=click.Path())
def push(plain_http, target, path, metadata):
@click.argument("target", required=True)
@click.argument(
"path",
required=True,
type=click.Path(path_type=Path, exists=True, resolve_path=True),
)
@click.option(
"-m",
"--metadata",
required=True,
type=click.Path(path_type=Path, exists=True, resolve_path=True),
)
def push(plain_http: bool, target: str, path: Path, metadata: Path):
"""Pushes an OCI Artifact containing ML model and metadata, supplying metadata from file as necessary"""
import logging

logging.basicConfig(level=logging.DEBUG)
md = deserialize_mdfile(metadata)
click.echo(Helper(get_OMLMDRegistry(plain_http)).push(target, path, **md))

cli.add_command(pull)
cli.add_command(get)
get.add_command(config)
cli.add_command(crawl)
cli.add_command(push)
99 changes: 49 additions & 50 deletions omlmd/helpers.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,30 @@
from __future__ import annotations

import os
import urllib.request
from collections.abc import Sequence
from dataclasses import fields
from typing import Optional, List
from pathlib import Path

from omlmd.listener import Event, Listener, PushEvent
from omlmd.model_metadata import ModelMetadata
from omlmd.provider import OMLMDRegistry
import os
import urllib.request

def write_content_to_file(filename: str, content_fn):
try:
with open(filename, 'x') as f:
content = content_fn()
f.write(content)
except FileExistsError:
raise RuntimeError(f"File '{filename}' already exists. Aborting TODO: demonstrator.")


def download_file(uri):
def download_file(uri: str):
file_name = os.path.basename(uri)
urllib.request.urlretrieve(uri, file_name)
return file_name


class Helper:
_listeners: list[Listener] = []

_listeners: List[Listener] = []

def __init__(self, registry: Optional[OMLMDRegistry] = None):
def __init__(self, registry: OMLMDRegistry | None = None):
if registry is None:
self._registry = OMLMDRegistry(insecure=True) # TODO: this is a bit limiting when used from CLI, to be refactored
self._registry = OMLMDRegistry(
insecure=True
) # TODO: this is a bit limiting when used from CLI, to be refactored
else:
self._registry = registry

Expand All @@ -38,71 +35,73 @@ def registry(self):
def push(
self,
target: str,
path: str,
name: Optional[str] = None,
description: Optional[str] = None,
author: Optional[str] = None,
model_format_name: Optional[str] = None,
model_format_version: Optional[str] = None,
**kwargs
path: Path | str,
name: str | None = None,
description: str | None = None,
author: str | None = None,
model_format_name: str | None = None,
model_format_version: str | None = None,
**kwargs,
):
dataclass_fields = {f.name for f in fields(ModelMetadata)} # avoid anything specified in kwargs which would collide
custom_properties = {k: v for k, v in kwargs.items() if k not in dataclass_fields}
dataclass_fields = {
f.name for f in fields(ModelMetadata)
} # avoid anything specified in kwargs which would collide
custom_properties = {
k: v for k, v in kwargs.items() if k not in dataclass_fields
}
model_metadata = ModelMetadata(
name=name,
description=description,
author=author,
customProperties=custom_properties,
model_format_name=model_format_name,
model_format_version=model_format_version
model_format_version=model_format_version,
)
write_content_to_file("model_metadata.omlmd.json", lambda: model_metadata.to_json())
write_content_to_file("model_metadata.omlmd.yaml", lambda: model_metadata.to_yaml())
if isinstance(path, str):
path = Path(path)

json_meta = path.parent / "model_metadata.omlmd.json"
yaml_meta = path.parent / "model_metadata.omlmd.yaml"
if (p := json_meta).exists() or (p := yaml_meta).exists():
raise RuntimeError(
f"File '{p}' already exists. Aborting TODO: demonstrator."
)
json_meta.write_text(model_metadata.to_json())
yaml_meta.write_text(model_metadata.to_yaml())

manifest_cfg = f"{json_meta}:application/x-config"
files = [
f"{path}:application/x-mlmodel",
"model_metadata.omlmd.json:application/x-config",
"model_metadata.omlmd.yaml:application/x-config",
manifest_cfg,
f"{yaml_meta}:application/x-config",
]
try:
# print(target, files, model_metadata.to_annotations_dict())
result = self._registry.push(
target=target,
files=files,
manifest_annotations=model_metadata.to_annotations_dict(),
manifest_config="model_metadata.omlmd.json:application/x-config"
manifest_config=manifest_cfg,
)
self.notify_listeners(PushEvent(target, model_metadata))
return result
finally:
os.remove("model_metadata.omlmd.json")
os.remove("model_metadata.omlmd.yaml")

json_meta.unlink()
yaml_meta.unlink()

def pull(
self,
target: str,
outdir: str,
media_types: Optional[List[str]] = None
self, target: str, outdir: Path | str, media_types: Sequence[str] | None = None
):
self._registry.download_layers(target, outdir, media_types)

def get_config(self, target: str) -> str:
return f'{{"reference":"{target}", "config": {self._registry.get_config(target)} }}' # this assumes OCI Manifest.Config later is JSON (per std spec)

def get_config(
self,
target: str
) -> str:
return f'{{"reference":"{target}", "config": {self._registry.get_config(target)} }}' # this assumes OCI Manifest.Config later is JSON (per std spec)


def crawl(
self,
targets: List[str]
) -> str:
def crawl(self, targets: Sequence[str]) -> str:
configs = map(self.get_config, targets)
joined = "[" + ", ".join(configs) + "]"
return joined


def add_listener(self, listener: Listener) -> None:
self._listeners.append(listener)

Expand Down
5 changes: 4 additions & 1 deletion omlmd/listener.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any

from omlmd.model_metadata import ModelMetadata


class Listener(ABC):
"""
TODO: not yet settled for multi-method or current single update method.
"""

@abstractmethod
def update(self, source: Any, event: Event) -> None:
"""
Expand All @@ -24,4 +28,3 @@ def __init__(self, target: str, metadata: ModelMetadata):
# TODO: cannot just receive yet the push sha, waiting for: https://github.com/oras-project/oras-py/pull/146 in a release.
self.target = target
self.metadata = metadata

Loading

0 comments on commit c5c2797

Please sign in to comment.