Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement: multiple save and load for mlflow registry #416

Merged
merged 7 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ target/
#mlflow
/.mlruns
*.db
mlruns/
mlartifacts/

# Jupyter Notebook
.ipynb_checkpoints
Expand Down Expand Up @@ -169,4 +171,7 @@ cython_debug/
# Mac related
*.DS_Store

# vscode
.vscode/

.python-version
133 changes: 131 additions & 2 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
import mlflow.pyfunc
import mlflow.pytorch
import mlflow.sklearn
yleilawang marked this conversation as resolved.
Show resolved Hide resolved
import mlflow
from mlflow.entities.model_registry import ModelVersion
from mlflow.exceptions import RestException
from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST
from mlflow.tracking import MlflowClient
from sortedcontainers import SortedSet

from numalogic.registry import ArtifactManager, ArtifactData
from numalogic.registry.artifact import ArtifactCache
from numalogic.tools.exceptions import ModelVersionError
from numalogic.tools.types import artifact_t, KEYS, META_VT
from numalogic.tools.types import KeyedArtifact, artifact_t, KEYS, META_VT

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -187,6 +189,39 @@
self._save_in_cache(model_key, artifact_data)
return artifact_data

def load_multiple(
self,
skeys: KEYS,
dkeys_list: list[list[str]],
) -> Optional[dict[str, ArtifactData]]:
"""
Load multiple artifacts from the registry for pyfunc models.
Args:
skeys (KEYS): The source keys of the artifacts to load.
dkeys_list (list[list[str]]):
A list of lists containing the dkeys of the artifacts to load.

Returns
-------
Optional[dict[str, ArtifactData]]: A dictionary mapping joined dynamic keys
to the loaded artifacts, or None if no artifacts were found.
"""
dkeys = self.__get_sorted_unique_dkeys(dkeys_list)
loaded_model = self.load(skeys=skeys, dkeys=dkeys, artifact_type="pyfunc")
if loaded_model is not None:
metadata = loaded_model.artifact.unwrap_python_model().metadata
dict_artifacts = loaded_model.artifact.unwrap_python_model().dict_artifacts
yleilawang marked this conversation as resolved.
Show resolved Hide resolved
artifacts_dict = {}
for artifact in dict_artifacts.values():
artifact_data = ArtifactData(
artifact=artifact.artifact, metadata=metadata, extras=None
)
dynamic_key = ":".join(artifact.dkeys)
artifacts_dict[dynamic_key] = artifact_data
else:
artifacts_dict = None

Check warning on line 222 in numalogic/registry/mlflow_registry.py

View check run for this annotation

Codecov / codecov/patch

numalogic/registry/mlflow_registry.py#L222

Added line #L222 was not covered by tests
yleilawang marked this conversation as resolved.
Show resolved Hide resolved
return artifacts_dict

@staticmethod
def __log_mlflow_err(mlflow_err: RestException, model_key: str) -> None:
if ErrorCode.Value(mlflow_err.error_code) == RESOURCE_DOES_NOT_EXIST:
Expand Down Expand Up @@ -225,7 +260,10 @@
handler = self.handler_from_type(artifact_type)
try:
mlflow.start_run(run_id=run_id)
handler.log_model(artifact, "model", registered_model_name=model_key)
if artifact_type == "pyfunc":
handler.log_model("model", python_model=artifact, registered_model_name=model_key)
else:
handler.log_model(artifact, "model", registered_model_name=model_key)
if metadata:
mlflow.log_params(metadata)
model_version = self.transition_stage(skeys=skeys, dkeys=dkeys)
Expand All @@ -238,6 +276,37 @@
finally:
mlflow.end_run()

def save_multiple(
self,
skeys: KEYS,
dict_artifacts: dict[str, KeyedArtifact],
**metadata: META_VT,
) -> Optional[ModelVersion]:
"""
Saves multiple artifacts into mlflow registry. The last save stores all the
artifact versions in the metadata.

Args:
----
skeys: static key fields as list/tuple of strings
dict_artifacts: dict of artifacts to save
metadata: additional metadata surrounding the artifact that needs to be saved.

Returns
-------
mlflow ModelVersion instance
"""
multiple_artifacts = CompositeModels(skeys=skeys, dict_artifacts=dict_artifacts, **metadata)
dkeys_list = multiple_artifacts.get_dkeys_list()
yleilawang marked this conversation as resolved.
Show resolved Hide resolved
sorted_dkeys = self.__get_sorted_unique_dkeys(dkeys_list)
return self.save(
skeys=multiple_artifacts.skeys,
dkeys=sorted_dkeys,
artifact=multiple_artifacts,
artifact_type="pyfunc",
metadata=multiple_artifacts.metadata,
)

