Skip to content

Commit

Permalink
implement: multiple save and load for mlflow registry
Browse files Browse the repository at this point in the history
  • Loading branch information
yleilawang committed Sep 19, 2024
1 parent b18b2d2 commit e8b190a
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 5 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ target/
#mlflow
/.mlruns
*.db
mlruns/
mlartifacts/

# Jupyter Notebook
.ipynb_checkpoints
Expand Down Expand Up @@ -169,4 +171,7 @@ cython_debug/
# Mac related
*.DS_Store

# vscode
.vscode/

.python-version
133 changes: 131 additions & 2 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
import mlflow.pyfunc
import mlflow.pytorch
import mlflow.sklearn
import mlflow
from mlflow.entities.model_registry import ModelVersion
from mlflow.exceptions import RestException
from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST
from mlflow.tracking import MlflowClient
from sortedcontainers import SortedSet

from numalogic.registry import ArtifactManager, ArtifactData
from numalogic.registry.artifact import ArtifactCache
from numalogic.tools.exceptions import ModelVersionError
from numalogic.tools.types import artifact_t, KEYS, META_VT
from numalogic.tools.types import KeyedArtifact, artifact_t, KEYS, META_VT

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -187,6 +189,39 @@ def load(
self._save_in_cache(model_key, artifact_data)
return artifact_data

def load_multiple(
self,
skeys: KEYS,
dkeys_list: list[list[str]],
) -> Optional[dict[str, ArtifactData]]:
"""
Load multiple artifacts from the registry for pyfunc models.
Args:
skeys (KEYS): The source keys of the artifacts to load.
dkeys_list (list[list[str]]):
A list of lists containing the dkeys of the artifacts to load.
Returns
-------
Optional[dict[str, ArtifactData]]: A dictionary mapping joined dynamic keys
to the loaded artifacts, or None if no artifacts were found.
"""
dkeys = self.__get_sorted_unique_dkeys(dkeys_list)
loaded_model = self.load(skeys=skeys, dkeys=dkeys, artifact_type="pyfunc")
if loaded_model is not None:
metadata = loaded_model.artifact.unwrap_python_model().metadata
dict_artifacts = loaded_model.artifact.unwrap_python_model().dict_artifacts
artifacts_dict = {}
for artifact in dict_artifacts.values():
artifact_data = ArtifactData(
artifact=artifact.artifact, metadata=metadata, extras=None
)
dynamic_key = ":".join(artifact.dkeys)
artifacts_dict[dynamic_key] = artifact_data
else:
artifacts_dict = None

Check warning on line 222 in numalogic/registry/mlflow_registry.py

View check run for this annotation

Codecov / codecov/patch

numalogic/registry/mlflow_registry.py#L222

Added line #L222 was not covered by tests
return artifacts_dict

@staticmethod
def __log_mlflow_err(mlflow_err: RestException, model_key: str) -> None:
if ErrorCode.Value(mlflow_err.error_code) == RESOURCE_DOES_NOT_EXIST:
Expand Down Expand Up @@ -225,7 +260,10 @@ def save(
handler = self.handler_from_type(artifact_type)
try:
mlflow.start_run(run_id=run_id)
handler.log_model(artifact, "model", registered_model_name=model_key)
if artifact_type == "pyfunc":
handler.log_model("model", python_model=artifact, registered_model_name=model_key)
else:
handler.log_model(artifact, "model", registered_model_name=model_key)
if metadata:
mlflow.log_params(metadata)
model_version = self.transition_stage(skeys=skeys, dkeys=dkeys)
Expand All @@ -238,6 +276,37 @@ def save(
finally:
mlflow.end_run()

def save_multiple(
self,
skeys: KEYS,
dict_artifacts: dict[str, KeyedArtifact],
**metadata: META_VT,
) -> Optional[ModelVersion]:
"""
Saves multiple artifacts into mlflow registry. The last save stores all the
artifact versions in the metadata.
Args:
----
skeys: static key fields as list/tuple of strings
dict_artifacts: dict of artifacts to save
metadata: additional metadata surrounding the artifact that needs to be saved.
Returns
-------
mlflow ModelVersion instance
"""
multiple_artifacts = CompositeModels(skeys=skeys, dict_artifacts=dict_artifacts, **metadata)
dkeys_list = multiple_artifacts.get_dkeys_list()
sorted_dkeys = self.__get_sorted_unique_dkeys(dkeys_list)
return self.save(
skeys=multiple_artifacts.skeys,
dkeys=sorted_dkeys,
artifact=multiple_artifacts,
artifact_type="pyfunc",
metadata=multiple_artifacts.metadata,
)

@staticmethod
def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool:
"""Returns whether the given artifact is stale or not, i.e. if
Expand Down Expand Up @@ -338,3 +407,63 @@ def __load_artifacts(
version_info.version,
)
return model, metadata

def __get_sorted_unique_dkeys(self, dkeys_list: list[list]) -> list[str]:
"""
Returns a unique sorted list of all dkeys in the stored artifacts.
Args:
----
dkeys_list: A list of lists containing the destination keys of the artifacts.
Returns
-------
List[str]
A list of all unique dkeys in the stored artifacts, sorted in ascending order.
"""
return list(SortedSet([dkey for dkeys in dkeys_list for dkey in dkeys]))


class CompositeModels(mlflow.pyfunc.PythonModel):
"""A composite model that represents multiple artifacts.
This class extends the `mlflow.pyfunc.PythonModel` class and is used to store and load
multiple artifacts in the MLflow registry. It provides a convenient way to manage and
organize multiple artifacts associated with a single model.
Args:
skeys (KEYS): The static keys of the artifacts.
dict_artifacts (dict[str, KeyedArtifact]): A dictionary mapping dynamic keys to
`KeyedArtifact` objects.
**metadata (META_VT): Additional metadata associated with the artifacts.
Methods
-------
get_dkeys_list(): Returns a list of all dynamic keys in the stored artifacts.
Attributes
----------
skeys (KEYS): The static keys of the artifacts.
dict_artifacts (dict[str, KeyedArtifact]): A dictionary mapping dynamic keys to
`KeyedArtifact` objects.
metadata (META_VT): Additional metadata associated with the artifacts.
"""

def __init__(self, skeys: KEYS, dict_artifacts: dict[str, KeyedArtifact], **metadata: META_VT):
self.skeys = skeys
self.dict_artifacts = dict_artifacts
self.metadata = metadata

def get_dkeys_list(self):
"""
Returns a list of all dynamic keys in the stored artifacts.
Returns
-------
list[list[str]]: A list of all dynamic keys in the stored artifacts.
"""
dkeys_list = []
artifacts = self.dict_artifacts.values()
for artifact in artifacts:
dkeys_list.append(artifact.dkeys)
return dkeys_list
91 changes: 91 additions & 0 deletions tests/registry/_mlflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
from mlflow.store.entities import PagedList
from sklearn.preprocessing import StandardScaler
from torch import tensor
from mlflow.models import Model

from numalogic.models.autoencoder.variants.vanilla import VanillaAE
from numalogic.models.threshold import StdDevThreshold
from numalogic.registry.mlflow_registry import CompositeModels
from numalogic.tools.types import KeyedArtifact


def create_model():
Expand Down Expand Up @@ -135,6 +139,71 @@ def mock_log_model_sklearn(*_, **__):
)


def mock_log_model_pyfunc(*_, **__):
return ModelInfo(
artifact_path="model",
flavors={
"pyfunc": {"model_data": "data", "pyfunc_version": "1.11.0", "code": None},
"python_function": {
"pickle_module_name": "mlflow.pyfunc.pickle_module",
"loader_module": "mlflow.pyfunc",
"python_version": "3.8.5",
"data": "data",
"env": "conda.yaml",
},
},
model_uri="runs:/a7c0b376530b40d7b23e6ce2081c899c/model",
model_uuid="a7c0b376530b40d7b23e6ce2081c899c",
run_id="a7c0b376530b40d7b23e6ce2081c899c",
saved_input_example_info=None,
signature_dict=None,
utc_time_created="2022-05-23 22:35:59.557372",
mlflow_version="2.0.1",
signature=None,
)


def mock_load_model_pyfunc(*_, **__):
artifact_path = "model"
flavors = {
"python_function": {
"cloudpickle_version": "3.0.0",
"code": None,
"env": {"conda": "conda.yaml", "virtualenv": "python_env.yaml"},
"loader_module": "mlflow.pyfunc.model",
"python_model": "python_model.pkl",
"python_version": "3.10.14",
"streamable": False,
}
}
model_size_bytes = 8912
model_uuid = "ae27ecc166c94c01a4f4dccaf84ca5dc"
run_id = "7e85a3fa46d44e668c840f3dddc909c3"
utc_time_created = "2024-09-18 17:12:41.501209"
model = Model(
artifact_path=artifact_path,
flavors=flavors,
model_size_bytes=model_size_bytes,
model_uuid=model_uuid,
run_id=run_id,
utc_time_created=utc_time_created,
mlflow_version="2.16.0",
)
return mlflow.pyfunc.PyFuncModel(
model_meta=model,
model_impl=TestObject(
python_model=CompositeModels(
skeys=["model"],
dict_artifacts={
"AE": KeyedArtifact(dkeys=["AE", "infer"], artifact=VanillaAE(10)),
"scaler": KeyedArtifact(dkeys=["scaler", "infer"], artifact=StandardScaler()),
},
**{"learning_rate": 0.01},
)
),
)


def mock_transition_stage(*_, **__):
return ModelVersion(
creation_timestamp=1653402941169,
Expand Down Expand Up @@ -303,6 +372,23 @@ def return_sklearn_rundata():
)


def return_pyfunc_rundata():
return Run(
run_info=RunInfo(
artifact_uri="mlflow-artifacts:/0/a7c0b376530b40d7b23e6ce2081c899c/artifacts/model",
end_time=None,
experiment_id="0",
lifecycle_stage="active",
run_id="a7c0b376530b40d7b23e6ce2081c899c",
run_uuid="a7c0b376530b40d7b23e6ce2081c899c",
start_time=1658788772612,
status="RUNNING",
user_id="lol",
),
run_data=RunData(metrics={}, tags={}, params={}),
)


def return_pytorch_rundata_dict():
return Run(
run_info=RunInfo(
Expand All @@ -318,3 +404,8 @@ def return_pytorch_rundata_dict():
),
run_data=RunData(metrics={}, tags={}, params=[mlflow.entities.Param("lr", "0.001")]),
)


class TestObject(mlflow.pyfunc.PythonModel):
def __init__(self, python_model):
self.python_model = python_model
56 changes: 53 additions & 3 deletions tests/registry/test_mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@


from numalogic.registry.mlflow_registry import ModelStage
from numalogic.tools.types import KeyedArtifact
from tests.registry._mlflow_utils import (
mock_load_model_pyfunc,
mock_log_model_pyfunc,
model_sklearn,
create_model,
mock_log_model_pytorch,
mock_log_state_dict,
mock_get_model_version,
mock_transition_stage,
mock_log_model_sklearn,
return_pyfunc_rundata,
return_pytorch_rundata_dict,
return_empty_rundata,
mock_list_of_model_version,
Expand Down Expand Up @@ -56,22 +60,68 @@ def test_construct_key(self):
self.assertEqual("model_:nnet::error1", key)

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch)
@patch("mlflow.log_param", mock_log_state_dict)
@patch("mlflow.log_params", mock_log_state_dict)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
def test_save_model(self):
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys = self.dkeys
status = ml.save(
skeys=skeys, dkeys=dkeys, artifact=self.model, run_id="1234", artifact_type="pytorch"
skeys=skeys,
dkeys=dkeys,
artifact=self.model,
run_id="1234",
artifact_type="pytorch",
**{"lr": 0.01},
)
mock_status = "READY"
self.assertEqual(mock_status, status.status)

@patch("mlflow.pyfunc.log_model", mock_log_model_pyfunc)
@patch("mlflow.log_params", mock_log_state_dict)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata())))
@patch("mlflow.active_run", Mock(return_value=return_pyfunc_rundata()))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
def test_save_multiple_models_pyfunc(self):
ml = MLflowRegistry(TRACKING_URI)
status = ml.save_multiple(
skeys=self.skeys,
dict_artifacts={
"AE": KeyedArtifact(dkeys=["AE", "infer"], artifact=VanillaAE(10)),
"scaler": KeyedArtifact(dkeys=["scaler", "infer"], artifact=StandardScaler()),
},
**{"learning_rate": 0.01},
)
self.assertIsNotNone(status)
mock_status = "READY"
self.assertEqual(mock_status, status.status)

@patch("mlflow.pyfunc.log_model", mock_log_model_pyfunc)
@patch("mlflow.log_params", mock_log_state_dict)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pyfunc_rundata())))
@patch("mlflow.active_run", Mock(return_value=return_pyfunc_rundata()))
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.tracking.MlflowClient.search_model_versions", mock_list_of_model_version2)
@patch("mlflow.pyfunc.load_model", mock_load_model_pyfunc)
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pyfunc_rundata()))
def test_load_multiple_models_when_pyfunc_model_exist(self):
ml = MLflowRegistry(TRACKING_URI)
skeys = self.skeys
dkeys_list = [["AE", "infer"], ["scaler", "infer"]]
data = ml.load_multiple(skeys=skeys, dkeys_list=dkeys_list)
self.assertIsNotNone(data["AE:infer"].metadata)
self.assertIsNotNone(data["scaler:infer"].metadata)
self.assertIsInstance(data, dict)
self.assertIsInstance(data["AE:infer"].artifact, VanillaAE)
self.assertIsInstance(data["scaler:infer"].artifact, StandardScaler)

@patch("mlflow.sklearn.log_model", mock_log_model_sklearn)
@patch("mlflow.log_param", mock_log_state_dict)
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_sklearn_rundata())))
Expand Down

0 comments on commit e8b190a

Please sign in to comment.