Skip to content

Commit

Permalink
Issue #190: test_experiments.py: added tests for specific run_ids list
Browse files Browse the repository at this point in the history
  • Loading branch information
amesar committed Jul 12, 2024
1 parent afb027d commit f134148
Showing 1 changed file with 87 additions and 4 deletions.
91 changes: 87 additions & 4 deletions tests/open_source/test_experiments.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import mlflow
from mlflow.entities import ViewType
from mlflow_export_import.experiment.export_experiment import export_experiment
from mlflow_export_import.experiment.import_experiment import import_experiment
from tests.open_source.oss_utils_test import create_simple_run, init_output_dirs, mk_dst_experiment_name
from tests.open_source.oss_utils_test import _create_simple_run
from tests.open_source.oss_utils_test import create_test_experiment
from tests.open_source.oss_utils_test import (
create_simple_run, _create_simple_run,
create_experiment, create_test_experiment,
mk_dst_experiment_name,
init_output_dirs)
from tests.compare_utils import compare_runs, compare_experiment_tags
from tests.open_source.init_tests import mlflow_context


# == Setup

def _init_exp_test(mlflow_context, import_source_tags=False):
Expand Down Expand Up @@ -56,6 +60,7 @@ def test_exp_with_source_tags(mlflow_context):
_compare_experiments(exp1, exp2, True)
compare_runs(mlflow_context, run1, run2, import_source_tags=True)


# == Test export/import deleted runs

def test_export_deleted_runs(mlflow_context):
Expand Down Expand Up @@ -92,7 +97,7 @@ def test_export_deleted_runs(mlflow_context):
assert len(runs2) == 3


# == Test start_date filter
# == Test start_date filter

def test_filter_run_no_start_date(mlflow_context):
_run_test_run_start_date(mlflow_context, 0)
Expand Down Expand Up @@ -157,3 +162,81 @@ def _fmt_utc_time_before(seconds_before):
def _fmt_utc_time_now():
from datetime import timezone
return datetime.now(timezone.utc).strftime(TS_FORMAT)


# == Test export of multiple runs

def test_exp_with_multiple_runs(mlflow_context):
client1, client2 = mlflow_context.client_src, mlflow_context.client_dst
init_output_dirs(mlflow_context.output_dir)
exp1 = create_experiment(client1)
mlflow.set_experiment(exp1.name)

num_runs = 4
for j in range(num_runs):
run = _create_simple_run(client1, run_name=f"run_{j}")
runs1 = client1.search_runs(exp1.experiment_id)
assert len(runs1) == num_runs

runs1 = [ runs1[0], runs1[2] ]
run_ids = [ run.info.run_id for run in runs1 ]

export_experiment(
mlflow_client = client1,
experiment_id_or_name = exp1.name,
run_ids = run_ids,
output_dir = mlflow_context.output_dir
)

exp_name2 = mk_dst_experiment_name(exp1.name)
exp2 = import_experiment(
mlflow_client = client2,
experiment_name = exp_name2,
input_dir = mlflow_context.output_dir
)
exp2 = client2.get_experiment_by_name(exp_name2)
runs2 = client2.search_runs(exp2.experiment_id)
assert len(runs2) == len(run_ids)

runs1 = sorted(runs1, key=lambda run: run.info.run_name)
runs2 = sorted(runs2, key=lambda run: run.info.run_name)

run_names1 = [run.info.run_name for run in runs1]
run_names2 = [run.info.run_name for run in runs2]
assert len(run_names1) == len(run_names2)

for run1,run2 in zip(runs1,runs2):
compare_runs(mlflow_context, run1, run2)

def test_exp_with_multiple_runs_nonexistent_run(mlflow_context):
client1, client2 = mlflow_context.client_src, mlflow_context.client_dst
init_output_dirs(mlflow_context.output_dir)
exp1 = create_experiment(client1)
mlflow.set_experiment(exp1.name)

num_runs = 4
for j in range(num_runs):
_create_simple_run(client1, run_name=f"run_{j}")
runs1 = client1.search_runs(exp1.experiment_id)
assert len(runs1) == num_runs

run1_ok = runs1[1]
run_ids = [ "foo", run1_ok ]

export_experiment(
mlflow_client = client1,
experiment_id_or_name = exp1.name,
run_ids = run_ids,
output_dir = mlflow_context.output_dir
)

exp_name2 = mk_dst_experiment_name(exp1.name)
exp2 = import_experiment(
mlflow_client = client2,
experiment_name = exp_name2,
input_dir = mlflow_context.output_dir
)
exp2 = client2.get_experiment_by_name(exp_name2)
runs2 = client2.search_runs(exp2.experiment_id)
assert len(runs2) == 1
compare_runs(mlflow_context, run1_ok, runs2[0])

0 comments on commit f134148

Please sign in to comment.