From 95b3681d3e602a8fc3855598579f11d118d75a86 Mon Sep 17 00:00:00 2001 From: Taekyung Heo Date: Thu, 7 Nov 2024 12:09:54 -0500 Subject: [PATCH] Test hook support (#263) --- conf/common/test_scenario/nccl_test.toml | 4 + conf/hook/nccl_test.toml | 22 +++ conf/hook/test/nccl_test_all_gather.toml | 33 ++++ src/cloudai/_core/command_gen_strategy.py | 26 ++++ src/cloudai/_core/test_scenario.py | 2 + src/cloudai/_core/test_scenario_parser.py | 42 ++++- src/cloudai/_core/test_template.py | 34 ++++ src/cloudai/cli/handlers.py | 43 +++++- src/cloudai/parser.py | 78 ++++++++-- .../jax_toolbox/slurm_command_gen_strategy.py | 95 +----------- .../nccl_test/slurm_command_gen_strategy.py | 4 + .../strategy/slurm_command_gen_strategy.py | 146 ++++++++++++++---- src/cloudai/test_definitions/gpt.py | 5 +- src/cloudai/test_definitions/grok.py | 5 +- src/cloudai/test_definitions/jax_toolbox.py | 34 +--- ...t-no-pretest.sbatch => gpt-no-hook.sbatch} | 4 +- ...gpt-pretest.sbatch => gpt-pre-test.sbatch} | 40 +---- ...-no-pretest.sbatch => grok-no-hook.sbatch} | 4 +- ...ok-pretest.sbatch => grok-pre-test.sbatch} | 40 +---- tests/ref_data/nccl.sbatch | 21 +-- tests/ref_data/sleep.sbatch | 4 +- tests/ref_data/ucc.sbatch | 10 +- .../test_common_slurm_command_gen_strategy.py | 132 +++++++++++++++- ..._jax_toolbox_slurm_command_gen_strategy.py | 105 +------------ tests/test_acceptance.py | 47 ++++-- tests/test_parser.py | 77 ++++++++- tests/test_test_scenario.py | 2 +- 27 files changed, 653 insertions(+), 406 deletions(-) create mode 100644 conf/hook/nccl_test.toml create mode 100644 conf/hook/test/nccl_test_all_gather.toml rename tests/ref_data/{gpt-no-pretest.sbatch => gpt-no-hook.sbatch} (91%) rename tests/ref_data/{gpt-pretest.sbatch => gpt-pre-test.sbatch} (52%) rename tests/ref_data/{grok-no-pretest.sbatch => grok-no-hook.sbatch} (95%) rename tests/ref_data/{grok-pretest.sbatch => grok-pre-test.sbatch} (67%) diff --git a/conf/common/test_scenario/nccl_test.toml b/conf/common/test_scenario/nccl_test.toml index f6ccf02c..15064561 100644 --- a/conf/common/test_scenario/nccl_test.toml +++ b/conf/common/test_scenario/nccl_test.toml @@ -15,6 +15,10 @@ # limitations under the License. name = "nccl-test" + +pre_test = "nccl_test" +post_test = "nccl_test" + [[Tests]] id = "Tests.1" test_name = "nccl_test_all_reduce" diff --git a/conf/hook/nccl_test.toml b/conf/hook/nccl_test.toml new file mode 100644 index 00000000..53349c43 --- /dev/null +++ b/conf/hook/nccl_test.toml @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name = "nccl_test" + +[[Tests]] +id = "Tests.1" +test_name = "nccl_test_all_gather" +time_limit = "00:20:00" diff --git a/conf/hook/test/nccl_test_all_gather.toml b/conf/hook/test/nccl_test_all_gather.toml new file mode 100644 index 00000000..4fec288a --- /dev/null +++ b/conf/hook/test/nccl_test_all_gather.toml @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name = "nccl_test_all_gather" +description = "all_gather" +test_template_name = "NcclTest" + +[cmd_args] +"subtest_name" = "all_gather_perf_mpi" +"ngpus" = "1" +"minbytes" = "128" +"maxbytes" = "4G" +"iters" = "100" +"warmup_iters" = "50" + +[extra_cmd_args] +"--stepfactor" = "2" + +[extra_env_vars] +"NCCL_TEST_SPLIT_MASK" = "0x7" diff --git a/src/cloudai/_core/command_gen_strategy.py b/src/cloudai/_core/command_gen_strategy.py index 16bd04f9..9c8bb389 100644 --- a/src/cloudai/_core/command_gen_strategy.py +++ b/src/cloudai/_core/command_gen_strategy.py @@ -39,3 +39,29 @@ def gen_exec_command(self, tr: TestRun) -> str: str: The generated execution command. """ pass + + @abstractmethod + def gen_srun_command(self, tr: TestRun) -> str: + """ + Generate the Slurm srun command for a test based on the given parameters. + + Args: + tr (TestRun): Contains the test and its run-specific configurations. + + Returns: + str: The generated Slurm srun command. + """ + pass + + @abstractmethod + def gen_srun_success_check(self, tr: TestRun) -> str: + """ + Generate the Slurm success check command to verify if a test run was successful. + + Args: + tr (TestRun): Contains the test and its run-specific configurations. + + Returns: + str: The generated command to check the success of the test run. + """ + pass diff --git a/src/cloudai/_core/test_scenario.py b/src/cloudai/_core/test_scenario.py index 3a60c036..39f1bd21 100644 --- a/src/cloudai/_core/test_scenario.py +++ b/src/cloudai/_core/test_scenario.py @@ -58,6 +58,8 @@ class TestRun: weight: float = 0.0 ideal_perf: float = 1.0 dependencies: dict[str, TestDependency] = field(default_factory=dict) + pre_test: Optional["TestScenario"] = None + post_test: Optional["TestScenario"] = None def __hash__(self) -> int: return hash(self.name + self.test.name + str(self.iterations) + str(self.current_iteration)) diff --git a/src/cloudai/_core/test_scenario_parser.py b/src/cloudai/_core/test_scenario_parser.py index 0f0e0bae..d111d002 100644 --- a/src/cloudai/_core/test_scenario_parser.py +++ b/src/cloudai/_core/test_scenario_parser.py @@ -54,6 +54,8 @@ class _TestScenarioTOML(BaseModel): name: str job_status_check: bool = True tests: list[_TestRunTOML] = Field(alias="Tests", min_length=1) + pre_test: Optional[str] = None + post_test: Optional[str] = None @model_validator(mode="after") def check_no_self_dependency(self): @@ -99,9 +101,10 @@ class TestScenarioParser: __test__ = False - def __init__(self, file_path: Path, test_mapping: Dict[str, Test]) -> None: + def __init__(self, file_path: Path, test_mapping: Dict[str, Test], hook_mapping: Dict[str, TestScenario]) -> None: self.file_path = file_path self.test_mapping = test_mapping + self.hook_mapping = hook_mapping def parse(self) -> TestScenario: """ @@ -136,8 +139,31 @@ def _parse_data(self, data: Dict[str, Any]) -> TestScenario: total_weight = sum(tr.weight for tr in ts_model.tests) normalized_weight = 0 if total_weight == 0 else 100 / total_weight + pre_test, post_test = None, None + if ts_model.pre_test: + pre_test = self.hook_mapping.get(ts_model.pre_test) + if pre_test is None: + msg = ( + f"Pre-test hook '{ts_model.pre_test}' not found in hook mapping. " + "A corresponding hook should exist under 'conf/hook'. " + "Ensure that a proper hook directory is set under the working directory." + ) + logging.error(msg) + raise TestScenarioParsingError(msg) + + if ts_model.post_test: + post_test = self.hook_mapping.get(ts_model.post_test) + if post_test is None: + msg = ( + f"Post-test hook '{ts_model.post_test}' not found in hook mapping. " + "A corresponding hook should exist under 'conf/hook'. " + "Ensure that a proper hook directory is set under the working directory." + ) + logging.error(msg) + raise TestScenarioParsingError(msg) + test_runs_by_id: dict[str, TestRun] = { - tr.id: self._create_test_run(tr, normalized_weight) for tr in ts_model.tests + tr.id: self._create_test_run(tr, normalized_weight, pre_test, post_test) for tr in ts_model.tests } tests_data: dict[str, _TestRunTOML] = {tr.id: tr for tr in ts_model.tests} @@ -153,13 +179,21 @@ def _parse_data(self, data: Dict[str, Any]) -> TestScenario: job_status_check=ts_model.job_status_check, ) - def _create_test_run(self, test_info: _TestRunTOML, normalized_weight: float) -> TestRun: + def _create_test_run( + self, + test_info: _TestRunTOML, + normalized_weight: float, + pre_test: Optional[TestScenario] = None, + post_test: Optional[TestScenario] = None, + ) -> TestRun: """ Create a section-specific Test object by copying from the test mapping. Args: test_info (Dict[str, Any]): Information of the test. normalized_weight (float): Normalized weight for the test. + pre_test (Optional[TestScenario]): TestScenario object representing the pre-test sequence. + post_test (Optional[TestScenario]): TestScenario object representing the post-test sequence. Returns: Test: Copied and updated Test object for the section. @@ -192,5 +226,7 @@ def _create_test_run(self, test_info: _TestRunTOML, normalized_weight: float) -> sol=test_info.sol, weight=test_info.weight * normalized_weight, ideal_perf=test_info.ideal_perf, + pre_test=pre_test, + post_test=post_test, ) return tr diff --git a/src/cloudai/_core/test_template.py b/src/cloudai/_core/test_template.py index e1f98e9e..d45fb544 100644 --- a/src/cloudai/_core/test_template.py +++ b/src/cloudai/_core/test_template.py @@ -93,6 +93,40 @@ def gen_exec_command(self, tr: TestRun) -> str: ) return self.command_gen_strategy.gen_exec_command(tr) + def gen_srun_command(self, tr: TestRun) -> str: + """ + Generate an Slurm srun command for a test using the provided command generation strategy. + + Args: + tr (TestRun): Contains the test and its run-specific configurations. + + Returns: + str: The generated Slurm srun command. + """ + if self.command_gen_strategy is None: + raise ValueError( + "command_gen_strategy is missing. Ensure the strategy is registered in the Registry " + "by calling the appropriate registration function for the system type." + ) + return self.command_gen_strategy.gen_srun_command(tr) + + def gen_srun_success_check(self, tr: TestRun) -> str: + """ + Generate a Slurm success check command for a test using the provided command generation strategy. + + Args: + tr (TestRun): Contains the test and its run-specific configurations. + + Returns: + str: The generated command to check the success of the test run. + """ + if self.command_gen_strategy is None: + raise ValueError( + "command_gen_strategy is missing. Ensure the strategy is registered in the Registry " + "by calling the appropriate registration function for the system type." + ) + return self.command_gen_strategy.gen_srun_success_check(tr) + def gen_json(self, tr: TestRun) -> Dict[Any, Any]: """ Generate a JSON string representing the Kubernetes job specification for this test using this template. diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index 6105bc24..30fb7a90 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -23,6 +23,8 @@ from cloudai import Installable, Parser, Registry, ReportGenerator, Runner, System +from ..parser import HOOK_ROOT + def handle_install_and_uninstall(args: argparse.Namespace) -> int: """ @@ -212,7 +214,11 @@ def verify_test_configs(test_tomls: List[Path]) -> int: def verify_test_scenarios( - scenario_tomls: List[Path], test_tomls: list[Path], system_config: Optional[Path] = None + scenario_tomls: List[Path], + test_tomls: list[Path], + hook_tomls: List[Path], + hook_test_tomls: list[Path], + system_config: Optional[Path] = None, ) -> int: system = Mock(spec=System) if system_config: @@ -225,7 +231,9 @@ def verify_test_scenarios( logging.debug(f"Verifying Test Scenario: {scenario_file}...") try: tests = Parser.parse_tests(test_tomls, system) - Parser.parse_test_scenario(scenario_file, {t.name: t for t in tests}) + hook_tests = Parser.parse_tests(hook_test_tomls, system) + hooks = Parser.parse_hooks(hook_tomls, {t.name: t for t in hook_tests}) + Parser.parse_test_scenario(scenario_file, {t.name: t for t in tests}, hooks) except Exception: nfailed += 1 @@ -243,6 +251,9 @@ def handle_verify_all_configs(args: argparse.Namespace) -> int: if err: return err + err, hook_tomls = expand_file_list(HOOK_ROOT, glob="**/*.toml") + tomls += hook_tomls + files = load_tomls_by_type(tomls) test_tomls = files["test"] @@ -259,7 +270,9 @@ def handle_verify_all_configs(args: argparse.Namespace) -> int: if files["test"]: nfailed += verify_test_configs(files["test"]) if files["scenario"]: - nfailed += verify_test_scenarios(files["scenario"], test_tomls, args.system_config) + nfailed += verify_test_scenarios( + files["scenario"], test_tomls, files["hook"], files["hook_test"], args.system_config + ) if files["unknown"]: logging.error(f"Unknown configuration files: {[str(f) for f in files['unknown']]}") nfailed += len(files["unknown"]) @@ -273,9 +286,31 @@ def handle_verify_all_configs(args: argparse.Namespace) -> int: def load_tomls_by_type(tomls: List[Path]) -> dict[str, List[Path]]: - files: dict[str, List[Path]] = {"system": [], "test": [], "scenario": [], "unknown": []} + files: dict[str, List[Path]] = { + "system": [], + "test": [], + "scenario": [], + "hook_test": [], + "hook": [], + "unknown": [], + } for toml_file in tomls: content = toml_file.read_text() + + is_in_hook_root = False + try: + toml_file.relative_to(HOOK_ROOT) + is_in_hook_root = True + except ValueError: + pass + + if is_in_hook_root: + if "test" in toml_file.parts: + files["hook_test"].append(toml_file) + else: + files["hook"].append(toml_file) + continue + if "scheduler =" in content: files["system"].append(toml_file) elif "test_template_name =" in content: diff --git a/src/cloudai/parser.py b/src/cloudai/parser.py index a627b312..6f59f9a3 100644 --- a/src/cloudai/parser.py +++ b/src/cloudai/parser.py @@ -34,6 +34,9 @@ format_validation_error, ) +HOOK_ROOT = Path("conf/hook") +HOOK_TEST_ROOT = HOOK_ROOT / "test" + class Parser: """Main parser for parsing all types of configurations.""" @@ -49,14 +52,21 @@ def __init__(self, system_config_path: Path) -> None: self.system_config_path = system_config_path def parse( - self, test_path: Path, test_scenario_path: Optional[Path] = None + self, + test_path: Path, + test_scenario_path: Optional[Path] = None, ) -> Tuple[System, List[Test], Optional[TestScenario]]: """ Parse configurations for system, test templates, and test scenarios. - Returns - Tuple[System, List[TestTemplate], TestScenario]: A tuple containing the system object, a list of test - template objects, and the test scenario object. + Args: + test_path (Path): The file path for tests. + test_scenario_path (Optional[Path]): The file path for the main test scenario. + If None, all tests are included. + + Returns: + Tuple[System, List[Test], Optional[TestScenario]]: A tuple containing the system object, a list of filtered + test template objects, and the main test scenario object if provided. """ if not test_path.exists(): raise FileNotFoundError(f"Test path '{test_path}' not found.") @@ -71,24 +81,64 @@ def parse( except TestConfigParsingError: exit(1) # exit right away to keep error message readable for users - logging.debug(f"Parsed {len(tests)} tests: {[t.name for t in tests]}") - test_mapping = {t.name: t for t in tests} + if not HOOK_ROOT.exists(): + logging.debug(f"HOOK_ROOT path '{HOOK_ROOT}' does not exist.") + + try: + hook_tests = ( + self.parse_tests(list(HOOK_TEST_ROOT.glob("*.toml")), system) if HOOK_TEST_ROOT.exists() else [] + ) + except TestConfigParsingError: + exit(1) # exit right away to keep error message readable for users + + if not test_scenario_path: + all_tests = list({test.name: test for test in tests + hook_tests}.values()) + return system, all_tests, None - filtered_tests = tests - test_scenario: Optional[TestScenario] = None - if test_scenario_path: + test_mapping = {t.name: t for t in tests} + hook_test_scenario_mapping = {} + if HOOK_ROOT.exists() and list(HOOK_ROOT.glob("*.toml")): try: - test_scenario = self.parse_test_scenario(test_scenario_path, test_mapping) + hook_test_scenario_mapping = self.parse_hooks( + list(HOOK_ROOT.glob("*.toml")), {t.name: t for t in hook_tests} + ) except TestScenarioParsingError: exit(1) # exit right away to keep error message readable for users - scenario_tests = set(tr.test.name for tr in test_scenario.test_runs) - filtered_tests = [t for t in tests if t.name in scenario_tests] + + try: + test_scenario = self.parse_test_scenario(test_scenario_path, test_mapping, hook_test_scenario_mapping) + except TestScenarioParsingError: + exit(1) # exit right away to keep error message readable for users + + scenario_tests = {tr.test.name for tr in test_scenario.test_runs} + hook_scenario_tests = { + tr.test.name for hook_scenario in hook_test_scenario_mapping.values() for tr in hook_scenario.test_runs + } + + relevant_test_names = scenario_tests.union(hook_scenario_tests) + filtered_tests = [t for t in tests if t.name in relevant_test_names] + hook_tests + filtered_tests = list({test.name: test for test in filtered_tests}.values()) return system, filtered_tests, test_scenario @staticmethod - def parse_test_scenario(test_scenario_path: Path, test_mapping: Dict[str, Test]) -> TestScenario: - test_scenario_parser = TestScenarioParser(test_scenario_path, test_mapping) + def parse_hooks(hook_tomls: List[Path], test_mapping: Dict[str, Test]) -> Dict[str, TestScenario]: + hook_mapping = {} + for hook_test_scenario_path in hook_tomls: + hook_scenario = Parser.parse_test_scenario(hook_test_scenario_path, test_mapping) + hook_mapping[hook_scenario.name] = hook_scenario + return hook_mapping + + @staticmethod + def parse_test_scenario( + test_scenario_path: Path, + test_mapping: Dict[str, Test], + hook_mapping: Optional[Dict[str, TestScenario]] = None, + ) -> TestScenario: + if hook_mapping is None: + hook_mapping = {} + + test_scenario_parser = TestScenarioParser(test_scenario_path, test_mapping, hook_mapping) test_scenario = test_scenario_parser.parse() return test_scenario diff --git a/src/cloudai/schema/test_template/jax_toolbox/slurm_command_gen_strategy.py b/src/cloudai/schema/test_template/jax_toolbox/slurm_command_gen_strategy.py index 9e20aae5..fafb407d 100644 --- a/src/cloudai/schema/test_template/jax_toolbox/slurm_command_gen_strategy.py +++ b/src/cloudai/schema/test_template/jax_toolbox/slurm_command_gen_strategy.py @@ -146,31 +146,17 @@ def _parse_slurm_args( return base_args - def generate_srun_command( + def _gen_srun_command( self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, Any], tr: TestRun ) -> str: self._create_run_script(env_vars, cmd_args, tr.test.extra_cmd_args) commands = [] - - run_pre_test = cmd_args.get("pre_test.enable", False) - - if run_pre_test: - output_path = Path(cmd_args["output_path"]).resolve() / "output_pretest-%j-%n-%t.txt" - error_path = Path(cmd_args["output_path"]).resolve() / "error_pretest-%j-%n-%t.txt" - commands.append(self._generate_pre_test_command(cmd_args, output_path, error_path)) - commands.append(self._generate_pre_test_check_command(cmd_args, output_path)) - commands.append('if [ "$PRE_TEST_SUCCESS" = true ]; then') - load_container = cmd_args.get("load_container", False) if load_container: commands += self._generate_container_load_command(slurm_args) - commands += self._generate_run_command(slurm_args) - if run_pre_test: - commands.append("fi") - return "\n".join(commands) def _create_run_script( @@ -341,85 +327,6 @@ def _create_pgo_nsys_converter_command(self, stage: str, cmd_args: Dict[str, str ["", 'if [ "$SLURM_NODEID" -eq 0 ] && [ "$SLURM_PROCID" -eq 0 ]; then', f" {command}", "fi"] ) - def _generate_pre_test_command(self, cmd_args: Dict[str, Any], output_path: Path, error_path: Path) -> str: - """ - Generate the pre-test command for running a test. - - This method constructs the pre-test command based on the command-line - arguments provided. - - Args: - cmd_args (Dict[str, Any]): A dictionary containing command arguments. - output_path (Path): The path to the output file. - error_path (Path): The path to the error file. - - Returns: - str: The generated pre-test command. - """ - nccl_test_prefix = "pre_test.nccl_test." - nccl_test = {} - - for key, value in cmd_args.items(): - if key.startswith(nccl_test_prefix): - flag_name = key[len(nccl_test_prefix) :] - nccl_test[flag_name] = value - pre_test_command_parts = [ - "srun", - "--mpi=pmix", - f"-N {nccl_test.get('num_nodes', 2)}", - f"-o {output_path}", - f"-e {error_path}", - f"--container-image={nccl_test.get('docker_image_url', 'nvcr.io/nvidia/pytorch:24.02-py3')}", - f"/usr/local/bin/{nccl_test.get('subtest_name', 'all_gather_perf_mpi')}", - f"--nthreads {nccl_test.get('nthreads', 1)}", - f"--ngpus {nccl_test.get('ngpus', 1)}", - f"--minbytes {nccl_test.get('minbytes', '32M')}", - f"--maxbytes {nccl_test.get('maxbytes', '16G')}", - f"--stepbytes {nccl_test.get('stepbytes', '1M')}", - f"--op {nccl_test.get('op', 'sum')}", - f"--datatype {nccl_test.get('datatype', 'float')}", - f"--root {nccl_test.get('root', 0)}", - f"--iters {nccl_test.get('iters', 20)}", - f"--warmup_iters {nccl_test.get('warmup_iters', 5)}", - f"--agg_iters {nccl_test.get('agg_iters', 1)}", - f"--average {nccl_test.get('average', 1)}", - f"--parallel_init {nccl_test.get('parallel_init', 0)}", - f"--check {nccl_test.get('check', 1)}", - f"--blocking {nccl_test.get('blocking', 0)}", - f"--cudagraph {nccl_test.get('cudagraph', 0)}", - f"--stepfactor {nccl_test.get('stepfactor', 2)}", - ] - return " \\\n".join(pre_test_command_parts) - - def _generate_pre_test_check_command(self, cmd_args: Dict[str, str], output_path: Path) -> str: - """ - Generate the command for pre-test check. - - This method generates the command that checks the output of the pre-test to determine if the main test should - be run. - - Args: - cmd_args (Dict[str, str]): Command-line arguments for the job. - output_path (str): The path to the output file. - - Returns: - str: The generated command for pre-test check. - """ - pretest_output_files = str(Path(output_path).parent / "output_pretest-*.txt") - keyword = cmd_args.get("keyword", "Avg bus bandwidth") - - return "\n".join( - [ - f'PRETEST_OUTPUT_FILES="{pretest_output_files}"', - f'keyword="{keyword}"', - "", - "# Use grep to search for the keyword in the files", - 'if grep -q "$keyword" $PRETEST_OUTPUT_FILES; then', - " PRE_TEST_SUCCESS=true", - "fi", - ] - ) - def _generate_container_load_command(self, slurm_args: Dict[str, Any]) -> List[str]: """Generate the command for loading a container.""" container_image = slurm_args.get("image_path") diff --git a/src/cloudai/schema/test_template/nccl_test/slurm_command_gen_strategy.py b/src/cloudai/schema/test_template/nccl_test/slurm_command_gen_strategy.py index d1d8d5fc..cc404bd7 100644 --- a/src/cloudai/schema/test_template/nccl_test/slurm_command_gen_strategy.py +++ b/src/cloudai/schema/test_template/nccl_test/slurm_command_gen_strategy.py @@ -71,3 +71,7 @@ def generate_test_command(self, env_vars: Dict[str, str], cmd_args: Dict[str, st srun_command_parts.append(tr.test.extra_cmd_args) return srun_command_parts + + def gen_srun_success_check(self, tr: TestRun) -> str: + output_file = Path(tr.output_path) / "stdout.txt" + return f'grep -q "Avg bus bandwidth" {output_file} && echo 1 || echo 0' diff --git a/src/cloudai/systems/slurm/strategy/slurm_command_gen_strategy.py b/src/cloudai/systems/slurm/strategy/slurm_command_gen_strategy.py index 7c5f69c3..ee8a463a 100644 --- a/src/cloudai/systems/slurm/strategy/slurm_command_gen_strategy.py +++ b/src/cloudai/systems/slurm/strategy/slurm_command_gen_strategy.py @@ -18,7 +18,7 @@ from pathlib import Path from typing import Any, Dict, List -from cloudai import CommandGenStrategy, TestRun +from cloudai import CommandGenStrategy, TestRun, TestScenario from cloudai.systems import SlurmSystem @@ -51,22 +51,37 @@ def __init__(self, system: SlurmSystem, cmd_args: Dict[str, Any]) -> None: self.docker_image_url = self.cmd_args.get("docker_image_url", "") - def _format_env_vars(self, env_vars: Dict[str, Any]) -> str: - """ - Format environment variables for inclusion in a batch script. + def gen_exec_command(self, tr: TestRun) -> str: + env_vars = self._override_env_vars(self.system.global_env_vars, tr.test.extra_env_vars) + cmd_args = self._override_cmd_args(self.default_cmd_args, tr.test.cmd_args) + slurm_args = self._parse_slurm_args(tr.test.test_template.__class__.__name__, env_vars, cmd_args, tr) - Args: - env_vars (Dict[str, Any]): Environment variables to format. + srun_command = self._gen_srun_command(slurm_args, env_vars, cmd_args, tr) + command_list = [] + indent = "" - Returns: - str: A string representation of the formatted environment variables. - """ - formatted_vars = [] - for key in sorted(env_vars.keys()): - value = env_vars[key] - formatted_value = str(value["default"]) if isinstance(value, dict) and "default" in value else str(value) - formatted_vars.append(f"export {key}={formatted_value}") - return "\n".join(formatted_vars) + if tr.pre_test: + pre_test_command = self.gen_pre_test(tr.pre_test, tr.output_path) + command_list = [pre_test_command, "if [ $PRE_TEST_SUCCESS -eq 1 ]; then"] + indent = " " + + command_list.append(f"{indent}{srun_command}") + + if tr.post_test: + post_test_command = self.gen_post_test(tr.post_test, tr.output_path) + command_list.append(f"{indent}{post_test_command}") + + if tr.pre_test: + command_list.append("fi") + + full_command = "\n".join(command_list).strip() + return self._write_sbatch_script(slurm_args, env_vars, full_command, tr) + + def gen_srun_command(self, tr: TestRun) -> str: + env_vars = self._override_env_vars(self.system.global_env_vars, tr.test.extra_env_vars) + cmd_args = self._override_cmd_args(self.default_cmd_args, tr.test.cmd_args) + slurm_args = self._parse_slurm_args(tr.test.test_template.__class__.__name__, env_vars, cmd_args, tr) + return self._gen_srun_command(slurm_args, env_vars, cmd_args, tr) def _parse_slurm_args( self, job_name_prefix: str, env_vars: Dict[str, str], cmd_args: Dict[str, str], tr: TestRun @@ -108,14 +123,83 @@ def job_name(self, job_name_prefix: str) -> str: job_name = f"{self.system.account}-{job_name_prefix}.{datetime.now().strftime('%Y%m%d_%H%M%S')}" return job_name - def generate_srun_command( + def gen_pre_test(self, pre_test: TestScenario, base_output_path: Path) -> str: + """ + Generate the pre-test command by running all tests defined in the pre-test test scenario. + + Args: + pre_test (TestScenario): The pre-test test scenario containing the tests to be run. + base_output_path (Path): The base output directory path for storing pre-test outputs. + + Returns: + str: A string with all the Slurm srun commands generated for the pre_test. + """ + pre_test_output_dir = base_output_path / "pre_test" + pre_test_output_dir.mkdir(parents=True, exist_ok=True) + + pre_test_commands = [] + success_vars = [] + + for idx, tr in enumerate(pre_test.test_runs): + hook_dir = pre_test_output_dir / tr.test.name + hook_dir.mkdir(parents=True, exist_ok=True) + tr.output_path = hook_dir + + srun_command = tr.test.test_template.gen_srun_command(tr) + srun_command_with_output = srun_command.replace( + "srun ", f"srun --output={hook_dir / 'stdout.txt'} --error={hook_dir / 'stderr.txt'} " + ) + pre_test_commands.append(srun_command_with_output) + + success_var = f"SUCCESS_{idx}" + success_vars.append(success_var) + + success_check_command = tr.test.test_template.gen_srun_success_check(tr) + pre_test_commands.append(f"{success_var}=$({success_check_command})") + + combined_success_var = " && ".join([f"[ ${var} -eq 1 ]" for var in success_vars]) + + pre_test_commands.append(f"PRE_TEST_SUCCESS=$( {combined_success_var} && echo 1 || echo 0 )") + + return "\n".join(pre_test_commands) + + def gen_post_test(self, post_test: TestScenario, base_output_path: Path) -> str: + """ + Generate the post-test command by running all tests defined in the post-test test scenario. + + Args: + post_test (TestScenario): The post-test test scenario containing the tests to be run. + base_output_path (Path): The base output directory path for storing post-test outputs. + + Returns: + str: A string with all the Slurm srun commands generated for the post-test. + """ + post_test_output_dir = base_output_path / "post_test" + post_test_output_dir.mkdir(parents=True, exist_ok=True) + + post_test_commands = [] + + for tr in post_test.test_runs: + hook_dir = post_test_output_dir / tr.test.name + hook_dir.mkdir(parents=True, exist_ok=True) + tr.output_path = hook_dir + + srun_command = tr.test.test_template.gen_srun_command(tr) + srun_command_with_output = srun_command.replace( + "srun ", f"srun --output={hook_dir / 'stdout.txt'} --error={hook_dir / 'stderr.txt'} " + ) + post_test_commands.append(srun_command_with_output) + + return "\n".join(post_test_commands) + + def _gen_srun_command( self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, str], tr: TestRun ) -> str: - srun_command_parts = self.generate_srun_prefix(slurm_args, tr) + srun_command_parts = self.gen_srun_prefix(slurm_args) test_command_parts = self.generate_test_command(env_vars, cmd_args, tr) - return " \\\n".join(srun_command_parts + test_command_parts) + return " ".join(srun_command_parts + test_command_parts) - def generate_srun_prefix(self, slurm_args: Dict[str, Any], tr: TestRun) -> List[str]: + def gen_srun_prefix(self, slurm_args: Dict[str, Any]) -> List[str]: srun_command_parts = ["srun", f"--mpi={self.system.mpi}"] if slurm_args.get("image_path"): srun_command_parts.append(f'--container-image={slurm_args["image_path"]}') @@ -127,13 +211,6 @@ def generate_srun_prefix(self, slurm_args: Dict[str, Any], tr: TestRun) -> List[ return srun_command_parts - def gen_exec_command(self, tr: TestRun) -> str: - env_vars = self._override_env_vars(self.system.global_env_vars, tr.test.extra_env_vars) - cmd_args = self._override_cmd_args(self.default_cmd_args, tr.test.cmd_args) - slurm_args = self._parse_slurm_args(tr.test.test_template.__class__.__name__, env_vars, cmd_args, tr) - srun_command = self.generate_srun_command(slurm_args, env_vars, cmd_args, tr) - return self._write_sbatch_script(slurm_args, env_vars, srun_command, tr) - def generate_test_command(self, env_vars: Dict[str, str], cmd_args: Dict[str, str], tr: TestRun) -> List[str]: return [] @@ -221,3 +298,20 @@ def _append_sbatch_directives( batch_script_content.append( "\nexport SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1)" ) + + def _format_env_vars(self, env_vars: Dict[str, Any]) -> str: + """ + Format environment variables for inclusion in a batch script. + + Args: + env_vars (Dict[str, Any]): Environment variables to format. + + Returns: + str: A string representation of the formatted environment variables. + """ + formatted_vars = [] + for key in sorted(env_vars.keys()): + value = env_vars[key] + formatted_value = str(value["default"]) if isinstance(value, dict) and "default" in value else str(value) + formatted_vars.append(f"export {key}={formatted_value}") + return "\n".join(formatted_vars) diff --git a/src/cloudai/test_definitions/gpt.py b/src/cloudai/test_definitions/gpt.py index ff1e8f1e..353d97fe 100644 --- a/src/cloudai/test_definitions/gpt.py +++ b/src/cloudai/test_definitions/gpt.py @@ -21,7 +21,7 @@ from cloudai import Installable from cloudai.installer.installables import DockerImage -from .jax_toolbox import JaxFdl, JaxToolboxCmdArgs, JaxToolboxTestDefinition, PreTest, SetupFlags, XLAFlags +from .jax_toolbox import JaxFdl, JaxToolboxCmdArgs, JaxToolboxTestDefinition, SetupFlags, XLAFlags class GPTFdl(JaxFdl): @@ -48,7 +48,6 @@ class GPTCmdArgs(JaxToolboxCmdArgs): fdl_config: str fdl: GPTFdl = Field(default_factory=GPTFdl) - pre_test: PreTest = Field(default_factory=PreTest) xla_flags: GPTXLAFlags = Field(default_factory=GPTXLAFlags) setup_flags: GPTSetupFlags = Field(default_factory=GPTSetupFlags) @@ -64,7 +63,7 @@ def cmd_args_dict(self): d = self.cmd_args.model_dump() res = {} for k, v in d.items(): - if k in {"pre_test", "docker_image_url", "load_container", "output_path"}: + if k in {"docker_image_url", "load_container", "output_path"}: res[k] = v else: if k == "xla_flags": diff --git a/src/cloudai/test_definitions/grok.py b/src/cloudai/test_definitions/grok.py index c87c6e44..88a358be 100644 --- a/src/cloudai/test_definitions/grok.py +++ b/src/cloudai/test_definitions/grok.py @@ -21,7 +21,7 @@ from cloudai import Installable from cloudai.installer.installables import DockerImage -from .jax_toolbox import JaxFdl, JaxToolboxCmdArgs, JaxToolboxTestDefinition, PreTest, SetupFlags, XLAFlags +from .jax_toolbox import JaxFdl, JaxToolboxCmdArgs, JaxToolboxTestDefinition, SetupFlags, XLAFlags class GrokFdl(JaxFdl): @@ -77,7 +77,6 @@ class GrokCmdArgs(JaxToolboxCmdArgs): setup_flags: SetupFlags = Field(default_factory=SetupFlags) profile: GrokProfileXLAFlags = Field(default_factory=GrokProfileXLAFlags) perf: GrokPerfXLAFlags = Field(default_factory=GrokPerfXLAFlags) - pre_test: PreTest = Field(default_factory=PreTest) class GrokTestDefinition(JaxToolboxTestDefinition): @@ -97,7 +96,7 @@ def cmd_args_dict(self): if k in {"profile", "perf"}: res.setdefault(f"Grok.{k}", {}) res[f"Grok.{k}"]["XLA_FLAGS"] = v - elif k in {"pre_test", "docker_image_url", "load_container", "output_path"}: + elif k in {"docker_image_url", "load_container", "output_path"}: res[k] = v else: res[f"Grok.{k}"] = v diff --git a/src/cloudai/test_definitions/jax_toolbox.py b/src/cloudai/test_definitions/jax_toolbox.py index 079e5b4e..4593028a 100644 --- a/src/cloudai/test_definitions/jax_toolbox.py +++ b/src/cloudai/test_definitions/jax_toolbox.py @@ -14,12 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Optional -from pydantic import BaseModel, ConfigDict, Field, field_serializer +from pydantic import BaseModel, ConfigDict, field_serializer from cloudai import CmdArgs, TestDefinition -from cloudai.test_definitions.nccl import NCCLCmdArgs class JaxFdl(BaseModel): @@ -54,35 +53,6 @@ def checkpoint_policy_serializer(self, value: str) -> str: return f'\\"{value}\\"' -class NCCLCmdAgrsPreTest(NCCLCmdArgs): - """NCCL pre-test command arguments.""" - - num_nodes: int = 8 - stepfactor: int = 2 - minbytes: str = "8M" - maxbytes: str = "16G" - blocking: int = 1 - - def model_post_init(self, _: Any) -> None: - self.subtest_name = "all_gather_perf_mpi" - self.docker_image_url = "nvcr.io/nvidia/pytorch:24.02-py3" - - -class PreTest(BaseModel): - """Pre-test configuration.""" - - model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) - enable: bool = True - nccl_test: NCCLCmdAgrsPreTest = Field(default_factory=NCCLCmdAgrsPreTest) - - -class NCCLPreTest(BaseModel): - """Pre-test configuration.""" - - model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) - nccl_test: Optional[NCCLCmdAgrsPreTest] = None - - class JaxToolboxCmdArgs(CmdArgs): """JAX Toolbox test command arguments.""" diff --git a/tests/ref_data/gpt-no-pretest.sbatch b/tests/ref_data/gpt-no-hook.sbatch similarity index 91% rename from tests/ref_data/gpt-no-pretest.sbatch rename to tests/ref_data/gpt-no-hook.sbatch index edc4d19c..f01e9222 100644 --- a/tests/ref_data/gpt-no-pretest.sbatch +++ b/tests/ref_data/gpt-no-hook.sbatch @@ -8,7 +8,7 @@ export COMBINE_THRESHOLD=1 export PER_GPU_COMBINE_THRESHOLD=0 export XLA_FLAGS="--xla_gpu_all_gather_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_all_reduce_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_reduce_scatter_combine_threshold_bytes=$PER_GPU_COMBINE_THRESHOLD" - echo "Loading container with srun command" +echo "Loading container with srun command" srun --mpi=none --container-image=https://docker/url --container-name=cont true echo "Running srun command" srun \ @@ -19,4 +19,4 @@ export XLA_FLAGS="--xla_gpu_all_gather_combine_threshold_bytes=$COMBINE_THRESHOL -e __OUTPUT_DIR__/error-%j-%n-%t.txt \ --container-name=cont \ --container-mounts=__OUTPUT_DIR__:/opt/paxml/workspace/ \ - /opt/paxml/workspace/run.sh \ No newline at end of file + /opt/paxml/workspace/run.sh diff --git a/tests/ref_data/gpt-pretest.sbatch b/tests/ref_data/gpt-pre-test.sbatch similarity index 52% rename from tests/ref_data/gpt-pretest.sbatch rename to tests/ref_data/gpt-pre-test.sbatch index 3a64823c..c0f6114f 100644 --- a/tests/ref_data/gpt-pretest.sbatch +++ b/tests/ref_data/gpt-pre-test.sbatch @@ -8,39 +8,11 @@ export COMBINE_THRESHOLD=1 export PER_GPU_COMBINE_THRESHOLD=0 export XLA_FLAGS="--xla_gpu_all_gather_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_all_reduce_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_reduce_scatter_combine_threshold_bytes=$PER_GPU_COMBINE_THRESHOLD" -srun \ ---mpi=pmix \ --N 8 \ --o __OUTPUT_DIR__/output_pretest-%j-%n-%t.txt \ --e __OUTPUT_DIR__/error_pretest-%j-%n-%t.txt \ ---container-image=nvcr.io/nvidia/pytorch:24.02-py3 \ -/usr/local/bin/all_gather_perf_mpi \ ---nthreads 1 \ ---ngpus 1 \ ---minbytes 8M \ ---maxbytes 16G \ ---stepbytes 1M \ ---op sum \ ---datatype float \ ---root 0 \ ---iters 20 \ ---warmup_iters 5 \ ---agg_iters 1 \ ---average 1 \ ---parallel_init 0 \ ---check 1 \ ---blocking 1 \ ---cudagraph 0 \ ---stepfactor 2 -PRETEST_OUTPUT_FILES="__OUTPUT_DIR__/output_pretest-*.txt" -keyword="Avg bus bandwidth" - -# Use grep to search for the keyword in the files -if grep -q "$keyword" $PRETEST_OUTPUT_FILES; then - PRE_TEST_SUCCESS=true -fi -if [ "$PRE_TEST_SUCCESS" = true ]; then - echo "Loading container with srun command" +srun --output=__OUTPUT_DIR__/pre_test/nccl/stdout.txt --error=__OUTPUT_DIR__/pre_test/nccl/stderr.txt --mpi=pmix --container-image=nvcr.io/nvidia/pytorch:24.02-py3 /usr/local/bin/all_reduce_perf_mpi --nthreads 1 --ngpus 1 --minbytes 32M --maxbytes 32M --stepbytes 1M --op sum --datatype float --root 0 --iters 20 --warmup_iters 5 --agg_iters 1 --average 1 --parallel_init 0 --check 1 --blocking 0 --cudagraph 0 +SUCCESS_0=$(grep -q "Avg bus bandwidth" __OUTPUT_DIR__/pre_test/nccl/stdout.txt && echo 1 || echo 0) +PRE_TEST_SUCCESS=$( [ $SUCCESS_0 -eq 1 ] && echo 1 || echo 0 ) +if [ $PRE_TEST_SUCCESS -eq 1 ]; then + echo "Loading container with srun command" srun --mpi=none --container-image=https://docker/url --container-name=cont true echo "Running srun command" srun \ @@ -52,4 +24,4 @@ if [ "$PRE_TEST_SUCCESS" = true ]; then --container-name=cont \ --container-mounts=__OUTPUT_DIR__:/opt/paxml/workspace/ \ /opt/paxml/workspace/run.sh -fi \ No newline at end of file +fi diff --git a/tests/ref_data/grok-no-pretest.sbatch b/tests/ref_data/grok-no-hook.sbatch similarity index 95% rename from tests/ref_data/grok-no-pretest.sbatch rename to tests/ref_data/grok-no-hook.sbatch index a8274477..7e7adfc2 100644 --- a/tests/ref_data/grok-no-pretest.sbatch +++ b/tests/ref_data/grok-no-hook.sbatch @@ -8,7 +8,7 @@ export COMBINE_THRESHOLD=1 export PER_GPU_COMBINE_THRESHOLD=0 export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization --xla_dump_hlo_pass_re=.* --xla_gpu_all_gather_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_all_reduce_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_graph_level=0 --xla_gpu_pgle_profile_file_or_directory_path=/opt/paxml/workspace/pgle_output_profile.pbtxt --xla_gpu_reduce_scatter_combine_threshold_bytes=$PER_GPU_COMBINE_THRESHOLD --xla_gpu_run_post_layout_collective_pipeliner=false --xla_gpu_use_memcpy_local_p2p=false" - echo "Loading container with srun command" +echo "Loading container with srun command" srun --mpi=none --container-image=https://docker/url --container-name=cont true echo "Running srun command" srun \ @@ -19,4 +19,4 @@ export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization --xla_dump_hlo_pass -e __OUTPUT_DIR__/error-%j-%n-%t.txt \ --container-name=cont \ --container-mounts=__OUTPUT_DIR__:/opt/paxml/workspace/ \ - /opt/paxml/workspace/run.sh \ No newline at end of file + /opt/paxml/workspace/run.sh diff --git a/tests/ref_data/grok-pretest.sbatch b/tests/ref_data/grok-pre-test.sbatch similarity index 67% rename from tests/ref_data/grok-pretest.sbatch rename to tests/ref_data/grok-pre-test.sbatch index 0e2672d5..51730bd7 100644 --- a/tests/ref_data/grok-pretest.sbatch +++ b/tests/ref_data/grok-pre-test.sbatch @@ -8,39 +8,11 @@ export COMBINE_THRESHOLD=1 export PER_GPU_COMBINE_THRESHOLD=0 export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization --xla_dump_hlo_pass_re=.* --xla_gpu_all_gather_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_all_reduce_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_graph_level=0 --xla_gpu_pgle_profile_file_or_directory_path=/opt/paxml/workspace/pgle_output_profile.pbtxt --xla_gpu_reduce_scatter_combine_threshold_bytes=$PER_GPU_COMBINE_THRESHOLD --xla_gpu_run_post_layout_collective_pipeliner=false --xla_gpu_use_memcpy_local_p2p=false" -srun \ ---mpi=pmix \ --N 8 \ --o __OUTPUT_DIR__/output_pretest-%j-%n-%t.txt \ --e __OUTPUT_DIR__/error_pretest-%j-%n-%t.txt \ ---container-image=nvcr.io/nvidia/pytorch:24.02-py3 \ -/usr/local/bin/all_gather_perf_mpi \ ---nthreads 1 \ ---ngpus 1 \ ---minbytes 8M \ ---maxbytes 16G \ ---stepbytes 1M \ ---op sum \ ---datatype float \ ---root 0 \ ---iters 20 \ ---warmup_iters 5 \ ---agg_iters 1 \ ---average 1 \ ---parallel_init 0 \ ---check 1 \ ---blocking 1 \ ---cudagraph 0 \ ---stepfactor 2 -PRETEST_OUTPUT_FILES="__OUTPUT_DIR__/output_pretest-*.txt" -keyword="Avg bus bandwidth" - -# Use grep to search for the keyword in the files -if grep -q "$keyword" $PRETEST_OUTPUT_FILES; then - PRE_TEST_SUCCESS=true -fi -if [ "$PRE_TEST_SUCCESS" = true ]; then - echo "Loading container with srun command" +srun --output=__OUTPUT_DIR__/pre_test/nccl/stdout.txt --error=__OUTPUT_DIR__/pre_test/nccl/stderr.txt --mpi=pmix --container-image=nvcr.io/nvidia/pytorch:24.02-py3 /usr/local/bin/all_reduce_perf_mpi --nthreads 1 --ngpus 1 --minbytes 32M --maxbytes 32M --stepbytes 1M --op sum --datatype float --root 0 --iters 20 --warmup_iters 5 --agg_iters 1 --average 1 --parallel_init 0 --check 1 --blocking 0 --cudagraph 0 +SUCCESS_0=$(grep -q "Avg bus bandwidth" __OUTPUT_DIR__/pre_test/nccl/stdout.txt && echo 1 || echo 0) +PRE_TEST_SUCCESS=$( [ $SUCCESS_0 -eq 1 ] && echo 1 || echo 0 ) +if [ $PRE_TEST_SUCCESS -eq 1 ]; then + echo "Loading container with srun command" srun --mpi=none --container-image=https://docker/url --container-name=cont true echo "Running srun command" srun \ @@ -52,4 +24,4 @@ if [ "$PRE_TEST_SUCCESS" = true ]; then --container-name=cont \ --container-mounts=__OUTPUT_DIR__:/opt/paxml/workspace/ \ /opt/paxml/workspace/run.sh -fi \ No newline at end of file +fi diff --git a/tests/ref_data/nccl.sbatch b/tests/ref_data/nccl.sbatch index 3ac39077..dc179ba9 100644 --- a/tests/ref_data/nccl.sbatch +++ b/tests/ref_data/nccl.sbatch @@ -8,23 +8,4 @@ export SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1) -srun \ ---mpi=pmix \ ---container-image=nvcr.io/nvidia/pytorch:24.02-py3 \ -/usr/local/bin/all_reduce_perf_mpi \ ---nthreads 1 \ ---ngpus 1 \ ---minbytes 32M \ ---maxbytes 32M \ ---stepbytes 1M \ ---op sum \ ---datatype float \ ---root 0 \ ---iters 20 \ ---warmup_iters 5 \ ---agg_iters 1 \ ---average 1 \ ---parallel_init 0 \ ---check 1 \ ---blocking 0 \ ---cudagraph 0 \ No newline at end of file +srun --mpi=pmix --container-image=nvcr.io/nvidia/pytorch:24.02-py3 /usr/local/bin/all_reduce_perf_mpi --nthreads 1 --ngpus 1 --minbytes 32M --maxbytes 32M --stepbytes 1M --op sum --datatype float --root 0 --iters 20 --warmup_iters 5 --agg_iters 1 --average 1 --parallel_init 0 --check 1 --blocking 0 --cudagraph 0 diff --git a/tests/ref_data/sleep.sbatch b/tests/ref_data/sleep.sbatch index 7c24ec14..9262001b 100644 --- a/tests/ref_data/sleep.sbatch +++ b/tests/ref_data/sleep.sbatch @@ -8,6 +8,4 @@ export SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1) -srun \ ---mpi=pmix \ -sleep 5 \ No newline at end of file +srun --mpi=pmix sleep 5 diff --git a/tests/ref_data/ucc.sbatch b/tests/ref_data/ucc.sbatch index 74fa7799..a9f9e686 100644 --- a/tests/ref_data/ucc.sbatch +++ b/tests/ref_data/ucc.sbatch @@ -8,12 +8,4 @@ export SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1) -srun \ ---mpi=pmix \ ---container-image=nvcr.io/nvidia/pytorch:24.02-py3 \ -/opt/hpcx/ucc/bin/ucc_perftest \ --c alltoall \ --b 1 \ --e 8M \ --m cuda \ --F \ No newline at end of file +srun --mpi=pmix --container-image=nvcr.io/nvidia/pytorch:24.02-py3 /opt/hpcx/ucc/bin/ucc_perftest -c alltoall -b 1 -e 8M -m cuda -F diff --git a/tests/slurm_command_gen_strategy/test_common_slurm_command_gen_strategy.py b/tests/slurm_command_gen_strategy/test_common_slurm_command_gen_strategy.py index 87927f9c..534d9cd1 100644 --- a/tests/slurm_command_gen_strategy/test_common_slurm_command_gen_strategy.py +++ b/tests/slurm_command_gen_strategy/test_common_slurm_command_gen_strategy.py @@ -19,7 +19,7 @@ import pytest -from cloudai import Test, TestDefinition, TestRun, TestTemplate +from cloudai import Test, TestDefinition, TestRun, TestScenario, TestTemplate from cloudai.systems import SlurmSystem from cloudai.systems.slurm.strategy import SlurmCommandGenStrategy @@ -54,7 +54,7 @@ def test_filename_generation(strategy_fixture: SlurmCommandGenStrategy, testrun_ env_vars = {"TEST_VAR": "VALUE"} cmd_args = {"test_arg": "test_value"} slurm_args = strategy_fixture._parse_slurm_args(job_name_prefix, env_vars, cmd_args, testrun_fixture) - srun_command = strategy_fixture.generate_srun_command(slurm_args, env_vars, cmd_args, testrun_fixture) + srun_command = strategy_fixture._gen_srun_command(slurm_args, env_vars, cmd_args, testrun_fixture) sbatch_command = strategy_fixture._write_sbatch_script(slurm_args, env_vars, srun_command, testrun_fixture) filepath_from_command = sbatch_command.split()[-1] @@ -120,3 +120,131 @@ def test_raises_if_no_default_partition(slurm_system: SlurmSystem): "system configuration. Please ensure that 'default_partition' is set correctly " "in the corresponding system configuration (e.g., system.toml)." ) in str(exc_info.value) + + +@pytest.mark.parametrize( + "pre_test,post_test,expected_script_lines", + [ + # No pre_test, no post_test + (None, None, ["srun"]), + # One pre_test, no post_test + ( + [Mock(test=Mock(name="test1", test_template=Mock()))], + None, + [ + "pre_test", + "PRE_TEST_SUCCESS=$( [ $SUCCESS_0 -eq 1 ] && echo 1 || echo 0 )", + "if [ $PRE_TEST_SUCCESS -eq 1 ]; then", + " srun", + "fi", + ], + ), + # No pre_test, one post_test + ( + None, + [Mock(test=Mock(name="test2", test_template=Mock()))], + [ + "srun", + "post_test", + ], + ), + # One pre_test, one post_test + ( + [Mock(test=Mock(name="test1", test_template=Mock()))], + [Mock(test=Mock(name="test2", test_template=Mock()))], + [ + "pre_test", + "PRE_TEST_SUCCESS=$( [ $SUCCESS_0 -eq 1 ] && echo 1 || echo 0 )", + "if [ $PRE_TEST_SUCCESS -eq 1 ]; then", + " srun", + " post_test", + "fi", + ], + ), + # Multiple pre_tests, multiple post_tests + ( + [Mock(test=Mock(name="test1", test_template=Mock())), Mock(test=Mock(name="test2", test_template=Mock()))], + [Mock(test=Mock(name="test3", test_template=Mock())), Mock(test=Mock(name="test4", test_template=Mock()))], + [ + "pre_test", + "pre_test", + "PRE_TEST_SUCCESS=$( [ $SUCCESS_0 -eq 1 ] && [ $SUCCESS_1 -eq 1 ] && echo 1 || echo 0 )", + "if [ $PRE_TEST_SUCCESS -eq 1 ]; then", + " srun", + " post_test", + " post_test", + "fi", + ], + ), + # Multiple pre_tests, no post_test + ( + [Mock(test=Mock(name="test1", test_template=Mock())), Mock(test=Mock(name="test2", test_template=Mock()))], + None, + [ + "pre_test", + "pre_test", + "PRE_TEST_SUCCESS=$( [ $SUCCESS_0 -eq 1 ] && [ $SUCCESS_1 -eq 1 ] && echo 1 || echo 0 )", + "if [ $PRE_TEST_SUCCESS -eq 1 ]; then", + " srun", + "fi", + ], + ), + # No pre_test, multiple post_tests + ( + None, + [Mock(test=Mock(name="test3", test_template=Mock())), Mock(test=Mock(name="test4", test_template=Mock()))], + [ + "srun", + "post_test", + "post_test", + ], + ), + # Multiple pre_tests, single post_test + ( + [Mock(test=Mock(name="test1", test_template=Mock())), Mock(test=Mock(name="test2", test_template=Mock()))], + [Mock(test=Mock(name="test3", test_template=Mock()))], + [ + "pre_test", + "pre_test", + "PRE_TEST_SUCCESS=$( [ $SUCCESS_0 -eq 1 ] && [ $SUCCESS_1 -eq 1 ] && echo 1 || echo 0 )", + "if [ $PRE_TEST_SUCCESS -eq 1 ]; then", + " srun", + " post_test", + "fi", + ], + ), + ], +) +def test_pre_test_post_test_combinations( + strategy_fixture: SlurmCommandGenStrategy, + testrun_fixture: TestRun, + pre_test, + post_test, + expected_script_lines, +): + testrun_fixture.pre_test = Mock(spec=TestScenario) if pre_test else None + testrun_fixture.post_test = Mock(spec=TestScenario) if post_test else None + + if pre_test is not None: + testrun_fixture.pre_test = Mock(spec=TestScenario) + testrun_fixture.pre_test.test_runs = pre_test + for idx, run in enumerate(pre_test): + run.test.test_template.gen_srun_success_check.return_value = "pre_test" + run.test.test_template.gen_srun_command.return_value = "srun" + run.test.name = f"test{idx+1}" + + if post_test is not None: + testrun_fixture.post_test = Mock(spec=TestScenario) + testrun_fixture.post_test.test_runs = post_test + for idx, run in enumerate(post_test): + run.test.test_template.gen_srun_command.return_value = "post_test" + run.test.name = f"test{idx+1}" + + sbatch_command = strategy_fixture.gen_exec_command(testrun_fixture) + script_file_path = sbatch_command.split()[-1] + + with open(script_file_path, "r") as script_file: + script_content = script_file.read() + + for expected_line in expected_script_lines: + assert expected_line in script_content, f"Expected '{expected_line}' in generated script but it was missing." diff --git a/tests/slurm_command_gen_strategy/test_jax_toolbox_slurm_command_gen_strategy.py b/tests/slurm_command_gen_strategy/test_jax_toolbox_slurm_command_gen_strategy.py index 5db0d1bd..131e4a55 100644 --- a/tests/slurm_command_gen_strategy/test_jax_toolbox_slurm_command_gen_strategy.py +++ b/tests/slurm_command_gen_strategy/test_jax_toolbox_slurm_command_gen_strategy.py @@ -25,7 +25,7 @@ from cloudai.systems import SlurmSystem from cloudai.test_definitions.gpt import GPTCmdArgs, GPTTestDefinition from cloudai.test_definitions.grok import GrokCmdArgs, GrokTestDefinition -from cloudai.test_definitions.jax_toolbox import JaxFdl, PreTest +from cloudai.test_definitions.jax_toolbox import JaxFdl class TestJaxToolboxSlurmCommandGenStrategy: @@ -63,7 +63,6 @@ def test_gen_exec_command( test_fixture, ) -> None: test_def = request.getfixturevalue(test_fixture) - test_def.cmd_args.pre_test = PreTest(enable=True) test = Test(test_definition=test_def, test_template=JaxToolbox(slurm_system, "name")) test_run = TestRun( @@ -74,14 +73,10 @@ def test_gen_exec_command( name="test-job", ) - cmd_gen_strategy._generate_pre_test_command = MagicMock(return_value="pre_test_command") cmd = cmd_gen_strategy.gen_exec_command(test_run) assert cmd == f"sbatch {test_run.output_path}/cloudai_sbatch_script.sh" assert (test_run.output_path / "run.sh").exists() - content = Path(f"{test_run.output_path}/cloudai_sbatch_script.sh").read_text() - assert "pre_test_command" in content - @pytest.mark.parametrize( "cmd_args, expected", [ @@ -215,100 +210,6 @@ def test_generate_python_command( "fi", ] - def test_generate_pre_test_command( - self, cmd_gen_strategy: JaxToolboxSlurmCommandGenStrategy, grok_test: GrokTestDefinition, tmp_path: Path - ) -> None: - grok_test.cmd_args.pre_test = PreTest(enable=True) - - nccl_test = grok_test.cmd_args.pre_test.nccl_test - nccl_test.num_nodes = 2 - nccl_test.minbytes = "32M" - nccl_test.blocking = 0 - - cargs = {"output_path": str(tmp_path), **grok_test.cmd_args_dict} - - pre_test_cli = cmd_gen_strategy._generate_pre_test_command(cargs, tmp_path, tmp_path).splitlines() - - expected_pre_test_cli = [ - "srun \\", - "--mpi=pmix \\", - f"-N {nccl_test.num_nodes} \\", - f"-o {tmp_path} \\", - f"-e {tmp_path} \\", - f"--container-image={nccl_test.docker_image_url} \\", - f"/usr/local/bin/{nccl_test.subtest_name} \\", - f"--nthreads {nccl_test.nthreads} \\", - f"--ngpus {nccl_test.ngpus} \\", - f"--minbytes {nccl_test.minbytes} \\", - f"--maxbytes {nccl_test.maxbytes} \\", - f"--stepbytes {nccl_test.stepbytes} \\", - f"--op {nccl_test.op} \\", - f"--datatype {nccl_test.datatype} \\", - f"--root {nccl_test.root} \\", - f"--iters {nccl_test.iters} \\", - f"--warmup_iters {nccl_test.warmup_iters} \\", - f"--agg_iters {nccl_test.agg_iters} \\", - f"--average {nccl_test.average} \\", - f"--parallel_init {nccl_test.parallel_init} \\", - f"--check {nccl_test.check} \\", - f"--blocking {nccl_test.blocking} \\", - f"--cudagraph {nccl_test.cudagraph} \\", - f"--stepfactor {nccl_test.stepfactor}", - ] - - assert pre_test_cli == expected_pre_test_cli, ( - "The generated pre-test command did not match the expected command.\n" - f"Expected: {expected_pre_test_cli}\n" - f"Actual: {pre_test_cli}" - ) - - def test_generate_srun_command(self, slurm_system, cmd_gen_strategy, grok_test): - cmd_gen_strategy.test_name = grok_test.name - Path("/tmp/output").mkdir(parents=True, exist_ok=True) - - output_path = Path("/tmp/output/output") - output_path.mkdir(parents=True, exist_ok=True) - - # Use the existing setup for mocking internal methods - cmd_gen_strategy._generate_pre_test_command = MagicMock(return_value="srun --mpi=none pre_test_command") - cmd_gen_strategy._generate_run_command = MagicMock(return_value="srun --mpi=none run_command") - cmd_gen_strategy._generate_container_load_command = MagicMock( - return_value="srun --mpi=none container_load_command" - ) - - slurm_args = { - "output": "/tmp/output/output-%j.txt", - "error": "/tmp/output/error-%j.txt", - "image_path": "fake_image_url", - "container_mounts": "/tmp/output:/workspace", - } - cmd_args = { - "output_path": "/tmp/output", - "pre_test": {"enable": True}, - f"{grok_test.name}.setup_flags.docker_workspace_dir": "/workspace/docker", - f"{grok_test.name}.setup_flags.tfds_data_dir": "/workspace/tfds", - f"{grok_test.name}.setup_flags.enable_checkpoint_saving": True, - } - - pre_test_command = cmd_gen_strategy._generate_pre_test_command( - cmd_args, Path("/tmp/output"), Path("/tmp/output") - ) - run_command = cmd_gen_strategy._generate_run_command(slurm_args) - container_load_command = cmd_gen_strategy._generate_container_load_command(slurm_args) - - result_command = f"{pre_test_command}\n{container_load_command}\n{run_command}" - - # Assert expected parts of the command are in the generated result - assert "pre_test_command" in result_command - assert "container_load_command" in result_command - assert "run_command" in result_command - assert "srun" in result_command - assert "--mpi=none" in result_command - - cmd_gen_strategy._generate_pre_test_command.assert_called_once() - cmd_gen_strategy._generate_run_command.assert_called_once() - cmd_gen_strategy._generate_container_load_command.assert_called_once() - def test_gpt_test_definition_cmd_args_dict(): gpt = GPTTestDefinition( @@ -324,7 +225,7 @@ def test_gpt_test_definition_cmd_args_dict(): assert "GPT.setup_flags" in cargs assert "GPT.XLA_FLAGS" in cargs - for k in {"pre_test", "docker_image_url", "load_container"}: + for k in {"docker_image_url", "load_container"}: assert k in cargs assert f"GPT.{k}" not in cargs @@ -348,7 +249,7 @@ def test_grok_test_definition_cmd_args_dict(): assert "Grok.perf" in cargs assert "XLA_FLAGS" in cargs["Grok.perf"] - for k in {"pre_test", "docker_image_url", "load_container"}: + for k in {"docker_image_url", "load_container"}: assert k in cargs assert f"Grok.{k}" not in cargs diff --git a/tests/test_acceptance.py b/tests/test_acceptance.py index be5f1299..d1e57782 100644 --- a/tests/test_acceptance.py +++ b/tests/test_acceptance.py @@ -22,7 +22,7 @@ import pytest -from cloudai import NcclTest, Test, TestRun, UCCTest +from cloudai import NcclTest, Test, TestRun, TestScenario, UCCTest from cloudai.cli import handle_dry_run_and_run, setup_logging from cloudai.schema.test_template.jax_toolbox.slurm_command_gen_strategy import JaxToolboxSlurmCommandGenStrategy from cloudai.schema.test_template.jax_toolbox.template import JaxToolbox @@ -60,6 +60,7 @@ def test_slurm(tmp_path: Path, scenario: Dict): system_config=Path("conf/common/system/example_slurm_cluster.toml"), test_templates_dir=Path("conf/common/test_template"), tests_dir=Path("conf/common/test"), + hook_dir=Path("conf/common/hook"), test_scenario=test_scenario_path, output_dir=tmp_path, ) @@ -90,7 +91,7 @@ def partial_tr(slurm_system: SlurmSystem) -> partial[TestRun]: return partial(TestRun, num_nodes=1, nodes=[], output_path=slurm_system.output_path) -@pytest.fixture(params=["ucc", "nccl", "sleep", "gpt-pretest", "gpt-no-pretest", "grok-pretest", "grok-no-pretest"]) +@pytest.fixture(params=["ucc", "nccl", "sleep", "gpt-pre-test", "gpt-no-hook", "grok-pre-test", "grok-no-hook"]) def test_req(request, slurm_system: SlurmSystem, partial_tr: partial[TestRun]) -> tuple[TestRun, str, Optional[str]]: if request.param == "ucc": tr = partial_tr( @@ -158,10 +159,21 @@ def test_req(request, slurm_system: SlurmSystem, partial_tr: partial[TestRun]) - slurm_system, tr.test.test_definition.cmd_args_dict ) tr.test.test_template.command_gen_strategy.job_name = Mock(return_value="job_name") - if "no-pretest" in request.param: - tr.test.test_definition.cmd_args.pre_test.enable = False - else: - tr.test.test_definition.cmd_args.pre_test.enable = True + if "pre-test" in request.param: + pre_test_tr = partial_tr( + name="nccl", + test=Test( + test_definition=NCCLTestDefinition( + name="nccl", description="nccl", test_template_name="nccl", cmd_args=NCCLCmdArgs() + ), + test_template=NcclTest(slurm_system, name="nccl"), + ), + ) + pre_test_tr.test.test_template.command_gen_strategy = NcclTestSlurmCommandGenStrategy( + slurm_system, pre_test_tr.test.test_definition.cmd_args_dict + ) + pre_test_tr.test.test_template.command_gen_strategy.job_name = Mock(return_value="job_name") + tr.pre_test = TestScenario(name=f"{pre_test_tr.name} NCCL pre-test", test_runs=[pre_test_tr]) return (tr, f"{request.param}.sbatch", "gpt.run") elif request.param.startswith("grok-"): @@ -182,10 +194,21 @@ def test_req(request, slurm_system: SlurmSystem, partial_tr: partial[TestRun]) - slurm_system, tr.test.test_definition.cmd_args_dict ) tr.test.test_template.command_gen_strategy.job_name = Mock(return_value="job_name") - if "no-pretest" in request.param: - tr.test.test_definition.cmd_args.pre_test.enable = False - else: - tr.test.test_definition.cmd_args.pre_test.enable = True + if "pre-test" in request.param: + pre_test_tr = partial_tr( + name="nccl", + test=Test( + test_definition=NCCLTestDefinition( + name="nccl", description="nccl", test_template_name="nccl", cmd_args=NCCLCmdArgs() + ), + test_template=NcclTest(slurm_system, name="nccl"), + ), + ) + pre_test_tr.test.test_template.command_gen_strategy = NcclTestSlurmCommandGenStrategy( + slurm_system, pre_test_tr.test.test_definition.cmd_args_dict + ) + pre_test_tr.test.test_template.command_gen_strategy.job_name = Mock(return_value="job_name") + tr.pre_test = TestScenario(name=f"{pre_test_tr.name} NCCL pre-test", test_runs=[pre_test_tr]) return (tr, f"{request.param}.sbatch", "grok.run") @@ -199,8 +222,8 @@ def test_sbatch_generation(slurm_system: SlurmSystem, test_req: tuple[TestRun, s sbatch_script = tr.test.test_template.gen_exec_command(tr).split()[-1] - curr = Path(sbatch_script).read_text() - ref = (Path(__file__).parent / "ref_data" / test_req[1]).read_text() + curr = Path(sbatch_script).read_text().strip() + ref = (Path(__file__).parent / "ref_data" / test_req[1]).read_text().strip() ref = ref.replace("__OUTPUT_DIR__", str(slurm_system.output_path)).replace("__JOB_NAME__", "job_name") assert curr == ref diff --git a/tests/test_parser.py b/tests/test_parser.py index d35896a9..3f901e0d 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -50,19 +50,84 @@ def test_no_scenario(self, test_parser: Mock, parser: Parser): @patch("cloudai._core.test_parser.TestParser.parse_all") @patch("cloudai._core.test_scenario_parser.TestScenarioParser.parse") - def test_scenario_filters_tests(self, test_scenario_parser: Mock, test_parser: Mock, parser: Parser): + def test_scenario_without_hook(self, test_scenario_parser: Mock, test_parser: Mock, parser: Parser): tests_dir = parser.system_config_path.parent.parent / "test" - fake_tests = [] - for i in range(3): - fake_tests.append(Mock()) - fake_tests[-1].name = f"test-{i}" - test_parser.return_value = fake_tests + + fake_tests = [Mock(name=f"test-{i}") for i in range(3)] + for i, test in enumerate(fake_tests): + test.name = f"test-{i}" + + test_parser.side_effect = [fake_tests, []] + fake_scenario = Mock() fake_scenario.test_runs = [Mock()] fake_scenario.test_runs[0].test.name = "test-1" test_scenario_parser.return_value = fake_scenario + _, tests, _ = parser.parse(tests_dir, Path()) + assert len(tests) == 1 + assert tests[0].name == "test-1" + + @patch("cloudai._core.test_parser.TestParser.parse_all") + @patch("cloudai._core.test_scenario_parser.TestScenarioParser.parse") + @patch("cloudai.parser.Parser.parse_hooks") + def test_scenario_with_hook_common_tests( + self, parse_hooks: Mock, test_scenario_parser: Mock, test_parser: Mock, parser: Parser + ): + tests_dir = parser.system_config_path.parent.parent / "test" + + main_tests = [Mock() for _ in range(3)] + for i, test in enumerate(main_tests): + test.name = f"test-{i}" + hook_tests = [Mock()] + hook_tests[0].name = "test-1" + + test_parser.side_effect = [main_tests, hook_tests] + + fake_scenario = Mock() + fake_scenario.test_runs = [Mock()] + fake_scenario.test_runs[0].test.name = "test-1" + test_scenario_parser.return_value = fake_scenario + + fake_hook = Mock() + fake_hook.test_runs = [Mock()] + fake_hook.test_runs[0].test.name = "test-1" + parse_hooks.return_value = {"hook-1": fake_hook} + + _, tests, _ = parser.parse(tests_dir, Path()) + + filtered_test_names = {"test-1"} + assert len(tests) == 1 + assert "test-1" in filtered_test_names + + @patch("cloudai._core.test_parser.TestParser.parse_all") + @patch("cloudai._core.test_scenario_parser.TestScenarioParser.parse") + def test_scenario_with_hook_exclusive_tests(self, test_scenario_parser: Mock, test_parser: Mock, parser: Parser): + tests_dir = parser.system_config_path.parent.parent / "test" + test_scenario_path = Path("/mock/test_scenario.toml") + + main_tests = [Mock() for _ in range(3)] + hook_tests = [Mock()] + for i, test in enumerate(main_tests): + test.name = f"test-{i}" + hook_tests[0].name = "hook-test-1" + + test_parser.side_effect = [main_tests, hook_tests] + + fake_scenario = Mock() + fake_scenario.test_runs = [Mock()] + fake_scenario.test_runs[0].test.name = "test-1" + test_scenario_parser.return_value = fake_scenario + + _, filtered_tests, _ = parser.parse(tests_dir, test_scenario_path) + + filtered_test_names = {t.name for t in filtered_tests} + assert len(filtered_tests) == 2 + assert "test-1" in filtered_test_names + assert "hook-test-1" in filtered_test_names + assert "test-0" not in filtered_test_names + assert "test-2" not in filtered_test_names def test_parse_system(self, parser: Parser): parser.system_config_path = Path("conf/common/system/example_slurm_cluster.toml") diff --git a/tests/test_test_scenario.py b/tests/test_test_scenario.py index ab81bdbd..72639068 100644 --- a/tests/test_test_scenario.py +++ b/tests/test_test_scenario.py @@ -27,7 +27,7 @@ @pytest.fixture def test_scenario_parser(tmp_path: Path) -> TestScenarioParser: - tsp = TestScenarioParser(Path(""), {}) + tsp = TestScenarioParser(Path(""), {}, {}) return tsp