@staticmethod
def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool:
"""Returns whether the given artifact is stale or not, i.e. if
Expand Down Expand Up @@ -338,3 +407,63 @@
version_info.version,
)
return model, metadata

def __get_sorted_unique_dkeys(self, dkeys_list: list[list]) -> list[str]:
"""
Returns a unique sorted list of all dkeys in the stored artifacts.

Args:
----
dkeys_list: A list of lists containing the destination keys of the artifacts.

Returns
-------
List[str]
A list of all unique dkeys in the stored artifacts, sorted in ascending order.
"""
return list(SortedSet([dkey for dkeys in dkeys_list for dkey in dkeys]))


class CompositeModels(mlflow.pyfunc.PythonModel):
yleilawang marked this conversation as resolved.
Show resolved Hide resolved
"""A composite model that represents multiple artifacts.

This class extends the `mlflow.pyfunc.PythonModel` class and is used to store and load
multiple artifacts in the MLflow registry. It provides a convenient way to manage and
organize multiple artifacts associated with a single model.

Args:
skeys (KEYS): The static keys of the artifacts.
dict_artifacts (dict[str, KeyedArtifact]): A dictionary mapping dynamic keys to
`KeyedArtifact` objects.
**metadata (META_VT): Additional metadata associated with the artifacts.

Methods
-------
get_dkeys_list(): Returns a list of all dynamic keys in the stored artifacts.

Attributes
----------
skeys (KEYS): The static keys of the artifacts.
dict_artifacts (dict[str, KeyedArtifact]): A dictionary mapping dynamic keys to
`KeyedArtifact` objects.
metadata (META_VT): Additional metadata associated with the artifacts.
"""

def __init__(self, skeys: KEYS, dict_artifacts: dict[str, KeyedArtifact], **metadata: META_VT):
yleilawang marked this conversation as resolved.
Show resolved Hide resolved
self.skeys = skeys
self.dict_artifacts = dict_artifacts
self.metadata = metadata

def get_dkeys_list(self):
"""
Returns a list of all dynamic keys in the stored artifacts.

