Skip to content

Commit

Permalink
[MMM-14420] Add Registered model support (#355)
Browse files Browse the repository at this point in the history
* Registered model support

* Fix lint

* Remove duplicate fixture
  • Loading branch information
baekdahl authored Oct 2, 2023
1 parent 0767932 commit c961695
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 2 deletions.
87 changes: 86 additions & 1 deletion src/dr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class DrClient:
DEPLOYMENT_MODEL_CHALLENGER_ROUTE = DEPLOYMENT_ROUTE + "challengers/"
DEPLOYMENT_ACTUALS_UPDATE_ROUTE = DEPLOYMENT_ROUTE + "actuals/fromDataset/"
ENVIRONMENT_DROP_IN_ROUTE = "executionEnvironments/"
REGISTERED_MODELS_LIST_ROUTE = "registeredModels/"
REGISTERED_MODELS_VERSIONS_ROUTE = "registeredModels/{registered_model_id}/versions/"

DEFAULT_MAX_WAIT_SEC = 600

Expand Down Expand Up @@ -452,6 +454,74 @@ def create_custom_model_version(
logger.info("Custom model version created successfully (id: %s)", model_version["id"])
return model_version

def create_or_update_registered_model(self, custom_model_version_id, registered_model_name):
"""
Creates or updates a registered model from custom model version.
If a registered model named registered_model_name exists, it is updated with a new
version if needed. If it does not exist, it is created.
Parameters
----------
custom_model_version_id : str
Custom model version id to create registered model version from.
registered_model_name : str
Registered model name to create or update.
Returns
-------
str,
Registered model version id of existing or newly created version.
"""
registered_model_id = self.get_registered_model_by_name(registered_model_name)
if registered_model_id:
existing_registered_versions = self._get_registered_model_versions(registered_model_id)
existing_version_id = next(
(
v["id"]
for v in existing_registered_versions
if v["modelId"] == custom_model_version_id
),
None,
)
if existing_version_id:
logger.info(
"Custom model version is already registered. Registered model name: %s, "
"custom model version id: %s",
registered_model_name,
custom_model_version_id,
)
return existing_version_id
registered_model_name = None

return self.create_model_package_from_custom_model_version(
custom_model_version_id, registered_model_name, registered_model_id
)["id"]

def get_registered_model_by_name(self, registered_model_name):
"""
Retrieves a registered model by name.
Parameters
----------
registered_model_name : str
The name of the registered model to get.
Returns
-------
str or None,
Registered model id if found, otherwise None.
"""
items = self._paginated_fetch(
self.REGISTERED_MODELS_LIST_ROUTE,
params={"search": registered_model_name},
)
return next((item["id"] for item in items if item["name"] == registered_model_name), None)

def _get_registered_model_versions(self, registered_model_id):
return self._paginated_fetch(
self.REGISTERED_MODELS_VERSIONS_ROUTE.format(registered_model_id=registered_model_id),
)

@classmethod
def _setup_payload_for_custom_model_version_creation(
cls,
Expand Down Expand Up @@ -1153,14 +1223,24 @@ def create_deployment(self, custom_model_version, deployment_info):
deployment, _ = self.update_deployment_settings(deployment, deployment_info)
return deployment

def create_model_package_from_custom_model_version(self, custom_model_version_id):
def create_model_package_from_custom_model_version(
self, custom_model_version_id, registered_model_name=None, registered_model_id=None
):
"""
Creates a model package in the model's registry from a custom model version.
Parameters
----------
custom_model_version_id : str
A custom model version ID
registered_model_name : str
Registered model name. This will add the model package as a registered model version
of a new registered model by this name.
If None, will be left out of request.
registered_model_id : str
Registered model id. This will add the model package as a registered model version
of an existing registered model by this id.
IF None, will be left out of request.
Returns
-------
Expand All @@ -1169,6 +1249,11 @@ def create_model_package_from_custom_model_version(self, custom_model_version_id
"""

payload = {"customModelVersionId": custom_model_version_id}
if registered_model_name:
payload["registeredModelName"] = registered_model_name
if registered_model_id:
payload["registeredModelId"] = registered_model_id

response = self._http_requester.post(self.MODEL_PACKAGES_CREATE_ROUTE, json=payload)
if response.status_code != 201:
raise DataRobotClientError(
Expand Down
5 changes: 5 additions & 0 deletions src/model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,11 @@ def handle_model_changes(self):
custom_model["id"], latest_version["id"], model_info
)

if model_info.should_register_model:
self._dr_client.create_or_update_registered_model(
latest_version["id"], model_info.registered_model_name
)

@staticmethod
def _was_new_version_created(previous_latest_version, latest_version):
return not previous_latest_version or latest_version["id"] != previous_latest_version["id"]
Expand Down
10 changes: 10 additions & 0 deletions src/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,16 @@ def is_there_a_change_in_training_or_holdout_data_at_version_level(

return False

@property
def should_register_model(self):
"""Wheter this model should be added as a registered model."""
return self.registered_model_name is not None

@property
def registered_model_name(self):
"""The registered model name to use or None if model should not be registered."""
return self.get_value(ModelSchema.MODEL_REGISTRY_KEY, ModelSchema.MODEL_NAME)

@property
def should_run_test(self):
"""
Expand Down
6 changes: 6 additions & 0 deletions src/schema_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ class ModelSchema(SharedSchema):
MINIMUM_PAYLOAD_SIZE_KEY = "minimum_payload_size"
MAXIMUM_PAYLOAD_SIZE_KEY = "maximum_payload_size"

MODEL_REGISTRY_KEY = "model_registry"
MODEL_NAME = "model_name"

MODEL_SCHEMA = Schema(
{
SharedSchema.MODEL_ID_KEY: And(str, len, Use(Namespace.namespaced)),
Expand Down Expand Up @@ -376,6 +379,9 @@ class ModelSchema(SharedSchema):
},
},
},
Optional(MODEL_REGISTRY_KEY): {
Optional(MODEL_NAME): And(str, len),
},
}
)
MULTI_MODELS_SCHEMA = Schema(
Expand Down
17 changes: 17 additions & 0 deletions tests/functional/test_model_github_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def test_e2e_pull_request_event_with_multiple_changes( # pylint: disable=too-ma
self._rename_model,
self._add_file_check,
self._remove_file_check,
self._add_registered_model,
]
self._run_checks(
checks,
Expand Down Expand Up @@ -333,6 +334,22 @@ def _add_remove_file_check(
for filepath in files_to_add_and_remove:
assert (filepath.name in cm_version_files) == is_add

@classmethod
@contextlib.contextmanager
def _add_registered_model(cls, git_repo, dr_client, model_metadata, model_metadata_yaml_file):
printout("Create new registered model ...")
model_metadata.update(
{ModelSchema.MODEL_REGISTRY_KEY: {ModelSchema.MODEL_NAME: "registered_model"}}
)

save_new_metadata_and_commit(
model_metadata, model_metadata_yaml_file, git_repo, "Add registered model"
)

yield cls.ExpectedChange(settings_updated=False, version_created=False)

assert dr_client.get_registered_model_by_name("registered_model") is not None

@staticmethod
def _merge_changes_into_the_main_branch(git_repo, merge_branch):
# Merge changes from the merge branch into the main branch
Expand Down
130 changes: 129 additions & 1 deletion tests/unit/test_dr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from bson import ObjectId
from mock import Mock
from mock import patch
from responses import matchers

from common.exceptions import DataRobotClientError
from common.http_requester import HttpRequester
Expand Down Expand Up @@ -126,7 +127,7 @@ def fixture_regression_model_info():


def mock_paginated_responses(
total_num_entities, num_entities_in_page, url_factory, entity_response_factory_fn
total_num_entities, num_entities_in_page, url_factory, entity_response_factory_fn, match=None
):
"""A method to mock paginated responses from DataRobot."""

Expand All @@ -144,6 +145,7 @@ def _generate_for_single_page(page_index, num_entities, has_next):
"data": entities_in_page,
},
status=200,
match=match or [],
)
return entities_in_page

Expand All @@ -162,6 +164,22 @@ def _generate_for_single_page(page_index, num_entities, has_next):
return total_entities


def mock_single_page_response(url, entities, match=None):
"""Mock single page paginated response."""

def url_factory(_):
return url

entities_iter = iter(entities)

def entity_response_factory_fn(_):
return next(entities_iter)

return mock_paginated_responses(
len(entities), max(len(entities), 1), url_factory, entity_response_factory_fn, match
)


class TestPaginator:
"""
A class to test the DrClient when it fetched entities from DataRobot from URLs that support
Expand Down Expand Up @@ -1478,6 +1496,116 @@ def test_dependency_environment_build_started_and_failed(
assert response_obj.call_count == 1


class TestRegisteredModels:
"""Registered models tests."""

@pytest.fixture
def registered_model_response_mock(self, paginated_url_factory):
"""Return existing registered model"""
with responses.RequestsMock():
registered_model = {
"id": "existing_registered_model_id",
"name": "existing_registered_model",
}

params = {"search": registered_model["name"]}
mock_single_page_response(
paginated_url_factory(DrClient.REGISTERED_MODELS_LIST_ROUTE),
entities=[registered_model],
match=[matchers.query_param_matcher(params)],
)

yield registered_model

@responses.activate
def test_create_new_registered_model(self, dr_client, paginated_url_factory):
"""Test creating new registered model"""
params = {"search": "non_existent_registered_model"}
mock_single_page_response(
paginated_url_factory(DrClient.REGISTERED_MODELS_LIST_ROUTE),
entities=[],
match=[matchers.query_param_matcher(params)],
)

responses.post(
url=paginated_url_factory(DrClient.MODEL_PACKAGES_CREATE_ROUTE),
json={"id": "new_registered_model_id"},
status=201,
)

registered_model_version = dr_client.create_or_update_registered_model(
"custom_model_version_id", "non_existent_registered_model"
)

assert registered_model_version == "new_registered_model_id"

@responses.activate
def test_update_existing_registered_model(
self,
dr_client,
paginated_url_factory,
custom_model_version_id,
registered_model_response_mock,
):
"""Update existing registered model by creating new version."""

mock_single_page_response(
paginated_url_factory(
DrClient.REGISTERED_MODELS_VERSIONS_ROUTE.format(
registered_model_id=registered_model_response_mock["id"]
)
),
entities=[],
)

create_model_package_payload = {
"customModelVersionId": custom_model_version_id,
"registeredModelId": registered_model_response_mock["id"],
}
new_registered_model_id = "new_registered_model_id"
responses.post(
url=paginated_url_factory(DrClient.MODEL_PACKAGES_CREATE_ROUTE),
match=[matchers.json_params_matcher(create_model_package_payload)],
json={"id": new_registered_model_id},
status=201,
)
registered_model_version = dr_client.create_or_update_registered_model(
custom_model_version_id, registered_model_response_mock["name"]
)

assert registered_model_version == new_registered_model_id

@responses.activate
def test_version_already_registered(
self,
dr_client,
paginated_url_factory,
custom_model_version_id,
registered_model_response_mock,
):
"""Existing registered model that already contains this version should do nothing."""
registered_model_version_id = "registered_model_version_id"
mock_single_page_response(
paginated_url_factory(
DrClient.REGISTERED_MODELS_VERSIONS_ROUTE.format(
registered_model_id=registered_model_response_mock["id"]
)
),
entities=[
{
"id": registered_model_version_id,
"modelId": custom_model_version_id,
}
],
)

registered_model_version = dr_client.create_or_update_registered_model(
custom_model_version_id, registered_model_response_mock["name"]
)

assert registered_model_version == registered_model_version_id


class TestDeploymentRoutes(SharedRouteTests):
"""Contains unit-tests to test the DataRobot deployment routes."""

Expand Down

0 comments on commit c961695

Please sign in to comment.