Skip to content

Commit

Permalink
data set logging
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-meyer-1986 committed Nov 22, 2024
1 parent c07f301 commit e0fc867
Showing 1 changed file with 39 additions and 13 deletions.
52 changes: 39 additions & 13 deletions mlflow/getml/autologging.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import json
import logging
import threading
from dataclasses import dataclass, field
from typing import Any

import mlflow
from mlflow.data.pandas_dataset import PandasDataset
from mlflow.entities.dataset_input import DatasetInput
from mlflow.entities.input_tag import InputTag
from mlflow.utils.autologging_utils import safe_patch
from mlflow.utils.autologging_utils.client import MlflowAutologgingQueueingClient
from mlflow.utils.mlflow_tags import MLFLOW_DATASET_CONTEXT

_logger = logging.getLogger(__name__)

@dataclass
class LogInfo:
Expand Down Expand Up @@ -152,20 +158,8 @@ def patched_fit_mlflow(original, self: getml.Pipeline, *args, **kwargs):
autologging_client = MlflowAutologgingQueueingClient()
assert (active_run := mlflow.active_run())
run_id = active_run.info.run_id
pipeline_log_info = _extract_pipeline_informations(self)
# with open("my_dict.json", "w") as f:
# json.dump(pipeline_log_info.params, f)
# mlflow.log_artifact("my_dict.json")
# mlflow.log_dict(pipeline_log_info.params, 'params.json')
autologging_client.log_params(
run_id=run_id,
params=pipeline_log_info.params,
)
if tags := pipeline_log_info.tags:
autologging_client.set_tags(run_id=run_id, tags=tags)

engine_metrics_to_be_tracked = _collect_available_engine_metrics()

engine_metrics_to_be_tracked = _log_pretraining_metadata(autologging_client, self, run_id, *args)
if engine_metrics_to_be_tracked:
stop_event = threading.Event()
metrics_thread = threading.Thread(
Expand Down Expand Up @@ -210,7 +204,39 @@ def patched_score_method(original, self: getml.Pipeline, *args, **kwargs):
)

return original(self, *args, **kwargs)
def _log_pretraining_metadata(autologging_client, self: getml.Pipeline, run_id, *args):

pipeline_log_info = _extract_pipeline_informations(self)
autologging_client.log_params(
run_id=run_id,
params=pipeline_log_info.params,
)
if tags := pipeline_log_info.tags:
autologging_client.set_tags(run_id=run_id, tags=tags)

engine_metrics_to_be_tracked = _collect_available_engine_metrics()

if log_datasets:
try:
datasets = []
population_dataset: PandasDataset = mlflow.data.from_pandas(args[0].population.to_pandas(), name = args[0].population.name) #args[0].population.name returns the wrong name
tags = [InputTag(key=MLFLOW_DATASET_CONTEXT, value='Population')]
datasets.append(DatasetInput(dataset=population_dataset._to_mlflow_entity(), tags=tags))

for name, peripheral in args[0].peripheral.items():
tags = [InputTag(key=MLFLOW_DATASET_CONTEXT, value='Peripheral')]
peripheral_dataset: PandasDataset = mlflow.data.from_pandas(peripheral.to_pandas(), name = name)
datasets.append(DatasetInput(dataset=peripheral_dataset._to_mlflow_entity(), tags=tags))

autologging_client.log_inputs(
run_id=run_id, datasets=datasets
)

except Exception as e:
_logger.warning(
"Failed to log training dataset information to MLflow Tracking. Reason: %s", e
)
return engine_metrics_to_be_tracked

_patch_pipeline_method(
flavor_name=flavor_name,
Expand Down

0 comments on commit e0fc867

Please sign in to comment.