Skip to content

Commit

Permalink
[MMM-14942] Support toggling registered model global flag (#356)
Browse files Browse the repository at this point in the history
* Return everything from get_registered_model_by_name

* Fix test that breaks on my machine

* Set registered model public

* Set global in model controller, functional test

* Fix lint

* Rename to global

* Remove functional test
  • Loading branch information
baekdahl authored Oct 31, 2023
1 parent c961695 commit f3aa0a1
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 7 deletions.
59 changes: 53 additions & 6 deletions src/dr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class DrClient:
DEPLOYMENT_ACTUALS_UPDATE_ROUTE = DEPLOYMENT_ROUTE + "actuals/fromDataset/"
ENVIRONMENT_DROP_IN_ROUTE = "executionEnvironments/"
REGISTERED_MODELS_LIST_ROUTE = "registeredModels/"
REGISTERED_MODEL_ROUTE = "registeredModels/{registered_model_id}/"
REGISTERED_MODELS_VERSIONS_ROUTE = "registeredModels/{registered_model_id}/versions/"

DEFAULT_MAX_WAIT_SEC = 600
Expand Down Expand Up @@ -472,7 +473,8 @@ def create_or_update_registered_model(self, custom_model_version_id, registered_
str,
Registered model version id of existing or newly created version.
"""
registered_model_id = self.get_registered_model_by_name(registered_model_name)
registered_model = self.get_registered_model_by_name(registered_model_name)
registered_model_id = registered_model["id"] if registered_model else None
if registered_model_id:
existing_registered_versions = self._get_registered_model_versions(registered_model_id)
existing_version_id = next(
Expand All @@ -493,9 +495,54 @@ def create_or_update_registered_model(self, custom_model_version_id, registered_
return existing_version_id
registered_model_name = None

return self.create_model_package_from_custom_model_version(
model_package = self.create_model_package_from_custom_model_version(
custom_model_version_id, registered_model_name, registered_model_id
)["id"]
)

return model_package["id"]

def set_registered_model_global(self, registered_model_name, is_global):
"""
Set the global property for a registered model.
This is also known as the global property
Parameters
----------
registered_model_name : str
Name of registered model.
is_global : bool
True if model should be global, False if not.
"""
registered_model = self.get_registered_model_by_name(registered_model_name)
if not registered_model:
raise DataRobotClientError(
f"Failed to find registered model by name: {registered_model_name}"
)

if registered_model.get("isGlobal", None) == is_global:
logger.info(
"Registered model '%s' global flag is already: %s", registered_model_name, is_global
)
return

response = self._http_requester.patch(
self.REGISTERED_MODEL_ROUTE.format(registered_model_id=registered_model["id"]),
json={"isGlobal": is_global},
)
if response.status_code != 200:
raise DataRobotClientError(
"Failed to set registered global property "
f"Registered model name: {registered_model_name}, "
f"Response status: {response.status_code}, "
f"Response body: {response.text}",
code=response.status_code,
)

logger.info(
"Registered model '%s' global flag has been set to: %s",
registered_model_name,
is_global,
)

def get_registered_model_by_name(self, registered_model_name):
"""
Expand All @@ -508,14 +555,14 @@ def get_registered_model_by_name(self, registered_model_name):
Returns
-------
str or None,
Registered model id if found, otherwise None.
dict or None,
Registered model if found, otherwise None.
"""
items = self._paginated_fetch(
self.REGISTERED_MODELS_LIST_ROUTE,
params={"search": registered_model_name},
)
return next((item["id"] for item in items if item["name"] == registered_model_name), None)
return next((item for item in items if item["name"] == registered_model_name), None)

def _get_registered_model_versions(self, registered_model_id):
return self._paginated_fetch(
Expand Down
3 changes: 3 additions & 0 deletions src/model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,9 @@ def handle_model_changes(self):
self._dr_client.create_or_update_registered_model(
latest_version["id"], model_info.registered_model_name
)
self._dr_client.set_registered_model_global(
model_info.registered_model_name, model_info.registered_model_global
)

@staticmethod
def _was_new_version_created(previous_latest_version, latest_version):
Expand Down
6 changes: 6 additions & 0 deletions src/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
logger = logging.getLogger()


# pylint: disable=too-many-public-methods
class InfoBase(ABC):
"""An abstract base class for models and deployments information classes."""

Expand Down Expand Up @@ -365,6 +366,11 @@ def registered_model_name(self):
"""The registered model name to use or None if model should not be registered."""
return self.get_value(ModelSchema.MODEL_REGISTRY_KEY, ModelSchema.MODEL_NAME)

@property
def registered_model_global(self):
"""Wheter the registered model should be global or not."""
return self.get_value(ModelSchema.MODEL_REGISTRY_KEY, ModelSchema.GLOBAL)

@property
def should_run_test(self):
"""
Expand Down
2 changes: 2 additions & 0 deletions src/schema_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ class ModelSchema(SharedSchema):

MODEL_REGISTRY_KEY = "model_registry"
MODEL_NAME = "model_name"
GLOBAL = "global"

MODEL_SCHEMA = Schema(
{
Expand Down Expand Up @@ -381,6 +382,7 @@ class ModelSchema(SharedSchema):
},
Optional(MODEL_REGISTRY_KEY): {
Optional(MODEL_NAME): And(str, len),
Optional(GLOBAL, default=False): bool,
},
}
)
Expand Down
66 changes: 65 additions & 1 deletion tests/unit/test_dr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def test_delete_custom_model_success(
)
expected_model = expected_model[0][0]
delete_url = f"{custom_models_url}{expected_model['id']}/"
responses.add(responses.DELETE, delete_url, json={}, status=204)
responses.add(responses.DELETE, delete_url, status=204)
dr_client.delete_custom_model_by_user_provided_id(expected_model["userProvidedId"])

@responses.activate
Expand Down Expand Up @@ -1605,6 +1605,70 @@ def test_version_already_registered(

assert registered_model_version == registered_model_version_id

@pytest.mark.parametrize("is_already_global", [True, False])
@responses.activate
def test_set_global(self, dr_client, paginated_url_factory, is_already_global):
"""Test setting registered model as global"""

registered_model = {
"id": "registered_model_id",
"name": "registered_model_name",
"isGlobal": is_already_global,
}

params = {"search": registered_model["name"]}
mock_single_page_response(
paginated_url_factory(DrClient.REGISTERED_MODELS_LIST_ROUTE),
entities=[registered_model],
match=[matchers.query_param_matcher(params)],
)

patch_mock = responses.patch(
url=paginated_url_factory(
DrClient.REGISTERED_MODEL_ROUTE.format(registered_model_id=registered_model["id"])
),
status=200,
)

dr_client.set_registered_model_global(registered_model["name"], True)

assert patch_mock.call_count == 0 if is_already_global else 1

@responses.activate
def test_set_global_non_existent(self, dr_client, paginated_url_factory):
"""Test that non existent registered model raises error"""

mock_single_page_response(
paginated_url_factory(DrClient.REGISTERED_MODELS_LIST_ROUTE),
entities=[],
)

with pytest.raises(DataRobotClientError):
dr_client.set_registered_model_global("non_existent_registered_model", True)

@responses.activate
def test_set_global_error(self, dr_client, paginated_url_factory):
"""Test setting global fails"""

registered_model = {"id": "registered_model_id", "name": "registered_model_name"}

params = {"search": registered_model["name"]}
mock_single_page_response(
paginated_url_factory(DrClient.REGISTERED_MODELS_LIST_ROUTE),
entities=[registered_model],
match=[matchers.query_param_matcher(params)],
)

responses.patch(
url=paginated_url_factory(
DrClient.REGISTERED_MODEL_ROUTE.format(registered_model_id=registered_model["id"])
),
status=500,
)

with pytest.raises(DataRobotClientError):
dr_client.set_registered_model_global(registered_model["name"], True)


class TestDeploymentRoutes(SharedRouteTests):
"""Contains unit-tests to test the DataRobot deployment routes."""
Expand Down

0 comments on commit f3aa0a1

Please sign in to comment.