Skip to content

Commit

Permalink
Py: Extend high-level API with paging (#178)
Browse files Browse the repository at this point in the history
* py: core: sync generic artifact bindings

There's no way to update them yet, so upsert is renamed as `create`.

Signed-off-by: Isabella Basso do Amaral <[email protected]>

* py: expose listing methods

Signed-off-by: Isabella Basso do Amaral <[email protected]>

---------

Signed-off-by: Isabella Basso do Amaral <[email protected]>
  • Loading branch information
isinyaaa authored Jul 18, 2024
1 parent 3f49945 commit ae4006f
Show file tree
Hide file tree
Showing 9 changed files with 582 additions and 41 deletions.
25 changes: 25 additions & 0 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,31 @@ There are caveats to be noted when using this method:
)
```

### Listing models

To list models you can use
```py
for model in registry.get_registered_models():
...

# and versions associated with a model
for version in registry.get_model_versions("my-model"):
...
```

To customize sorting order or query limits you can also use

```py
latest_updates = registry.get_model_versions("my-model").order_by_update_time().descending().limit(20)
for version in latest_updates:
...
```

You can use `order_by_creation_time`, `order_by_update_time`, or `order_by_id` to change the sorting order.

> Note that the `limit()` method only limits the query size, not the actual loop boundaries -- even if your limit is 1
> you will still get all the models, with one query each.
## Development

Common tasks, such as building documentation and running tests, can be executed using [`nox`](https://github.com/wntrblm/nox) sessions.
Expand Down
99 changes: 78 additions & 21 deletions clients/python/docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ To create a client you should use the {py:meth}`model_registry.core.ModelRegistr
```py
from model_registry.core import ModelRegistryAPIClient

insecure_registry = ModelRegistryAPIClient.insecure_connection(
insecure_mr_client = ModelRegistryAPIClient.insecure_connection(
"server-address", "port",
# optionally, you can identify yourself
# user_token=os.environ["MY_TOKEN"]
)

insecure_registry = ModelRegistryAPIClient.insecure_connection(
mr_client = ModelRegistryAPIClient.secure_connection(
"server-address", "port",
user_token=os.environ["MY_TOKEN"] # this is necessary on a secure connection
# optionally, use a custom_ca
Expand All @@ -42,7 +42,7 @@ from model_registry.types import RegisteredModel, ModelVersion, ModelArtifact
from model_registry.utils import s3_uri_from

async def register_a_model():
model = await registry.upsert_registered_model(
model = await mr_client.upsert_registered_model(
RegisteredModel(
name="HAL",
owner="me <[email protected]>",
Expand All @@ -51,7 +51,7 @@ async def register_a_model():
assert model.id # this should be valid now

# we need a registered model to associate the version to
version = await registry.upsert_model_version(
version = await mr_client.upsert_model_version(
ModelVersion(
name="9000",
author="Mr. Tom A.I.",
Expand All @@ -62,7 +62,7 @@ async def register_a_model():
assert version.id

# we need a version to associate a trained model to
trained_model = await registry.upsert_model_artifact(
trained_model = await mr_client.upsert_model_artifact(
ModelArtifact(
name="HAL-core",
uri=s3_uri_from("build/onnx/hal.onnx", "cool-bucket"),
Expand All @@ -80,6 +80,23 @@ async def register_a_model():
As objects are only assigned IDs upon creation, you can use this property to verify whether an object exists.

You can associate multiple artifacts with the same version as well:

```py
from model_registry.types import DocArtifact

readme = await mr_client.upsert_model_version_artifact(
DocArtifact(
name="README",
uri="https://github.com/my-org/my-model/blob/main/README.md",
description="Model information"
), version.id
)
```

> Note: document artifacts currently have no `storage_*` attributes, so you have to keep track of any credentials
> necessary to access it manually.
### Query objects

There are several ways to get registered objects from the registry.
Expand All @@ -89,9 +106,9 @@ There are several ways to get registered objects from the registry.
After upserting an object you can use its `id` to fetch it again.

```py
new_model = await registry.upsert_registered_model(RegisteredModel("new_model"))
new_model = await mr_client.upsert_registered_model(RegisteredModel("new_model"))

maybe_new_model = await registry.get_registered_model_by_id(new_model.id)
maybe_new_model = await mr_client.get_registered_model_by_id(new_model.id)

assert maybe_new_model == new_model # True
```
Expand Down Expand Up @@ -132,52 +149,88 @@ You can also perform queries by parameters:

```py
# We can get the model artifact associated to a version
another_trained_model = await registry.get_model_artifact_by_params(name="my_model_name", model_version_id=another_version.id)
another_trained_model = await mr_client.get_model_artifact_by_params(name="my_model_name", model_version_id=another_version.id)

# Or by its unique identifier
trained_model = await registry.get_model_artifact_by_params(external_id="unique_reference")
trained_model = await mr_client.get_model_artifact_by_params(external_id="unique_reference")

