Skip to content

Commit

Permalink
update with suggestions
Browse files Browse the repository at this point in the history
Signed-off-by: Isabella do Amaral <[email protected]>
  • Loading branch information
isinyaaa committed Oct 24, 2024
1 parent 3f36bd3 commit c2ed213
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
8 changes: 4 additions & 4 deletions omlmd/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def cli():
@click.option("--media-types", "-m", multiple=True, default=[])
def pull(plain_http: bool, target: str, output: Path, media_types: tuple[str]):
"""Pulls an OCI Artifact containing ML model and metadata, filtering if necessary."""
Helper.from_plain(plain_http).pull(target, output, media_types)
Helper.from_default_registry(plain_http).pull(target, output, media_types)


@cli.group()
Expand All @@ -54,15 +54,15 @@ def get():
@click.argument("target", required=True)
def config(plain_http: bool, target: str):
"""Outputs configuration of the given OCI Artifact for ML model and metadata."""
click.echo(Helper.from_plain(plain_http).get_config(target))
click.echo(Helper.from_default_registry(plain_http).get_config(target))


@cli.command()
@plain_http
@click.argument("targets", required=True, nargs=-1)
def crawl(plain_http: bool, targets: tuple[str]):
"""Crawls configuration for the given list of OCI Artifact for ML model and metadata."""
click.echo(Helper.from_plain(plain_http).crawl(targets))
click.echo(Helper.from_default_registry(plain_http).crawl(targets))


@cli.command()
Expand Down Expand Up @@ -96,4 +96,4 @@ def push(
if empty_metadata:
logger.warning(f"Pushing to {target} with empty metadata.")
md = deserialize_mdfile(metadata) if metadata else {}
click.echo(Helper.from_plain(plain_http).push(target, path, **md))
click.echo(Helper.from_default_registry(plain_http).push(target, path, **md))
10 changes: 5 additions & 5 deletions omlmd/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def download_file(uri: str):

@dataclass
class Helper:
registry: OMLMDRegistry = (
_registry: OMLMDRegistry = (
field( # TODO: this is a bit limiting when used from CLI, to be refactored
default_factory=lambda: OMLMDRegistry(insecure=True)
)
)
_listeners: list[Listener] = field(default_factory=list)

@classmethod
def from_plain(cls, insecure: bool):
def from_default_registry(cls, insecure: bool):
return cls(OMLMDRegistry(insecure=insecure))

def push(
Expand Down Expand Up @@ -92,7 +92,7 @@ def push(
]
try:
# print(target, files, model_metadata.to_annotations_dict())
result = self.registry.push(
result = self._registry.push(
target=target,
files=files,
manifest_annotations=model_metadata.to_annotations_dict(),
Expand All @@ -115,10 +115,10 @@ def push(
def pull(
self, target: str, outdir: Path | str, media_types: Sequence[str] | None = None
):
self.registry.download_layers(target, outdir, media_types)
self._registry.download_layers(target, outdir, media_types)

def get_config(self, target: str) -> str:
return f'{{"reference":"{target}", "config": {self.registry.get_config(target)} }}' # this assumes OCI Manifest.Config later is JSON (per std spec)
return f'{{"reference":"{target}", "config": {self._registry.get_config(target)} }}' # this assumes OCI Manifest.Config later is JSON (per std spec)

def crawl(self, targets: Sequence[str]) -> str:
configs = map(self.get_config, targets)
Expand Down
2 changes: 1 addition & 1 deletion omlmd/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ class Event(ABC):

@dataclass
class PushEvent(Event):
sha: str
digest: str
target: str
metadata: ModelMetadata
4 changes: 2 additions & 2 deletions tests/test_e2e_model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class ListenerForModelRegistry(Listener):

def update(self, source: Helper, event: Event) -> None:
if isinstance(event, PushEvent):
self.sha = event.sha
self.rm = from_oci_to_kfmr(model_registry, event, event.sha)
self.sha = event.digest
self.rm = from_oci_to_kfmr(model_registry, event, event.digest)

listener = ListenerForModelRegistry()
omlmd = Helper()
Expand Down

0 comments on commit c2ed213

Please sign in to comment.