Skip to content

Commit

Permalink
Issue #191: added OSS implementation of nested runs
Browse files Browse the repository at this point in the history
  • Loading branch information
amesar committed Jul 29, 2024
1 parent da3a4ea commit a572e94
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 6 deletions.
9 changes: 3 additions & 6 deletions mlflow_export_import/experiment/nested_runs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@ def get_nested_runs(client, runs):
Return set of run_ids and their nested run descendants from list of run IDs.
"""
if utils.calling_databricks():
return get_by_rootRunId(client, runs)
return get_nested_runs_by_rootRunId(client, runs)
else:
#_logger.warning(f"OSS MLflow nested run export not yet supported")
#return runs
from . import oss_nested_runs_utils
descendant_runs = oss_nested_runs_utils.get_descendant_runs(client, runs)
return runs + descendant_runs
return runs + oss_nested_runs_utils.get_nested_runs(client, runs)


def get_by_rootRunId(client, runs):
def get_nested_runs_by_rootRunId(client, runs):
"""
Return list of nested run descendants (includes the root run).
Unlike Databricks MLflow, OSS MLflow does not add the 'mlflow.rootRunId' tag to child runs.
Expand Down
42 changes: 42 additions & 0 deletions mlflow_export_import/experiment/oss_nested_runs_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from mlflow_export_import.common.iterators import SearchRunsIterator


def get_nested_runs(client, runs, parent_runs=None):
nested_runs = []
for run in runs:
nested_runs += _get_nested_runs_for_run(client, run, parent_runs)
return nested_runs

def get_nested_runs_for_experiment(client, experiment_id):
filter = f"tags.mlflow.parentRunId like '%'"
return list(SearchRunsIterator(client, experiment_id, filter=filter))


def _get_nested_runs_for_run(client, run, parent_runs=None):
nested_runs = _build_nested_runs(client, run.info.experiment_id, parent_runs)
run_ids = _get_run_ids(run.info.run_id, nested_runs)
return [ client.get_run(run_id) for run_id in run_ids ]

def _get_run_ids(root_id, nested_runs):
nested_run_ids = nested_runs.get(root_id)
if not nested_run_ids:
return set()
all_nested_run_ids = nested_run_ids
for run_id in nested_run_ids:
_nested_run_ids = _get_run_ids(run_id, nested_runs)
if _nested_run_ids:
all_nested_run_ids += _nested_run_ids
return set(all_nested_run_ids)

def _build_nested_runs(client, experiment_id, parent_runs=None):
"""
Flat dict of all descendant run IDs and their child runs
dict: run_id: list of run_id's child runs (per mlflow.parentRunId tag)
"""
if not parent_runs:
parent_runs = get_nested_runs_for_experiment(client, experiment_id)
dct = { run.info.run_id:run.data.tags["mlflow.parentRunId"] for run in parent_runs }
nested_runs = {}
for run_id,parent_id in dct.items():
nested_runs.setdefault(parent_id, []).append(run_id)
return nested_runs

0 comments on commit a572e94

Please sign in to comment.