diff --git a/omlmd/cli.py b/omlmd/cli.py index 47772df..8d66eca 100644 --- a/omlmd/cli.py +++ b/omlmd/cli.py @@ -41,7 +41,7 @@ def cli(): @click.option("--media-types", "-m", multiple=True, default=[]) def pull(plain_http: bool, target: str, output: Path, media_types: tuple[str]): """Pulls an OCI Artifact containing ML model and metadata, filtering if necessary.""" - Helper.from_plain(plain_http).pull(target, output, media_types) + Helper.from_default_registry(plain_http).pull(target, output, media_types) @cli.group() @@ -54,7 +54,7 @@ def get(): @click.argument("target", required=True) def config(plain_http: bool, target: str): """Outputs configuration of the given OCI Artifact for ML model and metadata.""" - click.echo(Helper.from_plain(plain_http).get_config(target)) + click.echo(Helper.from_default_registry(plain_http).get_config(target)) @cli.command() @@ -62,7 +62,7 @@ def config(plain_http: bool, target: str): @click.argument("targets", required=True, nargs=-1) def crawl(plain_http: bool, targets: tuple[str]): """Crawls configuration for the given list of OCI Artifact for ML model and metadata.""" - click.echo(Helper.from_plain(plain_http).crawl(targets)) + click.echo(Helper.from_default_registry(plain_http).crawl(targets)) @cli.command() @@ -96,4 +96,4 @@ def push( if empty_metadata: logger.warning(f"Pushing to {target} with empty metadata.") md = deserialize_mdfile(metadata) if metadata else {} - click.echo(Helper.from_plain(plain_http).push(target, path, **md)) + click.echo(Helper.from_default_registry(plain_http).push(target, path, **md)) diff --git a/omlmd/helpers.py b/omlmd/helpers.py index e2f9044..7f4aea6 100644 --- a/omlmd/helpers.py +++ b/omlmd/helpers.py @@ -28,7 +28,7 @@ def download_file(uri: str): @dataclass class Helper: - registry: OMLMDRegistry = ( + _registry: OMLMDRegistry = ( field( # TODO: this is a bit limiting when used from CLI, to be refactored default_factory=lambda: OMLMDRegistry(insecure=True) ) @@ -36,7 +36,7 @@ class Helper: _listeners: list[Listener] = field(default_factory=list) @classmethod - def from_plain(cls, insecure: bool): + def from_default_registry(cls, insecure: bool): return cls(OMLMDRegistry(insecure=insecure)) def push( @@ -92,7 +92,7 @@ def push( ] try: # print(target, files, model_metadata.to_annotations_dict()) - result = self.registry.push( + result = self._registry.push( target=target, files=files, manifest_annotations=model_metadata.to_annotations_dict(), @@ -115,10 +115,10 @@ def push( def pull( self, target: str, outdir: Path | str, media_types: Sequence[str] | None = None ): - self.registry.download_layers(target, outdir, media_types) + self._registry.download_layers(target, outdir, media_types) def get_config(self, target: str) -> str: - return f'{{"reference":"{target}", "config": {self.registry.get_config(target)} }}' # this assumes OCI Manifest.Config later is JSON (per std spec) + return f'{{"reference":"{target}", "config": {self._registry.get_config(target)} }}' # this assumes OCI Manifest.Config later is JSON (per std spec) def crawl(self, targets: Sequence[str]) -> str: configs = map(self.get_config, targets) diff --git a/omlmd/listener.py b/omlmd/listener.py index 3c59d25..51583fa 100644 --- a/omlmd/listener.py +++ b/omlmd/listener.py @@ -26,6 +26,6 @@ class Event(ABC): @dataclass class PushEvent(Event): - sha: str + digest: str target: str metadata: ModelMetadata diff --git a/tests/test_e2e_model_registry.py b/tests/test_e2e_model_registry.py index f26f5f6..748e2e1 100644 --- a/tests/test_e2e_model_registry.py +++ b/tests/test_e2e_model_registry.py @@ -43,8 +43,8 @@ class ListenerForModelRegistry(Listener): def update(self, source: Helper, event: Event) -> None: if isinstance(event, PushEvent): - self.sha = event.sha - self.rm = from_oci_to_kfmr(model_registry, event, event.sha) + self.sha = event.digest + self.rm = from_oci_to_kfmr(model_registry, event, event.digest) listener = ListenerForModelRegistry() omlmd = Helper()