Skip to content

Commit

Permalink
Py: rename pager limit method to page_size (#379)
Browse files Browse the repository at this point in the history
* py: pager: rename limit to page_size

... to improve semantics.

Closes: #357
Signed-off-by: Isabella do Amaral <[email protected]>

* py: fix custom_props typing issue

Signed-off-by: Isabella do Amaral <[email protected]>

* py: test pager descending option

Signed-off-by: Isabella do Amaral <[email protected]>

* py: update README with suggestion

Signed-off-by: Isabella do Amaral <[email protected]>

---------

Signed-off-by: Isabella do Amaral <[email protected]>
  • Loading branch information
isinyaaa authored Sep 12, 2024
1 parent 7506eaf commit 548fb51
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 24 deletions.
15 changes: 10 additions & 5 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,23 @@ for version in registry.get_model_versions("my-model"):
...
```

To customize sorting order or query limits you can also use
You can also use `order_by_creation_time`, `order_by_update_time`, or `order_by_id` to change the sorting order

```py
latest_updates = registry.get_model_versions("my-model").order_by_update_time().descending().limit(20)
latest_updates = registry.get_model_versions("my-model").order_by_update_time().descending()
for version in latest_updates:
...
```

You can use `order_by_creation_time`, `order_by_update_time`, or `order_by_id` to change the sorting order.
By default, all queries will be `ascending`, but this method is also available for explicitness.

> Note that the `limit()` method only limits the query size, not the actual loop boundaries -- even if your limit is 1
> you will still get all the models, with one query each.
> Note: You can also set the `page_size()` that you want the Pager to use when invoking the Model Registry backend.
> When using it as an iterator, it will automatically manage pages for you.
#### Implementation notes

The pager will manage pages for you in order to prevent infinite looping.
Currently, the Model Registry backend treats model lists as a circular buffer, and **will not end iteration** for you.

## Development

Expand Down
3 changes: 2 additions & 1 deletion clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import os
from collections.abc import Mapping
from pathlib import Path
from typing import Any, TypeVar, Union, get_args
from warnings import warn
Expand Down Expand Up @@ -138,7 +139,7 @@ def register_model(
author: str | None = None,
owner: str | None = None,
description: str | None = None,
metadata: dict[str, SupportedTypes] | None = None,
metadata: Mapping[str, SupportedTypes] | None = None,
) -> RegisteredModel:
"""Register a model.
Expand Down
4 changes: 2 additions & 2 deletions clients/python/src/model_registry/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import Any, Union, get_args

from pydantic import BaseModel, ConfigDict
Expand Down Expand Up @@ -35,7 +35,7 @@ class BaseResourceModel(BaseModel, ABC):
external_id: str | None = None
create_time_since_epoch: str | None = None
last_update_time_since_epoch: str | None = None
custom_properties: dict[str, SupportedTypes] | None = None
custom_properties: Mapping[str, SupportedTypes] | None = None

@abstractmethod
def create(self, **kwargs) -> Any:
Expand Down
9 changes: 6 additions & 3 deletions clients/python/src/model_registry/types/pager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,15 @@ def order_by_id(self) -> Pager[T]:
self.options.order_by = OrderByField.ID
return self.restart()

def limit(self, limit: int) -> Pager[T]:
"""Limit the number of items to return.
def page_size(self, n: int) -> Pager[T]:
"""Set the page size for each request.
This resets the pager.
"""
self.options.limit = limit
if n < 1:
msg = f"Page size must be at least 1, got {n}"
raise ValueError(msg)
self.options.limit = n
return self.restart()

def ascending(self) -> Pager[T]:
Expand Down
49 changes: 40 additions & 9 deletions clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def test_register_existing_version(client: ModelRegistry):
"model_format_version": "test_version",
"version": "1.0.0",
}
client.register_model(**params)
client.register_model(**params, metadata=None)

with pytest.raises(StoreError):
client.register_model(**params)
client.register_model(**params, metadata=None)


@pytest.mark.e2e
Expand Down Expand Up @@ -124,8 +124,10 @@ async def test_update_logical_model_with_labels(client: ModelRegistry):
)
assert rm.id
mv = client.get_model_version(name, version)
assert mv
assert mv.id
ma = client.get_model_artifact(name, version)
assert ma
assert ma.id

rm_labels = {
Expand All @@ -149,9 +151,15 @@ async def test_update_logical_model_with_labels(client: ModelRegistry):
ma.custom_properties = ma_labels
client.update(ma)

assert client.get_registered_model(name).custom_properties == rm_labels
assert client.get_model_version(name, version).custom_properties == mv_labels
assert client.get_model_artifact(name, version).custom_properties == ma_labels
rm = client.get_registered_model(name)
assert rm
assert rm.custom_properties == rm_labels
mv = client.get_model_version(name, version)
assert mv
assert mv.custom_properties == mv_labels
ma = client.get_model_artifact(name, version)
assert ma
assert ma.custom_properties == ma_labels


@pytest.mark.e2e
Expand Down Expand Up @@ -232,7 +240,7 @@ def test_get_registered_models(client: ModelRegistry):
version="1.0.0",
)

rm_iter = client.get_registered_models().limit(10)
rm_iter = client.get_registered_models().page_size(10)
i = 0
prev_tok = None
changes = 0
Expand Down Expand Up @@ -315,6 +323,17 @@ def test_get_registered_models_order_by(client: ModelRegistry):

assert i == models

# or if descending is explicitly set
i = 0
for rm, by_update in zip(
rms,
client.get_registered_models().order_by_update_time().descending(),
):
assert rm.id == by_update.id
i += 1

assert i == models


@pytest.mark.e2e
def test_get_registered_models_and_reset(client: ModelRegistry):
Expand All @@ -330,7 +349,7 @@ def test_get_registered_models_and_reset(client: ModelRegistry):
version="1.0.0",
)

rm_iter = client.get_registered_models().limit(model_count - 1)
rm_iter = client.get_registered_models().page_size(model_count - 1)
models = []
for rm in islice(rm_iter, page):
models.append(rm)
Expand All @@ -355,7 +374,7 @@ def test_get_model_versions(client: ModelRegistry):
version=v,
)

mv_iter = client.get_model_versions(name).limit(10)
mv_iter = client.get_model_versions(name).page_size(10)
i = 0
prev_tok = None
changes = 0
Expand Down Expand Up @@ -430,6 +449,18 @@ def test_get_model_versions_order_by(client: ModelRegistry):
assert mv.id == by_update.id
i += 1

assert i == models

i = 0
for mv, by_update in zip(
mvs,
client.get_model_versions(name).order_by_update_time().descending(),
):
assert mv.id == by_update.id
i += 1

assert i == models


@pytest.mark.e2e
def test_get_model_versions_and_reset(client: ModelRegistry):
Expand All @@ -447,7 +478,7 @@ def test_get_model_versions_and_reset(client: ModelRegistry):
version=v,
)

mv_iter = client.get_model_versions(name).limit(model_count - 1)
mv_iter = client.get_model_versions(name).page_size(model_count - 1)
models = []
for rm in islice(mv_iter, page):
models.append(rm)
Expand Down
10 changes: 6 additions & 4 deletions clients/python/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ async def test_get_registered_model_by_external_id(
client: ModelRegistryAPIClient,
registered_model: RegisteredModel,
):
assert registered_model.external_id
assert (
rm := await client.get_registered_model_by_params(
external_id=registered_model.external_id
Expand All @@ -99,7 +100,7 @@ async def test_page_through_registered_models(client: ModelRegistryAPIClient):
models = 6
for i in range(models):
await client.upsert_registered_model(RegisteredModel(name=f"rm{i}"))
pager = Pager(client.get_registered_models).limit(5)
pager = Pager(client.get_registered_models).page_size(5)
total = 0
async for _ in pager:
total += 1
Expand Down Expand Up @@ -205,7 +206,7 @@ async def test_page_through_model_versions(
)
pager = Pager(
lambda o: client.get_model_versions(str(registered_model.id), o)
).limit(5)
).page_size(5)
total = 0
async for _ in pager:
total += 1
Expand All @@ -227,7 +228,8 @@ async def test_insert_model_artifact(
"service_account_name": "test service account",
}
ma = await client.upsert_model_artifact(
ModelArtifact(**props), str(model_version.id)
ModelArtifact(**props), # type: ignore
str(model_version.id),
)
assert ma.id
assert ma.name == "test model"
Expand Down Expand Up @@ -340,7 +342,7 @@ async def test_page_through_model_version_artifacts(
await client.create_model_version_artifact(art, str(model_version.id))
pager = Pager(
lambda o: client.get_model_version_artifacts(str(model_version.id), o)
).limit(5)
).page_size(5)
total = 0
async for _ in pager:
total += 1
Expand Down

0 comments on commit 548fb51

Please sign in to comment.