diff --git a/e2e_tests/tests/cluster/test_log_policies.py b/e2e_tests/tests/cluster/test_log_policies.py index 8a7aa8da7fe..a500e4a341d 100644 --- a/e2e_tests/tests/cluster/test_log_policies.py +++ b/e2e_tests/tests/cluster/test_log_policies.py @@ -2,12 +2,10 @@ from determined.common.api import bindings from determined.experimental import client -from tests import api_utils -from tests import experiment as exp -from tests.experiment import noop -from tests import detproc +from tests import api_utils, detproc from tests import experiment as exp from tests.cluster import utils +from tests.experiment import noop @pytest.mark.e2e_cpu @@ -50,7 +48,7 @@ def test_log_policy_exclude_node_k8s(should_match: bool) -> None: assert agents[0].slots is not None config = { - "log_policies": [{"pattern": regex, "action": {"type": "exclude_node"}}], + "log_policies": [{"pattern": regex, "actions": [{"type": "exclude_node"}]}], "resources": {"slots_per_trial": len(agents[0].slots)}, "max_restarts": 1, } @@ -93,7 +91,7 @@ def test_log_policy_exclude_node_single_agent(should_match: bool) -> None: assert agents[0].slots is not None config = { - "log_policies": [{"pattern": regex, "action": {"type": "exclude_node"}}], + "log_policies": [{"pattern": regex, "actions": [{"type": "exclude_node"}]}], "resources": {"slots_per_trial": len(agents[0].slots)}, "max_restarts": 1, } @@ -136,7 +134,7 @@ def test_log_policy_exclude_slurm(should_match: bool) -> None: regex = r"(.*) this should not match (.*)" config = { - "log_policies": [{"pattern": regex, "action": {"type": "exclude_node"}}], + "log_policies": [{"pattern": regex, "actions": [{"type": "exclude_node"}]}], "max_restarts": 1, } exp_ref = noop.create_experiment(sess, [noop.Exit(7)], config=config) @@ -173,7 +171,6 @@ def test_log_signal(should_match: bool) -> None: exp_ref = noop.create_experiment(sess, [noop.Exit(7)], config=config) assert exp_ref.wait(interval=0.01) == client.ExperimentState.ERROR - searchRes = utils.get_run_by_exp_id(sess, exp_ref.id) runSignal = searchRes.runs[0].logSignal @@ -212,8 +209,17 @@ def test_signal_clear_after_exp_continue() -> None: assert runSignal == expected_signal assert trialSignal == expected_signal - - detproc.check_call(sess, ["det", "e", "continue", str(exp_ref.id), "--config", "hyperparameters.crash_on_startup=false"]) + detproc.check_call( + sess, + [ + "det", + "e", + "continue", + str(exp_ref.id), + "--config", + "hyperparameters.crash_on_startup=false", + ], + ) exp.wait_for_experiment_state(sess, exp_ref.id, bindings.experimentv1State.COMPLETED) searchRes = utils.get_run_by_exp_id(sess, exp_ref.id) diff --git a/e2e_tests/tests/cluster/utils.py b/e2e_tests/tests/cluster/utils.py index f71e1cb290e..9bef663f01f 100644 --- a/e2e_tests/tests/cluster/utils.py +++ b/e2e_tests/tests/cluster/utils.py @@ -200,7 +200,7 @@ def set_master_port(config: str) -> None: conf.MASTER_PORT = port -def get_run_by_exp_id(sess, exp_id) -> int: +def get_run_by_exp_id(sess: api.Session, exp_id: int) -> bindings.v1SearchRunsResponse: return bindings.post_SearchRuns( sess, body=bindings.v1SearchRunsRequest(