Skip to content

Commit

Permalink
Test Python client against latest MR (#326)
Browse files Browse the repository at this point in the history
* py: tests: rename basic to REST bindings

Signed-off-by: Isabella do Amaral <[email protected]>

* py: create e2e test mode

Signed-off-by: Isabella do Amaral <[email protected]>

---------

Signed-off-by: Isabella do Amaral <[email protected]>
  • Loading branch information
isinyaaa authored Aug 30, 2024
1 parent e5e6f73 commit 3e1259a
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 3 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ jobs:
working-directory: clients/python
run: |
if [[ ${{ matrix.session }} == "tests" ]]; then
make build-mr
nox --python=${{ matrix.python }} -- --cov-report=xml
elif [[ ${{ matrix.session }} == "mypy" ]]; then
nox --python=${{ matrix.python }} ||\
Expand Down
8 changes: 8 additions & 0 deletions clients/python/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ install:
clean:
rm -rf src/mr_openapi

.PHONY: build-mr
build-mr:
cd ../../ && make image/build

.PHONY: test-e2e
test-e2e: build-mr
poetry run pytest --e2e -s

.PHONY: test
test:
poetry run pytest -s
Expand Down
7 changes: 6 additions & 1 deletion clients/python/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


package = "model_registry"
python_versions = ["3.12", "3.11","3.10", "3.9"]
python_versions = ["3.12", "3.11", "3.10", "3.9"]
nox.needs_version = ">= 2021.6.6"
nox.options.sessions = (
"tests",
Expand Down Expand Up @@ -63,6 +63,11 @@ def tests(session: Session) -> None:
try:
session.run(
"pytest",
*session.posargs,
)
session.run(
"pytest",
"--e2e",
"--cov",
"--cov-config=pyproject.toml",
*session.posargs,
Expand Down
1 change: 1 addition & 0 deletions clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ line-length = 119

[tool.pytest.ini_options]
asyncio_mode = "auto"
markers = ["e2e: end-to-end testing"]

[tool.ruff]
target-version = "py39"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ async def mv_create(client, rm_create) -> ModelVersionCreate:
)


@pytest.mark.e2e
async def test_registered_model(client, rm_create):
rm_create.custom_properties = {
"key1": MetadataValue.from_dict(
Expand All @@ -64,6 +65,7 @@ async def test_registered_model(client, rm_create):
assert new_rm.description == by_find.description


@pytest.mark.e2e
async def test_model_version(client, mv_create):
mv_create.custom_properties = {
"key1": MetadataValue.from_dict(
Expand All @@ -87,6 +89,7 @@ async def test_model_version(client, mv_create):
assert new_mv.custom_properties == by_find.custom_properties


@pytest.mark.e2e
async def test_model_artifact(client, mv_create):
mv = await client.create_model_version(mv_create)
assert mv is not None
Expand Down
20 changes: 18 additions & 2 deletions clients/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,22 @@
import pytest
import requests


def pytest_addoption(parser):
parser.addoption("--e2e", action="store_true", help="run end-to-end tests")


def pytest_collection_modifyitems(config, items):
for item in items:
skip_e2e = pytest.mark.skip(
reason="this is an end-to-end test, requires explicit opt-in --e2e option to run."
)
if "e2e" in item.keywords:
if not config.getoption("--e2e"):
item.add_marker(skip_e2e)
continue


REGISTRY_HOST = "http://localhost"
REGISTRY_PORT = 8080
REGISTRY_URL = f"{REGISTRY_HOST}:{REGISTRY_PORT}"
Expand Down Expand Up @@ -45,7 +61,7 @@ def poll_for_ready():
time.sleep(POLL_INTERVAL)


@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(scope="session")
def _compose_mr(root):
print("Assuming this is the Model Registry root directory:", root)
shared_volume = root / "test/config/ml-metadata"
Expand Down Expand Up @@ -76,7 +92,7 @@ def _compose_mr(root):


def cleanup(client):
async def yield_and_restart(root):
async def yield_and_restart(_compose_mr, root):
poll_for_ready()
if inspect.iscoroutinefunction(client) or inspect.isasyncgenfunction(client):
async with asynccontextmanager(client)() as async_client:
Expand Down
10 changes: 10 additions & 0 deletions clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_secure_client():
assert "user token" in str(e.value).lower()


@pytest.mark.e2e
async def test_register_new(client: ModelRegistry):
name = "test_model"
version = "1.0.0"
Expand All @@ -44,6 +45,7 @@ async def test_register_new(client: ModelRegistry):
assert ma


@pytest.mark.e2e
async def test_register_new_using_s3_uri_builder(client: ModelRegistry):
name = "test_model"
version = "1.0.0"
Expand All @@ -68,6 +70,7 @@ async def test_register_new_using_s3_uri_builder(client: ModelRegistry):
assert ma.uri == uri


@pytest.mark.e2e
def test_register_existing_version(client: ModelRegistry):
params = {
"name": "test_model",
Expand All @@ -82,6 +85,7 @@ def test_register_existing_version(client: ModelRegistry):
client.register_model(**params)


@pytest.mark.e2e
async def test_get(client: ModelRegistry):
name = "test_model"
version = "1.0.0"
Expand Down Expand Up @@ -112,6 +116,7 @@ async def test_get(client: ModelRegistry):
assert ma.id == _ma.id


@pytest.mark.e2e
def test_get_registered_models(client: ModelRegistry):
models = 21

Expand Down Expand Up @@ -142,6 +147,7 @@ def test_get_registered_models(client: ModelRegistry):
assert i == models


@pytest.mark.e2e
def test_get_registered_models_and_reset(client: ModelRegistry):
model_count = 6
page = model_count // 2
Expand All @@ -166,6 +172,7 @@ def test_get_registered_models_and_reset(client: ModelRegistry):
assert complete[:page] == models


@pytest.mark.e2e
def test_get_model_versions(client: ModelRegistry):
name = "test_model"
models = 21
Expand Down Expand Up @@ -197,6 +204,7 @@ def test_get_model_versions(client: ModelRegistry):
assert i == models


@pytest.mark.e2e
def test_get_model_versions_and_reset(client: ModelRegistry):
name = "test_model"

Expand All @@ -223,6 +231,7 @@ def test_get_model_versions_and_reset(client: ModelRegistry):
assert complete[:page] == models


@pytest.mark.e2e
def test_hf_import(client: ModelRegistry):
pytest.importorskip("huggingface_hub")
name = "openai-community/gpt2"
Expand Down Expand Up @@ -250,6 +259,7 @@ def test_hf_import(client: ModelRegistry):
assert client.get_model_artifact(name, version)


@pytest.mark.e2e
def test_hf_import_default_env(client: ModelRegistry):
"""Test setting environment variables, hence triggering defaults, does _not_ interfere with HF metadata"""
pytest.importorskip("huggingface_hub")
Expand Down
22 changes: 22 additions & 0 deletions clients/python/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def client():
return ModelRegistryAPIClient.insecure_connection(REGISTRY_HOST, REGISTRY_PORT)


@pytest.mark.e2e
async def test_insert_registered_model(client: ModelRegistryAPIClient):
registered_model = RegisteredModel(name="test rm")
rm = await client.upsert_registered_model(registered_model)
Expand All @@ -31,6 +32,7 @@ async def test_insert_registered_model(client: ModelRegistryAPIClient):
assert rm.last_update_time_since_epoch


@pytest.mark.e2e
async def test_update_registered_model(client: ModelRegistryAPIClient):
registered_model = RegisteredModel(name="updated rm")
rm = await client.upsert_registered_model(registered_model)
Expand All @@ -49,6 +51,7 @@ async def registered_model(client: ModelRegistryAPIClient) -> RegisteredModel:
)


@pytest.mark.e2e
async def test_get_registered_model_by_id(
client: ModelRegistryAPIClient,
registered_model: RegisteredModel,
Expand All @@ -57,6 +60,7 @@ async def test_get_registered_model_by_id(
assert rm == registered_model


@pytest.mark.e2e
async def test_get_registered_model_by_name(
client: ModelRegistryAPIClient,
registered_model: RegisteredModel,
Expand All @@ -67,6 +71,7 @@ async def test_get_registered_model_by_name(
assert rm == registered_model


@pytest.mark.e2e
async def test_get_registered_model_by_external_id(
client: ModelRegistryAPIClient,
registered_model: RegisteredModel,
Expand All @@ -79,6 +84,7 @@ async def test_get_registered_model_by_external_id(
assert rm == registered_model


@pytest.mark.e2e
async def test_get_registered_models(
client: ModelRegistryAPIClient, registered_model: RegisteredModel
):
Expand All @@ -88,6 +94,7 @@ async def test_get_registered_models(
assert [registered_model, rm2] == rms


@pytest.mark.e2e
async def test_page_through_registered_models(client: ModelRegistryAPIClient):
models = 6
for i in range(models):
Expand All @@ -99,6 +106,7 @@ async def test_page_through_registered_models(client: ModelRegistryAPIClient):
assert total == models


@pytest.mark.e2e
async def test_insert_model_version(
client: ModelRegistryAPIClient,
registered_model: RegisteredModel,
Expand All @@ -114,6 +122,7 @@ async def test_insert_model_version(
assert mv.author == model_version.author


@pytest.mark.e2e
async def test_update_model_version(
client: ModelRegistryAPIClient, registered_model: RegisteredModel
):
Expand All @@ -137,13 +146,15 @@ async def model_version(
)


@pytest.mark.e2e
async def test_get_model_version_by_id(
client: ModelRegistryAPIClient, model_version: ModelVersion
):
assert (mv := await client.get_model_version_by_id(str(model_version.id)))
assert mv == model_version


@pytest.mark.e2e
async def test_get_model_version_by_name(
client: ModelRegistryAPIClient,
registered_model: RegisteredModel,
Expand All @@ -157,6 +168,7 @@ async def test_get_model_version_by_name(
assert mv == model_version


@pytest.mark.e2e
async def test_get_model_version_by_external_id(
client: ModelRegistryAPIClient, model_version: ModelVersion
):
Expand All @@ -168,6 +180,7 @@ async def test_get_model_version_by_external_id(
assert mv == model_version


@pytest.mark.e2e
async def test_get_model_versions(
client: ModelRegistryAPIClient,
registered_model: RegisteredModel,
Expand All @@ -181,6 +194,7 @@ async def test_get_model_versions(
assert [model_version, mv2] == mvs


@pytest.mark.e2e
async def test_page_through_model_versions(
client: ModelRegistryAPIClient, registered_model: RegisteredModel
):
Expand All @@ -198,6 +212,7 @@ async def test_page_through_model_versions(
assert total == models


@pytest.mark.e2e
async def test_insert_model_artifact(
client: ModelRegistryAPIClient,
model_version: ModelVersion,
Expand Down Expand Up @@ -228,6 +243,7 @@ async def test_insert_model_artifact(
assert ma.service_account_name


@pytest.mark.e2e
async def test_update_model_artifact(
client: ModelRegistryAPIClient, model_version: ModelVersion
):
Expand All @@ -252,13 +268,15 @@ async def model(
)


@pytest.mark.e2e
async def test_get_model_artifact_by_id(
client: ModelRegistryAPIClient, model: ModelArtifact
):
assert (ma := await client.get_model_artifact_by_id(str(model.id)))
assert ma == model


@pytest.mark.e2e
async def test_get_model_artifact_by_name(
client: ModelRegistryAPIClient, model_version: ModelVersion, model: ModelArtifact
):
Expand All @@ -270,6 +288,7 @@ async def test_get_model_artifact_by_name(
assert ma == model


@pytest.mark.e2e
async def test_get_model_artifact_by_external_id(
client: ModelRegistryAPIClient, model: ModelArtifact
):
Expand All @@ -281,6 +300,7 @@ async def test_get_model_artifact_by_external_id(
assert ma == model


@pytest.mark.e2e
async def test_get_all_model_artifacts(
client: ModelRegistryAPIClient, model_version: ModelVersion, model: ModelArtifact
):
Expand All @@ -292,6 +312,7 @@ async def test_get_all_model_artifacts(
assert [model, ma2] == mas


@pytest.mark.e2e
async def test_get_model_artifacts_by_mv_id(
client: ModelRegistryAPIClient, model_version: ModelVersion, model: ModelArtifact
):
Expand All @@ -303,6 +324,7 @@ async def test_get_model_artifacts_by_mv_id(
assert [model, ma2] == mas


@pytest.mark.e2e
async def test_page_through_model_version_artifacts(
client: ModelRegistryAPIClient,
registered_model: RegisteredModel,
Expand Down

0 comments on commit 3e1259a

Please sign in to comment.