From 3e1259a223efe7c51a7c8afa8ecdd0a04f3394ed Mon Sep 17 00:00:00 2001 From: Isabella Basso Date: Fri, 30 Aug 2024 13:47:05 -0300 Subject: [PATCH] Test Python client against latest MR (#326) * py: tests: rename basic to REST bindings Signed-off-by: Isabella do Amaral * py: create e2e test mode Signed-off-by: Isabella do Amaral --------- Signed-off-by: Isabella do Amaral --- .github/workflows/python-tests.yml | 1 + clients/python/Makefile | 8 +++++++ clients/python/noxfile.py | 7 +++++- clients/python/pyproject.toml | 1 + .../{basic_test.py => REST_bindings_test.py} | 3 +++ clients/python/tests/conftest.py | 20 +++++++++++++++-- clients/python/tests/test_client.py | 10 +++++++++ clients/python/tests/test_core.py | 22 +++++++++++++++++++ 8 files changed, 69 insertions(+), 3 deletions(-) rename clients/python/tests/{basic_test.py => REST_bindings_test.py} (98%) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index ca00c84d..eecd5b45 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -65,6 +65,7 @@ jobs: working-directory: clients/python run: | if [[ ${{ matrix.session }} == "tests" ]]; then + make build-mr nox --python=${{ matrix.python }} -- --cov-report=xml elif [[ ${{ matrix.session }} == "mypy" ]]; then nox --python=${{ matrix.python }} ||\ diff --git a/clients/python/Makefile b/clients/python/Makefile index c16c0491..60cb21b8 100644 --- a/clients/python/Makefile +++ b/clients/python/Makefile @@ -11,6 +11,14 @@ install: clean: rm -rf src/mr_openapi +.PHONY: build-mr +build-mr: + cd ../../ && make image/build + +.PHONY: test-e2e +test-e2e: build-mr + poetry run pytest --e2e -s + .PHONY: test test: poetry run pytest -s diff --git a/clients/python/noxfile.py b/clients/python/noxfile.py index c92fe9b0..7edbc2c8 100644 --- a/clients/python/noxfile.py +++ b/clients/python/noxfile.py @@ -21,7 +21,7 @@ package = "model_registry" -python_versions = ["3.12", "3.11","3.10", "3.9"] +python_versions = ["3.12", "3.11", "3.10", "3.9"] nox.needs_version = ">= 2021.6.6" nox.options.sessions = ( "tests", @@ -63,6 +63,11 @@ def tests(session: Session) -> None: try: session.run( "pytest", + *session.posargs, + ) + session.run( + "pytest", + "--e2e", "--cov", "--cov-config=pyproject.toml", *session.posargs, diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 3e3c540e..30c6ddc4 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -70,6 +70,7 @@ line-length = 119 [tool.pytest.ini_options] asyncio_mode = "auto" +markers = ["e2e: end-to-end testing"] [tool.ruff] target-version = "py39" diff --git a/clients/python/tests/basic_test.py b/clients/python/tests/REST_bindings_test.py similarity index 98% rename from clients/python/tests/basic_test.py rename to clients/python/tests/REST_bindings_test.py index 96f6639d..649ffbad 100644 --- a/clients/python/tests/basic_test.py +++ b/clients/python/tests/REST_bindings_test.py @@ -43,6 +43,7 @@ async def mv_create(client, rm_create) -> ModelVersionCreate: ) +@pytest.mark.e2e async def test_registered_model(client, rm_create): rm_create.custom_properties = { "key1": MetadataValue.from_dict( @@ -64,6 +65,7 @@ async def test_registered_model(client, rm_create): assert new_rm.description == by_find.description +@pytest.mark.e2e async def test_model_version(client, mv_create): mv_create.custom_properties = { "key1": MetadataValue.from_dict( @@ -87,6 +89,7 @@ async def test_model_version(client, mv_create): assert new_mv.custom_properties == by_find.custom_properties +@pytest.mark.e2e async def test_model_artifact(client, mv_create): mv = await client.create_model_version(mv_create) assert mv is not None diff --git a/clients/python/tests/conftest.py b/clients/python/tests/conftest.py index 02c9cb3e..5d73ae62 100644 --- a/clients/python/tests/conftest.py +++ b/clients/python/tests/conftest.py @@ -10,6 +10,22 @@ import pytest import requests + +def pytest_addoption(parser): + parser.addoption("--e2e", action="store_true", help="run end-to-end tests") + + +def pytest_collection_modifyitems(config, items): + for item in items: + skip_e2e = pytest.mark.skip( + reason="this is an end-to-end test, requires explicit opt-in --e2e option to run." + ) + if "e2e" in item.keywords: + if not config.getoption("--e2e"): + item.add_marker(skip_e2e) + continue + + REGISTRY_HOST = "http://localhost" REGISTRY_PORT = 8080 REGISTRY_URL = f"{REGISTRY_HOST}:{REGISTRY_PORT}" @@ -45,7 +61,7 @@ def poll_for_ready(): time.sleep(POLL_INTERVAL) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="session") def _compose_mr(root): print("Assuming this is the Model Registry root directory:", root) shared_volume = root / "test/config/ml-metadata" @@ -76,7 +92,7 @@ def _compose_mr(root): def cleanup(client): - async def yield_and_restart(root): + async def yield_and_restart(_compose_mr, root): poll_for_ready() if inspect.iscoroutinefunction(client) or inspect.isasyncgenfunction(client): async with asynccontextmanager(client)() as async_client: diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 8b459c09..aaeba370 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -24,6 +24,7 @@ def test_secure_client(): assert "user token" in str(e.value).lower() +@pytest.mark.e2e async def test_register_new(client: ModelRegistry): name = "test_model" version = "1.0.0" @@ -44,6 +45,7 @@ async def test_register_new(client: ModelRegistry): assert ma +@pytest.mark.e2e async def test_register_new_using_s3_uri_builder(client: ModelRegistry): name = "test_model" version = "1.0.0" @@ -68,6 +70,7 @@ async def test_register_new_using_s3_uri_builder(client: ModelRegistry): assert ma.uri == uri +@pytest.mark.e2e def test_register_existing_version(client: ModelRegistry): params = { "name": "test_model", @@ -82,6 +85,7 @@ def test_register_existing_version(client: ModelRegistry): client.register_model(**params) +@pytest.mark.e2e async def test_get(client: ModelRegistry): name = "test_model" version = "1.0.0" @@ -112,6 +116,7 @@ async def test_get(client: ModelRegistry): assert ma.id == _ma.id +@pytest.mark.e2e def test_get_registered_models(client: ModelRegistry): models = 21 @@ -142,6 +147,7 @@ def test_get_registered_models(client: ModelRegistry): assert i == models +@pytest.mark.e2e def test_get_registered_models_and_reset(client: ModelRegistry): model_count = 6 page = model_count // 2 @@ -166,6 +172,7 @@ def test_get_registered_models_and_reset(client: ModelRegistry): assert complete[:page] == models +@pytest.mark.e2e def test_get_model_versions(client: ModelRegistry): name = "test_model" models = 21 @@ -197,6 +204,7 @@ def test_get_model_versions(client: ModelRegistry): assert i == models +@pytest.mark.e2e def test_get_model_versions_and_reset(client: ModelRegistry): name = "test_model" @@ -223,6 +231,7 @@ def test_get_model_versions_and_reset(client: ModelRegistry): assert complete[:page] == models +@pytest.mark.e2e def test_hf_import(client: ModelRegistry): pytest.importorskip("huggingface_hub") name = "openai-community/gpt2" @@ -250,6 +259,7 @@ def test_hf_import(client: ModelRegistry): assert client.get_model_artifact(name, version) +@pytest.mark.e2e def test_hf_import_default_env(client: ModelRegistry): """Test setting environment variables, hence triggering defaults, does _not_ interfere with HF metadata""" pytest.importorskip("huggingface_hub") diff --git a/clients/python/tests/test_core.py b/clients/python/tests/test_core.py index 784e94ec..75c52f36 100644 --- a/clients/python/tests/test_core.py +++ b/clients/python/tests/test_core.py @@ -20,6 +20,7 @@ def client(): return ModelRegistryAPIClient.insecure_connection(REGISTRY_HOST, REGISTRY_PORT) +@pytest.mark.e2e async def test_insert_registered_model(client: ModelRegistryAPIClient): registered_model = RegisteredModel(name="test rm") rm = await client.upsert_registered_model(registered_model) @@ -31,6 +32,7 @@ async def test_insert_registered_model(client: ModelRegistryAPIClient): assert rm.last_update_time_since_epoch +@pytest.mark.e2e async def test_update_registered_model(client: ModelRegistryAPIClient): registered_model = RegisteredModel(name="updated rm") rm = await client.upsert_registered_model(registered_model) @@ -49,6 +51,7 @@ async def registered_model(client: ModelRegistryAPIClient) -> RegisteredModel: ) +@pytest.mark.e2e async def test_get_registered_model_by_id( client: ModelRegistryAPIClient, registered_model: RegisteredModel, @@ -57,6 +60,7 @@ async def test_get_registered_model_by_id( assert rm == registered_model +@pytest.mark.e2e async def test_get_registered_model_by_name( client: ModelRegistryAPIClient, registered_model: RegisteredModel, @@ -67,6 +71,7 @@ async def test_get_registered_model_by_name( assert rm == registered_model +@pytest.mark.e2e async def test_get_registered_model_by_external_id( client: ModelRegistryAPIClient, registered_model: RegisteredModel, @@ -79,6 +84,7 @@ async def test_get_registered_model_by_external_id( assert rm == registered_model +@pytest.mark.e2e async def test_get_registered_models( client: ModelRegistryAPIClient, registered_model: RegisteredModel ): @@ -88,6 +94,7 @@ async def test_get_registered_models( assert [registered_model, rm2] == rms +@pytest.mark.e2e async def test_page_through_registered_models(client: ModelRegistryAPIClient): models = 6 for i in range(models): @@ -99,6 +106,7 @@ async def test_page_through_registered_models(client: ModelRegistryAPIClient): assert total == models +@pytest.mark.e2e async def test_insert_model_version( client: ModelRegistryAPIClient, registered_model: RegisteredModel, @@ -114,6 +122,7 @@ async def test_insert_model_version( assert mv.author == model_version.author +@pytest.mark.e2e async def test_update_model_version( client: ModelRegistryAPIClient, registered_model: RegisteredModel ): @@ -137,6 +146,7 @@ async def model_version( ) +@pytest.mark.e2e async def test_get_model_version_by_id( client: ModelRegistryAPIClient, model_version: ModelVersion ): @@ -144,6 +154,7 @@ async def test_get_model_version_by_id( assert mv == model_version +@pytest.mark.e2e async def test_get_model_version_by_name( client: ModelRegistryAPIClient, registered_model: RegisteredModel, @@ -157,6 +168,7 @@ async def test_get_model_version_by_name( assert mv == model_version +@pytest.mark.e2e async def test_get_model_version_by_external_id( client: ModelRegistryAPIClient, model_version: ModelVersion ): @@ -168,6 +180,7 @@ async def test_get_model_version_by_external_id( assert mv == model_version +@pytest.mark.e2e async def test_get_model_versions( client: ModelRegistryAPIClient, registered_model: RegisteredModel, @@ -181,6 +194,7 @@ async def test_get_model_versions( assert [model_version, mv2] == mvs +@pytest.mark.e2e async def test_page_through_model_versions( client: ModelRegistryAPIClient, registered_model: RegisteredModel ): @@ -198,6 +212,7 @@ async def test_page_through_model_versions( assert total == models +@pytest.mark.e2e async def test_insert_model_artifact( client: ModelRegistryAPIClient, model_version: ModelVersion, @@ -228,6 +243,7 @@ async def test_insert_model_artifact( assert ma.service_account_name +@pytest.mark.e2e async def test_update_model_artifact( client: ModelRegistryAPIClient, model_version: ModelVersion ): @@ -252,6 +268,7 @@ async def model( ) +@pytest.mark.e2e async def test_get_model_artifact_by_id( client: ModelRegistryAPIClient, model: ModelArtifact ): @@ -259,6 +276,7 @@ async def test_get_model_artifact_by_id( assert ma == model +@pytest.mark.e2e async def test_get_model_artifact_by_name( client: ModelRegistryAPIClient, model_version: ModelVersion, model: ModelArtifact ): @@ -270,6 +288,7 @@ async def test_get_model_artifact_by_name( assert ma == model +@pytest.mark.e2e async def test_get_model_artifact_by_external_id( client: ModelRegistryAPIClient, model: ModelArtifact ): @@ -281,6 +300,7 @@ async def test_get_model_artifact_by_external_id( assert ma == model +@pytest.mark.e2e async def test_get_all_model_artifacts( client: ModelRegistryAPIClient, model_version: ModelVersion, model: ModelArtifact ): @@ -292,6 +312,7 @@ async def test_get_all_model_artifacts( assert [model, ma2] == mas +@pytest.mark.e2e async def test_get_model_artifacts_by_mv_id( client: ModelRegistryAPIClient, model_version: ModelVersion, model: ModelArtifact ): @@ -303,6 +324,7 @@ async def test_get_model_artifacts_by_mv_id( assert [model, ma2] == mas +@pytest.mark.e2e async def test_page_through_model_version_artifacts( client: ModelRegistryAPIClient, registered_model: RegisteredModel,