-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Py: Extend high-level API with paging (#178)
* py: core: sync generic artifact bindings There's no way to update them yet, so upsert is renamed as `create`. Signed-off-by: Isabella Basso do Amaral <[email protected]> * py: expose listing methods Signed-off-by: Isabella Basso do Amaral <[email protected]> --------- Signed-off-by: Isabella Basso do Amaral <[email protected]>
- Loading branch information
Showing
9 changed files
with
582 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,13 +17,13 @@ To create a client you should use the {py:meth}`model_registry.core.ModelRegistr | |
```py | ||
from model_registry.core import ModelRegistryAPIClient | ||
|
||
insecure_registry = ModelRegistryAPIClient.insecure_connection( | ||
insecure_mr_client = ModelRegistryAPIClient.insecure_connection( | ||
"server-address", "port", | ||
# optionally, you can identify yourself | ||
# user_token=os.environ["MY_TOKEN"] | ||
) | ||
|
||
insecure_registry = ModelRegistryAPIClient.insecure_connection( | ||
mr_client = ModelRegistryAPIClient.secure_connection( | ||
"server-address", "port", | ||
user_token=os.environ["MY_TOKEN"] # this is necessary on a secure connection | ||
# optionally, use a custom_ca | ||
|
@@ -42,7 +42,7 @@ from model_registry.types import RegisteredModel, ModelVersion, ModelArtifact | |
from model_registry.utils import s3_uri_from | ||
|
||
async def register_a_model(): | ||
model = await registry.upsert_registered_model( | ||
model = await mr_client.upsert_registered_model( | ||
RegisteredModel( | ||
name="HAL", | ||
owner="me <[email protected]>", | ||
|
@@ -51,7 +51,7 @@ async def register_a_model(): | |
assert model.id # this should be valid now | ||
|
||
# we need a registered model to associate the version to | ||
version = await registry.upsert_model_version( | ||
version = await mr_client.upsert_model_version( | ||
ModelVersion( | ||
name="9000", | ||
author="Mr. Tom A.I.", | ||
|
@@ -62,7 +62,7 @@ async def register_a_model(): | |
assert version.id | ||
|
||
# we need a version to associate a trained model to | ||
trained_model = await registry.upsert_model_artifact( | ||
trained_model = await mr_client.upsert_model_artifact( | ||
ModelArtifact( | ||
name="HAL-core", | ||
uri=s3_uri_from("build/onnx/hal.onnx", "cool-bucket"), | ||
|
@@ -80,6 +80,23 @@ async def register_a_model(): | |
As objects are only assigned IDs upon creation, you can use this property to verify whether an object exists. | ||
|
||
You can associate multiple artifacts with the same version as well: | ||
|
||
```py | ||
from model_registry.types import DocArtifact | ||
|
||
readme = await mr_client.upsert_model_version_artifact( | ||
DocArtifact( | ||
name="README", | ||
uri="https://github.com/my-org/my-model/blob/main/README.md", | ||
description="Model information" | ||
), version.id | ||
) | ||
``` | ||
|
||
> Note: document artifacts currently have no `storage_*` attributes, so you have to keep track of any credentials | ||
> necessary to access it manually. | ||
### Query objects | ||
|
||
There are several ways to get registered objects from the registry. | ||
|
@@ -89,9 +106,9 @@ There are several ways to get registered objects from the registry. | |
After upserting an object you can use its `id` to fetch it again. | ||
|
||
```py | ||
new_model = await registry.upsert_registered_model(RegisteredModel("new_model")) | ||
new_model = await mr_client.upsert_registered_model(RegisteredModel("new_model")) | ||
|
||
maybe_new_model = await registry.get_registered_model_by_id(new_model.id) | ||
maybe_new_model = await mr_client.get_registered_model_by_id(new_model.id) | ||
|
||
assert maybe_new_model == new_model # True | ||
``` | ||
|
@@ -132,52 +149,88 @@ You can also perform queries by parameters: | |
|
||
```py | ||
# We can get the model artifact associated to a version | ||
another_trained_model = await registry.get_model_artifact_by_params(name="my_model_name", model_version_id=another_version.id) | ||
another_trained_model = await mr_client.get_model_artifact_by_params(name="my_model_name", model_version_id=another_version.id) | ||
|
||
# Or by its unique identifier | ||
trained_model = await registry.get_model_artifact_by_params(external_id="unique_reference") | ||
trained_model = await mr_client.get_model_artifact_by_params(external_id="unique_reference") | ||
|
||
# Same thing for a version | ||
version = await registry.get_model_version_by_params(external_id="unique_reference") | ||
version = await mr_client.get_model_version_by_params(external_id="unique_reference") | ||
|
||
# Or for a model | ||
model = await registry.get_registered_model_by_params(external_id="another_unique_reference") | ||
model = await mr_client.get_registered_model_by_params(external_id="another_unique_reference") | ||
|
||
# We can also get a version by its name and associated model id | ||
version = await registry.get_model_version_by_params(version="v1.0", registered_model_id="x") | ||
version = await mr_client.get_model_version_by_params(version="v1.0", registered_model_id="x") | ||
|
||
# And we can get a model by simply calling its name | ||
model = await registry.get_registered_model_by_params(name="my_model_name") | ||
model = await mr_client.get_registered_model_by_params(name="my_model_name") | ||
``` | ||
|
||
### Query multiple objects | ||
|
||
We can query all objects of a type | ||
|
||
```py | ||
models = await registry.get_registered_models() | ||
models = await mr_client.get_registered_models() | ||
|
||
versions = await registry.get_model_versions("registered_model_id") | ||
versions = await mr_client.get_model_versions("registered_model_id") | ||
|
||
# We can get a list of all model artifacts | ||
all_model_artifacts = await registry.get_model_artifacts() | ||
# We can get a list of the first 20 model artifacts | ||
all_model_artifacts = await mr_client.get_model_artifacts() | ||
``` | ||
|
||
To limit or order the query, provide a {py:class}`model_registry.types.ListOptions` object. | ||
To limit or sort the query by another parameter, provide a {py:class}`model_registry.types.ListOptions` object. | ||
|
||
```py | ||
from model_registry import ListOptions | ||
from model_registry.types import ListOptions | ||
|
||
options = ListOptions(limit=50) | ||
|
||
first_50_models = await registry.get_registered_models(options) | ||
first_50_models = await mr_client.get_registered_models(options) | ||
|
||
# By default we get ascending order | ||
options = ListOptions.order_by_creation_time(is_asc=False) | ||
|
||
last_50_models = await registry.get_registered_models(options) | ||
last_50_models = await mr_client.get_registered_models(options) | ||
``` | ||
|
||
You can also use the high-level {py:class}`model_registry.types.Pager` to get an iterator. | ||
|
||
```py | ||
from model_registry.types import Pager | ||
|
||
models = Pager(mr_client.get_registered_models) | ||
|
||
async for model in models: | ||
... | ||
``` | ||
|
||
Note that the iterator currently only works with methods that take a `ListOptions` argument, so if you want to use a | ||
method that needs additional arguments, you'll need to provide a partial application like in the example below. | ||
|
||
```py | ||
model_version_artifacts = Pager(lambda o: mr_client.get_model_version_artifacts(mv.id, o)) | ||
``` | ||
|
||
> ⚠️ Also note that a [`partial`](https://docs.python.org/3/library/functools.html#functools.partial) definition won't work as the `options` argument is optional, and thus has to be overriden as a positional argument. | ||
The iterator provides methods for setting up the {py:class}`model_registry.types.ListOptions` that will be used in each | ||
call. | ||
|
||
```py | ||
reverse_model_version_artifacts = model_version_artifacts.order_by_creation_time().descending().limit(100) | ||
``` | ||
|
||
You can also get each page separately and iterate yourself: | ||
|
||
```py | ||
page = await reverse_model_version_artifacts.next_page() | ||
``` | ||
|
||
> Note: the iterator will be automagically sync or async depending on the paging function passed in for initialization. | ||
|
||
```{eval-rst} | ||
.. automodule:: model_registry.core | ||
``` | ||
|
@@ -188,13 +241,17 @@ last_50_models = await registry.get_registered_models(options) | |
|
||
Registry objects can be created by doing | ||
|
||
<!-- TODO: be explicit about possible ways to create MA that allow for serving --> | ||
|
||
```py | ||
from model_registry.types import ModelArtifact, ModelVersion, RegisteredModel | ||
|
||
trained_model = ModelArtifact( | ||
name="model-exec", | ||
uri="resource_URI", | ||
description="Model description", | ||
model_format_name="onnx", | ||
model_format_version="1", | ||
) | ||
|
||
version = ModelVersion( | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.