# Same thing for a version
version = await registry.get_model_version_by_params(external_id="unique_reference")
version = await mr_client.get_model_version_by_params(external_id="unique_reference")

# Or for a model
model = await registry.get_registered_model_by_params(external_id="another_unique_reference")
model = await mr_client.get_registered_model_by_params(external_id="another_unique_reference")

# We can also get a version by its name and associated model id
version = await registry.get_model_version_by_params(version="v1.0", registered_model_id="x")
version = await mr_client.get_model_version_by_params(version="v1.0", registered_model_id="x")

# And we can get a model by simply calling its name
model = await registry.get_registered_model_by_params(name="my_model_name")
model = await mr_client.get_registered_model_by_params(name="my_model_name")
```

### Query multiple objects

We can query all objects of a type

```py
models = await registry.get_registered_models()
models = await mr_client.get_registered_models()

versions = await registry.get_model_versions("registered_model_id")
versions = await mr_client.get_model_versions("registered_model_id")

# We can get a list of all model artifacts
all_model_artifacts = await registry.get_model_artifacts()
# We can get a list of the first 20 model artifacts
all_model_artifacts = await mr_client.get_model_artifacts()
```

To limit or order the query, provide a {py:class}`model_registry.types.ListOptions` object.
To limit or sort the query by another parameter, provide a {py:class}`model_registry.types.ListOptions` object.

```py
from model_registry import ListOptions
from model_registry.types import ListOptions

options = ListOptions(limit=50)

first_50_models = await registry.get_registered_models(options)
first_50_models = await mr_client.get_registered_models(options)

# By default we get ascending order
options = ListOptions.order_by_creation_time(is_asc=False)

last_50_models = await registry.get_registered_models(options)
last_50_models = await mr_client.get_registered_models(options)
```

You can also use the high-level {py:class}`model_registry.types.Pager` to get an iterator.

```py
from model_registry.types import Pager

models = Pager(mr_client.get_registered_models)

async for model in models:
...
```

Note that the iterator currently only works with methods that take a `ListOptions` argument, so if you want to use a
method that needs additional arguments, you'll need to provide a partial application like in the example below.

```py
model_version_artifacts = Pager(lambda o: mr_client.get_model_version_artifacts(mv.id, o))
```

> ⚠️ Also note that a [`partial`](https://docs.python.org/3/library/functools.html#functools.partial) definition won't work as the `options` argument is optional, and thus has to be overriden as a positional argument.
The iterator provides methods for setting up the {py:class}`model_registry.types.ListOptions` that will be used in each
call.

```py
reverse_model_version_artifacts = model_version_artifacts.order_by_creation_time().descending().limit(100)
```

You can also get each page separately and iterate yourself:

```py
page = await reverse_model_version_artifacts.next_page()
```

> Note: the iterator will be automagically sync or async depending on the paging function passed in for initialization.

```{eval-rst}
.. automodule:: model_registry.core
```
Expand All @@ -188,13 +241,17 @@ last_50_models = await registry.get_registered_models(options)

Registry objects can be created by doing

<!-- TODO: be explicit about possible ways to create MA that allow for serving -->

```py
from model_registry.types import ModelArtifact, ModelVersion, RegisteredModel

trained_model = ModelArtifact(
name="model-exec",
uri="resource_URI",
description="Model description",
model_format_name="onnx",
model_format_version="1",
)

version = ModelVersion(
Expand Down
44 changes: 43 additions & 1 deletion clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@

from .core import ModelRegistryAPIClient
from .exceptions import StoreError
from .types import ModelArtifact, ModelVersion, RegisteredModel, SupportedTypes
from .types import (
ListOptions,
ModelArtifact,
ModelVersion,
Pager,
RegisteredModel,
SupportedTypes,
)


class ModelRegistry:
Expand Down Expand Up @@ -327,3 +334,38 @@ def get_model_artifact(self, name: str, version: str) -> ModelArtifact | None:
raise StoreError(msg)
assert mv.id
return self.async_runner(self._api.get_model_artifact_by_params(name, mv.id))

def get_registered_models(self) -> Pager[RegisteredModel]:
"""Get a pager for registered models.
Returns:
Iterable pager for registered models.
"""

def rm_list(options: ListOptions) -> list[RegisteredModel]:
return self.async_runner(self._api.get_registered_models(options))

return Pager[RegisteredModel](rm_list)

def get_model_versions(self, name: str) -> Pager[ModelVersion]:
"""Get a pager for model versions.
Args:
name: Name of the model.
Returns:
Iterable pager for model versions.
Raises:
StoreException: If the model does not exist.
"""
if not (rm := self.get_registered_model(name)):
msg = f"Model {name} does not exist"
raise StoreError(msg)

def rm_versions(options: ListOptions) -> list[ModelVersion]:
# type checkers can't restrict the type inside a nested function: https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
assert rm.id
return self.async_runner(self._api.get_model_versions(rm.id, options))

return Pager[ModelVersion](rm_versions)
Loading

0 comments on commit ae4006f

Please sign in to comment.