Returns
-------
list[list[str]]: A list of all dynamic keys in the stored artifacts.
"""
dkeys_list = []
artifacts = self.dict_artifacts.values()
for artifact in artifacts:
dkeys_list.append(artifact.dkeys)
return dkeys_list
yleilawang marked this conversation as resolved.
Show resolved Hide resolved
91 changes: 91 additions & 0 deletions tests/registry/_mlflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
from mlflow.store.entities import PagedList
from sklearn.preprocessing import StandardScaler
from torch import tensor
from mlflow.models import Model

from numalogic.models.autoencoder.variants.vanilla import VanillaAE
from numalogic.models.threshold import StdDevThreshold
from numalogic.registry.mlflow_registry import CompositeModels
from numalogic.tools.types import KeyedArtifact


def create_model():
Expand Down Expand Up @@ -135,6 +139,71 @@ def mock_log_model_sklearn(*_, **__):
)


def mock_log_model_pyfunc(*_, **__):
return ModelInfo(
artifact_path="model",
flavors={
"pyfunc": {"model_data": "data", "pyfunc_version": "1.11.0", "code": None},
"python_function": {
"pickle_module_name": "mlflow.pyfunc.pickle_module",
"loader_module": "mlflow.pyfunc",
"python_version": "3.8.5",
"data": "data",
"env": "conda.yaml",
},
},
model_uri="runs:/a7c0b376530b40d7b23e6ce2081c899c/model",
model_uuid="a7c0b376530b40d7b23e6ce2081c899c",
run_id="a7c0b376530b40d7b23e6ce2081c899c",
saved_input_example_info=None,
signature_dict=None,
utc_time_created="2022-05-23 22:35:59.557372",
mlflow_version="2.0.1",
signature=None,
)


def mock_load_model_pyfunc(*_, **__):
artifact_path = "model"
flavors = {
"python_function": {
"cloudpickle_version": "3.0.0",
"code": None,
"env": {"conda": "conda.yaml", "virtualenv": "python_env.yaml"},
"loader_module": "mlflow.pyfunc.model",
"python_model": "python_model.pkl",
"python_version": "3.10.14",
"streamable": False,
}
}
model_size_bytes = 8912
model_uuid = "ae27ecc166c94c01a4f4dccaf84ca5dc"
run_id = "7e85a3fa46d44e668c840f3dddc909c3"
utc_time_created = "2024-09-18 17:12:41.501209"
model = Model(
artifact_path=artifact_path,
flavors=flavors,
model_size_bytes=model_size_bytes,
model_uuid=model_uuid,
run_id=run_id,
utc_time_created=utc_time_created,
mlflow_version="2.16.0",
)
return mlflow.pyfunc.PyFuncModel(
model_meta=model,
model_impl=TestObject(
python_model=CompositeModels(
skeys=["model"],
dict_artifacts={
"AE": KeyedArtifact(dkeys=["AE", "infer"], artifact=VanillaAE(10)),
"scaler": KeyedArtifact(dkeys=["scaler", "infer"], artifact=StandardScaler()),
},
**{"learning_rate": 0.01},
)
),
)


def mock_transition_stage(*_, **__):
return ModelVersion(
creation_timestamp=1653402941169,
Expand Down Expand Up @@ -303,6 +372,23 @@ def return_sklearn_rundata():
)


def return_pyfunc_rundata():
return Run(
run_info=RunInfo(
artifact_uri="mlflow-artifacts:/0/a7c0b376530b40d7b23e6ce2081c899c/artifacts/model",
end_time=None,
experiment_id="0",
lifecycle_stage="active",
run_id="a7c0b376530b40d7b23e6ce2081c899c",
run_uuid="a7c0b376530b40d7b23e6ce2081c899c",
start_time=1658788772612,
status="RUNNING",
user_id="lol",
),
run_data=RunData(metrics={}, tags={}, params={}),
)


def return_pytorch_rundata_dict():
return Run(
run_info=RunInfo(
Expand All @@ -318,3 +404,8 @@ def return_pytorch_rundata_dict():
),
run_data=RunData(metrics={}, tags={}, params=[mlflow.entities.Param("lr", "0.001")]),
)


class TestObject(mlflow.pyfunc.PythonModel):
def __init__(self, python_model):
self.python_model = python_model
56 changes: 53 additions & 3 deletions tests/registry/test_mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@


from numalogic.registry.mlflow_registry import ModelStage
from numalogic.tools.types import KeyedArtifact
from tests.registry._mlflow_utils import (
mock_load_model_pyfunc,
mock_log_model_pyfunc,
model_sklearn,
create_model,
mock_log_model_pytorch,
mock_log_state_dict,
mock_get_model_version,
mock_transition_stage,
mock_log_model_sklearn,
return_pyfunc_rundata,
return_pytorch_rundata_dict,
return_empty_rundata,
mock_list_of_model_version,
Expand Down Expand Up @@ -56,22 +60,68 @@ def test_construct_key(self):
self.assertEqual("model_:nnet::error1", key)

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.log_param", mock_log_state_dict)
@patch("mlflow.log_params", mock_log_state_dict)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
def test_save_model(self):
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys = self.dkeys
status = ml.save(
skeys=skeys, dkeys=dkeys, artifact=self.model, run_id="1234", artifact_type="pytorch"
skeys=skeys,
dkeys=dkeys,
artifact=self.model,
run_id="1234",
artifact_type="pytorch",
**{"lr": 0.01},
)
mock_status = "READY"
self.assertEqual(mock_status, status.status)

@patch("mlflow.pyfunc.log_model", mock_log_model_pyfunc)
@patch("mlflow.log_params", mock_log_state_dict)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata())))
@patch("mlflow.active_run", Mock(return_value=return_pyfunc_rundata()))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
def test_save_multiple_models_pyfunc(self):
ml = MLflowRegistry(TRACKING_URI)
status = ml.save_multiple(
skeys=self.skeys,
dict_artifacts={
"AE": KeyedArtifact(dkeys=["AE", "infer"], artifact=VanillaAE(10)),
"scaler": KeyedArtifact(dkeys=["scaler", "infer"], artifact=StandardScaler()),
},
**{"learning_rate": 0.01},
)
self.assertIsNotNone(status)
mock_status = "READY"
self.assertEqual(mock_status, status.status)

@patch("mlflow.pyfunc.log_model", mock_log_model_pyfunc)
@patch("mlflow.log_params", mock_log_state_dict)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata())))
@patch("mlflow.active_run", Mock(return_value=return_pyfunc_rundata()))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.pyfunc.load_model", mock_load_model_pyfunc)
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pyfunc_rundata()))
def test_load_multiple_models_when_pyfunc_model_exist(self):
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys_list = [["AE", "infer"], ["scaler", "infer"]]
data = ml.load_multiple(skeys=skeys, dkeys_list=dkeys_list)
self.assertIsNotNone(data["AE:infer"].metadata)
self.assertIsNotNone(data["scaler:infer"].metadata)
self.assertIsInstance(data, dict)
self.assertIsInstance(data["AE:infer"].artifact, VanillaAE)
self.assertIsInstance(data["scaler:infer"].artifact, StandardScaler)

@patch("mlflow.sklearn.log_model", mock_log_model_sklearn)
@patch("mlflow.log_param", mock_log_state_dict)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_sklearn_rundata())))
Expand Down
Loading