diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 207d7be5..1077797a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,13 +23,13 @@ jobs: run: pip install -r requirements-dev.txt - name: Run ruff linter - run: ruff check . + run: ruff check - name: Run ruff formatter - run: ruff format --check --diff . + run: ruff format --check --diff - name: Run pyright - run: pyright . + run: pyright - name: Run vulture check run: vulture src/ tests/ diff --git a/conf/common/system/example_slurm_cluster.toml b/conf/common/system/example_slurm_cluster.toml index ddf2f210..ff795171 100644 --- a/conf/common/system/example_slurm_cluster.toml +++ b/conf/common/system/example_slurm_cluster.toml @@ -17,7 +17,7 @@ name = "example-cluster" scheduler = "slurm" -install_path = "./install" +install_path = "./install_dir" output_path = "./results" default_partition = "partition_1" diff --git a/pyproject.toml b/pyproject.toml index a6442964..4ef950cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,3 +100,6 @@ min_confidence = 100 [tool.coverage.report] exclude_also = ["@abstractmethod"] + +[tool.pyright] +include = ["src", "tests"] diff --git a/src/cloudai/__init__.py b/src/cloudai/__init__.py index fd394f24..fc4dc7a0 100644 --- a/src/cloudai/__init__.py +++ b/src/cloudai/__init__.py @@ -81,6 +81,13 @@ from .schema.test_template.sleep.slurm_command_gen_strategy import SleepSlurmCommandGenStrategy from .schema.test_template.sleep.standalone_command_gen_strategy import SleepStandaloneCommandGenStrategy from .schema.test_template.sleep.template import Sleep +from .schema.test_template.slurm_container.report_generation_strategy import ( + SlurmContainerReportGenerationStrategy, +) +from .schema.test_template.slurm_container.slurm_command_gen_strategy import ( + SlurmContainerCommandGenStrategy, +) +from .schema.test_template.slurm_container.template import SlurmContainer from .schema.test_template.ucc_test.grading_strategy import UCCTestGradingStrategy from .schema.test_template.ucc_test.report_generation_strategy import UCCTestReportGenerationStrategy from .schema.test_template.ucc_test.slurm_command_gen_strategy import UCCTestSlurmCommandGenStrategy @@ -98,6 +105,7 @@ SleepTestDefinition, UCCTestDefinition, ) +from .test_definitions.slurm_container import SlurmContainerTestDefinition Registry().add_runner("slurm", SlurmRunner) Registry().add_runner("kubernetes", KubernetesRunner) @@ -121,14 +129,21 @@ Registry().add_strategy(JobIdRetrievalStrategy, [SlurmSystem], [NeMoLauncher], NeMoLauncherSlurmJobIdRetrievalStrategy) Registry().add_strategy(CommandGenStrategy, [SlurmSystem], [NeMoLauncher], NeMoLauncherSlurmCommandGenStrategy) Registry().add_strategy(ReportGenerationStrategy, [SlurmSystem], [UCCTest], UCCTestReportGenerationStrategy) +Registry().add_strategy( + ReportGenerationStrategy, + [SlurmSystem], + [SlurmContainer], + SlurmContainerReportGenerationStrategy, +) Registry().add_strategy(GradingStrategy, [SlurmSystem], [NeMoLauncher], NeMoLauncherGradingStrategy) + Registry().add_strategy(GradingStrategy, [SlurmSystem], [JaxToolbox], JaxToolboxGradingStrategy) Registry().add_strategy(GradingStrategy, [SlurmSystem], [UCCTest], UCCTestGradingStrategy) Registry().add_strategy(CommandGenStrategy, [SlurmSystem], [JaxToolbox], JaxToolboxSlurmCommandGenStrategy) Registry().add_strategy( JobIdRetrievalStrategy, [SlurmSystem], - [ChakraReplay, JaxToolbox, NcclTest, UCCTest, Sleep], + [ChakraReplay, JaxToolbox, NcclTest, UCCTest, Sleep, SlurmContainer], SlurmJobIdRetrievalStrategy, ) Registry().add_strategy(JobIdRetrievalStrategy, [StandaloneSystem], [Sleep], StandaloneJobIdRetrievalStrategy) @@ -141,13 +156,14 @@ Registry().add_strategy( JobStatusRetrievalStrategy, [SlurmSystem], - [ChakraReplay, UCCTest, NeMoLauncher, Sleep], + [ChakraReplay, UCCTest, NeMoLauncher, Sleep, SlurmContainer], DefaultJobStatusRetrievalStrategy, ) Registry().add_strategy(CommandGenStrategy, [SlurmSystem], [UCCTest], UCCTestSlurmCommandGenStrategy) Registry().add_strategy(ReportGenerationStrategy, [SlurmSystem], [ChakraReplay], ChakraReplayReportGenerationStrategy) Registry().add_strategy(GradingStrategy, [SlurmSystem], [ChakraReplay], ChakraReplayGradingStrategy) Registry().add_strategy(CommandGenStrategy, [SlurmSystem], [ChakraReplay], ChakraReplaySlurmCommandGenStrategy) +Registry().add_strategy(CommandGenStrategy, [SlurmSystem], [SlurmContainer], SlurmContainerCommandGenStrategy) Registry().add_installer("slurm", SlurmInstaller) Registry().add_installer("standalone", StandaloneInstaller) @@ -165,6 +181,7 @@ Registry().add_test_definition("JaxToolboxGPT", GPTTestDefinition) Registry().add_test_definition("JaxToolboxGrok", GrokTestDefinition) Registry().add_test_definition("JaxToolboxNemotron", NemotronTestDefinition) +Registry().add_test_definition("SlurmContainer", SlurmContainerTestDefinition) Registry().add_test_template("ChakraReplay", ChakraReplay) Registry().add_test_template("NcclTest", NcclTest) @@ -174,6 +191,7 @@ Registry().add_test_template("JaxToolboxGPT", JaxToolbox) Registry().add_test_template("JaxToolboxGrok", JaxToolbox) Registry().add_test_template("JaxToolboxNemotron", JaxToolbox) +Registry().add_test_template("SlurmContainer", SlurmContainer) __all__ = [ "BaseInstaller", diff --git a/src/cloudai/installer/slurm_installer.py b/src/cloudai/installer/slurm_installer.py index f7904381..8d542f28 100644 --- a/src/cloudai/installer/slurm_installer.py +++ b/src/cloudai/installer/slurm_installer.py @@ -118,6 +118,8 @@ def install_one(self, item: Installable) -> InstallStatusResult: if isinstance(item, DockerImage): res = self._install_docker_image(item) return InstallStatusResult(res.success, res.message) + elif isinstance(item, GitRepo): + return self._install_one_git_repo(item) elif isinstance(item, PythonExecutable): return self._install_python_executable(item) @@ -139,6 +141,8 @@ def uninstall_one(self, item: Installable) -> InstallStatusResult: return InstallStatusResult(res.success, res.message) elif isinstance(item, PythonExecutable): return self._uninstall_python_executable(item) + elif isinstance(item, GitRepo): + return self._uninstall_git_repo(item) return InstallStatusResult(False, f"Unsupported item type: {type(item)}") @@ -148,6 +152,12 @@ def is_installed_one(self, item: Installable) -> InstallStatusResult: if res.success and res.docker_image_path: item.installed_path = res.docker_image_path return InstallStatusResult(res.success, res.message) + elif isinstance(item, GitRepo): + repo_path = self.system.install_path / item.repo_name + if repo_path.exists(): + item.installed_path = repo_path + return InstallStatusResult(True) + return InstallStatusResult(False, f"Git repository {item.git_url} not cloned") elif isinstance(item, PythonExecutable): return self._is_python_executable_installed(item) diff --git a/src/cloudai/report_generator/report_generator.py b/src/cloudai/report_generator/report_generator.py index 3c8a7e2a..9d7ef563 100644 --- a/src/cloudai/report_generator/report_generator.py +++ b/src/cloudai/report_generator/report_generator.py @@ -70,7 +70,14 @@ def _generate_test_report(self, directory_path: Path, tr: TestRun) -> None: tr (TestRun): The test run object. """ for subdir in directory_path.iterdir(): - if subdir.is_dir() and tr.test.test_template.can_handle_directory(subdir): - tr.test.test_template.generate_report(tr.test.name, subdir, tr.sol) - else: - logging.warning(f"Skipping directory '{subdir}' for test '{tr.test.name}'") + if not subdir.is_dir(): + logging.debug(f"Skipping file '{subdir}', not a directory.") + continue + if not tr.test.test_template.can_handle_directory(subdir): + logging.warning( + f"Skipping '{subdir}', can't hande with " + f"strategy={tr.test.test_template.report_generation_strategy}." + ) + continue + + tr.test.test_template.generate_report(tr.test.name, subdir, tr.sol) diff --git a/src/cloudai/runner/slurm/slurm_runner.py b/src/cloudai/runner/slurm/slurm_runner.py index c7b6e12b..7362bcde 100644 --- a/src/cloudai/runner/slurm/slurm_runner.py +++ b/src/cloudai/runner/slurm/slurm_runner.py @@ -68,4 +68,5 @@ def _submit_test(self, tr: TestRun) -> SlurmJob: stderr=stderr, message="Failed to retrieve job ID from command output.", ) + logging.info(f"Submitted slurm job: {job_id}") return SlurmJob(tr, id=job_id) diff --git a/src/cloudai/schema/test_template/nemo_launcher/slurm_command_gen_strategy.py b/src/cloudai/schema/test_template/nemo_launcher/slurm_command_gen_strategy.py index 075e462a..d743b72e 100644 --- a/src/cloudai/schema/test_template/nemo_launcher/slurm_command_gen_strategy.py +++ b/src/cloudai/schema/test_template/nemo_launcher/slurm_command_gen_strategy.py @@ -47,7 +47,11 @@ def gen_exec_command(self, tr: TestRun) -> str: ) self.final_cmd_args["cluster.gpus_per_node"] = self.system.gpus_per_node or "null" - repo_path = tdef.python_executable.git_repo.installed_path + repo_path = ( + tdef.python_executable.git_repo.installed_path.absolute() + if tdef.python_executable.git_repo.installed_path is not None + else None + ) if not repo_path: logging.warning( f"Local clone of git repo {tdef.python_executable.git_repo} does not exist. " diff --git a/src/cloudai/schema/test_template/slurm_container/report_generation_strategy.py b/src/cloudai/schema/test_template/slurm_container/report_generation_strategy.py new file mode 100644 index 00000000..c7c4554d --- /dev/null +++ b/src/cloudai/schema/test_template/slurm_container/report_generation_strategy.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from pathlib import Path +from typing import Optional + +from cloudai import ReportGenerationStrategy + + +class SlurmContainerReportGenerationStrategy(ReportGenerationStrategy): + """Report generation strategy for a generic Slurm container test.""" + + def can_handle_directory(self, directory_path: Path) -> bool: + stdout_path = directory_path / "stdout.txt" + if stdout_path.exists(): + with stdout_path.open("r") as file: + if re.search( + r"Training epoch \d+, iteration \d+/\d+ | lr: [\d.]+ | global_batch_size: \d+ | global_step: \d+ | " + r"reduced_train_loss: [\d.]+ | train_step_timing in s: [\d.]+", + file.read(), + ): + return True + return False + + def generate_report(self, test_name: str, directory_path: Path, sol: Optional[float] = None) -> None: + stdout_path = directory_path / "stdout.txt" + if not stdout_path.is_file(): + return + + with stdout_path.open("r") as file: + lines = file.readlines() + with open(directory_path / "report.csv", "w") as csv_file: + csv_file.write( + "epoch,iteration,lr,global_batch_size,global_step,reduced_train_loss,train_step_timing,consumed_samples\n" + ) + for line in lines: + pattern = ( + r"Training epoch (\d+), iteration (\d+)/\d+ \| lr: ([\d.]+) \| global_batch_size: (\d+) \| " + r"global_step: (\d+) \| reduced_train_loss: ([\d.]+) \| train_step_timing in s: ([\d.]+)" + ) + if " | consumed_samples:" in line: + pattern = ( + r"Training epoch (\d+), iteration (\d+)/\d+ \| lr: ([\d.]+) \| global_batch_size: (\d+) \| " + r"global_step: (\d+) \| reduced_train_loss: ([\d.]+) \| train_step_timing in s: ([\d.]+) " + r"\| consumed_samples: (\d+)" + ) + + match = re.match(pattern, line) + if match: + csv_file.write(",".join(match.groups()) + "\n") diff --git a/src/cloudai/schema/test_template/slurm_container/slurm_command_gen_strategy.py b/src/cloudai/schema/test_template/slurm_container/slurm_command_gen_strategy.py new file mode 100644 index 00000000..23a22958 --- /dev/null +++ b/src/cloudai/schema/test_template/slurm_container/slurm_command_gen_strategy.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, cast + +from cloudai import TestRun +from cloudai.systems.slurm.strategy import SlurmCommandGenStrategy +from cloudai.test_definitions.slurm_container import SlurmContainerTestDefinition + + +class SlurmContainerCommandGenStrategy(SlurmCommandGenStrategy): + """Command generation strategy for generic Slurm container tests.""" + + def gen_srun_prefix(self, slurm_args: dict[str, Any], tr: TestRun) -> list[str]: + tdef: SlurmContainerTestDefinition = cast(SlurmContainerTestDefinition, tr.test.test_definition) + slurm_args["image_path"] = tdef.docker_image.installed_path + slurm_args["container_mounts"] = ",".join(tdef.container_mounts(self.system.install_path)) + + cmd = super().gen_srun_prefix(slurm_args, tr) + return cmd + ["--no-container-mount-home"] + + def generate_test_command(self, env_vars: dict[str, str], cmd_args: dict[str, str], tr: TestRun) -> list[str]: + srun_command_parts: list[str] = [] + if tr.test.extra_cmd_args: + srun_command_parts.append(tr.test.extra_cmd_args) + + return srun_command_parts diff --git a/src/cloudai/schema/test_template/slurm_container/template.py b/src/cloudai/schema/test_template/slurm_container/template.py new file mode 100644 index 00000000..9e49eb35 --- /dev/null +++ b/src/cloudai/schema/test_template/slurm_container/template.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cloudai import TestTemplate + + +class SlurmContainer(TestTemplate): + """Generic Slurm container test template.""" + + pass diff --git a/src/cloudai/systems/slurm/strategy/slurm_command_gen_strategy.py b/src/cloudai/systems/slurm/strategy/slurm_command_gen_strategy.py index ee8a463a..910db56c 100644 --- a/src/cloudai/systems/slurm/strategy/slurm_command_gen_strategy.py +++ b/src/cloudai/systems/slurm/strategy/slurm_command_gen_strategy.py @@ -195,11 +195,11 @@ def gen_post_test(self, post_test: TestScenario, base_output_path: Path) -> str: def _gen_srun_command( self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, str], tr: TestRun ) -> str: - srun_command_parts = self.gen_srun_prefix(slurm_args) + srun_command_parts = self.gen_srun_prefix(slurm_args, tr) test_command_parts = self.generate_test_command(env_vars, cmd_args, tr) return " ".join(srun_command_parts + test_command_parts) - def gen_srun_prefix(self, slurm_args: Dict[str, Any]) -> List[str]: + def gen_srun_prefix(self, slurm_args: Dict[str, Any], tr: TestRun) -> List[str]: srun_command_parts = ["srun", f"--mpi={self.system.mpi}"] if slurm_args.get("image_path"): srun_command_parts.append(f'--container-image={slurm_args["image_path"]}') diff --git a/src/cloudai/test_definitions/slurm_container.py b/src/cloudai/test_definitions/slurm_container.py new file mode 100644 index 00000000..e84d3b47 --- /dev/null +++ b/src/cloudai/test_definitions/slurm_container.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Optional + +from cloudai import CmdArgs, Installable, TestDefinition +from cloudai.installer.installables import DockerImage, GitRepo + + +class SlurmContainerCmdArgs(CmdArgs): + """Command line arguments for a generic Slurm container test.""" + + docker_image_url: str + repository_url: str + repository_commit_hash: str + mcore_vfm_repo: str + mcore_vfm_commit_hash: str + + +class SlurmContainerTestDefinition(TestDefinition): + """Test definition for a generic Slurm container test.""" + + cmd_args: SlurmContainerCmdArgs + + _docker_image: Optional[DockerImage] = None + _git_repo: Optional[GitRepo] = None + _mcore_git_repo: Optional[GitRepo] = None + + @property + def docker_image(self) -> DockerImage: + if not self._docker_image: + self._docker_image = DockerImage(url=self.cmd_args.docker_image_url) + return self._docker_image + + @property + def git_repo(self) -> GitRepo: + if not self._git_repo: + self._git_repo = GitRepo( + git_url=self.cmd_args.repository_url, commit_hash=self.cmd_args.repository_commit_hash + ) + + return self._git_repo + + @property + def mcore_vfm_git_repo(self) -> GitRepo: + if not self._mcore_git_repo: + self._mcore_git_repo = GitRepo( + git_url=self.cmd_args.mcore_vfm_repo, commit_hash=self.cmd_args.mcore_vfm_commit_hash + ) + + return self._mcore_git_repo + + def container_mounts(self, root: Path) -> list[str]: + repo_path = self.git_repo.installed_path or root / self.git_repo.repo_name + mcore_vfm_path = self.mcore_vfm_git_repo.installed_path or root / self.mcore_vfm_git_repo.repo_name + return [ + f"{repo_path.absolute()}:/work", + f"{mcore_vfm_path.absolute()}:/opt/megatron-lm", + ] + + @property + def installables(self) -> list[Installable]: + return [self.docker_image, self.git_repo, self.mcore_vfm_git_repo] + + @property + def extra_args_str(self) -> str: + parts = [] + for k, v in self.extra_cmd_args.items(): + parts.append(f"{k} {v}" if v else k) + return " ".join(parts) diff --git a/tests/ref_data/gpt-no-hook.sbatch b/tests/ref_data/gpt-no-hook.sbatch index f01e9222..77999bda 100644 --- a/tests/ref_data/gpt-no-hook.sbatch +++ b/tests/ref_data/gpt-no-hook.sbatch @@ -15,8 +15,8 @@ echo "Loading container with srun command" --mpi=none \ \ --export=ALL \ - -o __OUTPUT_DIR__/output-%j-%n-%t.txt \ - -e __OUTPUT_DIR__/error-%j-%n-%t.txt \ + -o __OUTPUT_DIR__/output/output-%j-%n-%t.txt \ + -e __OUTPUT_DIR__/output/error-%j-%n-%t.txt \ --container-name=cont \ - --container-mounts=__OUTPUT_DIR__:/opt/paxml/workspace/ \ + --container-mounts=__OUTPUT_DIR__/output:/opt/paxml/workspace/ \ /opt/paxml/workspace/run.sh diff --git a/tests/ref_data/gpt-pre-test.sbatch b/tests/ref_data/gpt-pre-test.sbatch index c0f6114f..d21f0ed7 100644 --- a/tests/ref_data/gpt-pre-test.sbatch +++ b/tests/ref_data/gpt-pre-test.sbatch @@ -8,8 +8,8 @@ export COMBINE_THRESHOLD=1 export PER_GPU_COMBINE_THRESHOLD=0 export XLA_FLAGS="--xla_gpu_all_gather_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_all_reduce_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_reduce_scatter_combine_threshold_bytes=$PER_GPU_COMBINE_THRESHOLD" -srun --output=__OUTPUT_DIR__/pre_test/nccl/stdout.txt --error=__OUTPUT_DIR__/pre_test/nccl/stderr.txt --mpi=pmix --container-image=nvcr.io/nvidia/pytorch:24.02-py3 /usr/local/bin/all_reduce_perf_mpi --nthreads 1 --ngpus 1 --minbytes 32M --maxbytes 32M --stepbytes 1M --op sum --datatype float --root 0 --iters 20 --warmup_iters 5 --agg_iters 1 --average 1 --parallel_init 0 --check 1 --blocking 0 --cudagraph 0 -SUCCESS_0=$(grep -q "Avg bus bandwidth" __OUTPUT_DIR__/pre_test/nccl/stdout.txt && echo 1 || echo 0) +srun --output=__OUTPUT_DIR__/output/pre_test/nccl/stdout.txt --error=__OUTPUT_DIR__/output/pre_test/nccl/stderr.txt --mpi=pmix --container-image=nvcr.io/nvidia/pytorch:24.02-py3 /usr/local/bin/all_reduce_perf_mpi --nthreads 1 --ngpus 1 --minbytes 32M --maxbytes 32M --stepbytes 1M --op sum --datatype float --root 0 --iters 20 --warmup_iters 5 --agg_iters 1 --average 1 --parallel_init 0 --check 1 --blocking 0 --cudagraph 0 +SUCCESS_0=$(grep -q "Avg bus bandwidth" __OUTPUT_DIR__/output/pre_test/nccl/stdout.txt && echo 1 || echo 0) PRE_TEST_SUCCESS=$( [ $SUCCESS_0 -eq 1 ] && echo 1 || echo 0 ) if [ $PRE_TEST_SUCCESS -eq 1 ]; then echo "Loading container with srun command" @@ -19,9 +19,9 @@ if [ $PRE_TEST_SUCCESS -eq 1 ]; then --mpi=none \ \ --export=ALL \ - -o __OUTPUT_DIR__/output-%j-%n-%t.txt \ - -e __OUTPUT_DIR__/error-%j-%n-%t.txt \ + -o __OUTPUT_DIR__/output/output-%j-%n-%t.txt \ + -e __OUTPUT_DIR__/output/error-%j-%n-%t.txt \ --container-name=cont \ - --container-mounts=__OUTPUT_DIR__:/opt/paxml/workspace/ \ + --container-mounts=__OUTPUT_DIR__/output:/opt/paxml/workspace/ \ /opt/paxml/workspace/run.sh fi diff --git a/tests/ref_data/grok-no-hook.sbatch b/tests/ref_data/grok-no-hook.sbatch index 7e7adfc2..8d008611 100644 --- a/tests/ref_data/grok-no-hook.sbatch +++ b/tests/ref_data/grok-no-hook.sbatch @@ -15,8 +15,8 @@ echo "Loading container with srun command" --mpi=none \ \ --export=ALL \ - -o __OUTPUT_DIR__/output-%j-%n-%t.txt \ - -e __OUTPUT_DIR__/error-%j-%n-%t.txt \ + -o __OUTPUT_DIR__/output/output-%j-%n-%t.txt \ + -e __OUTPUT_DIR__/output/error-%j-%n-%t.txt \ --container-name=cont \ - --container-mounts=__OUTPUT_DIR__:/opt/paxml/workspace/ \ + --container-mounts=__OUTPUT_DIR__/output:/opt/paxml/workspace/ \ /opt/paxml/workspace/run.sh diff --git a/tests/ref_data/grok-pre-test.sbatch b/tests/ref_data/grok-pre-test.sbatch index 51730bd7..7d88745a 100644 --- a/tests/ref_data/grok-pre-test.sbatch +++ b/tests/ref_data/grok-pre-test.sbatch @@ -8,8 +8,8 @@ export COMBINE_THRESHOLD=1 export PER_GPU_COMBINE_THRESHOLD=0 export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization --xla_dump_hlo_pass_re=.* --xla_gpu_all_gather_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_all_reduce_combine_threshold_bytes=$COMBINE_THRESHOLD --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_graph_level=0 --xla_gpu_pgle_profile_file_or_directory_path=/opt/paxml/workspace/pgle_output_profile.pbtxt --xla_gpu_reduce_scatter_combine_threshold_bytes=$PER_GPU_COMBINE_THRESHOLD --xla_gpu_run_post_layout_collective_pipeliner=false --xla_gpu_use_memcpy_local_p2p=false" -srun --output=__OUTPUT_DIR__/pre_test/nccl/stdout.txt --error=__OUTPUT_DIR__/pre_test/nccl/stderr.txt --mpi=pmix --container-image=nvcr.io/nvidia/pytorch:24.02-py3 /usr/local/bin/all_reduce_perf_mpi --nthreads 1 --ngpus 1 --minbytes 32M --maxbytes 32M --stepbytes 1M --op sum --datatype float --root 0 --iters 20 --warmup_iters 5 --agg_iters 1 --average 1 --parallel_init 0 --check 1 --blocking 0 --cudagraph 0 -SUCCESS_0=$(grep -q "Avg bus bandwidth" __OUTPUT_DIR__/pre_test/nccl/stdout.txt && echo 1 || echo 0) +srun --output=__OUTPUT_DIR__/output/pre_test/nccl/stdout.txt --error=__OUTPUT_DIR__/output/pre_test/nccl/stderr.txt --mpi=pmix --container-image=nvcr.io/nvidia/pytorch:24.02-py3 /usr/local/bin/all_reduce_perf_mpi --nthreads 1 --ngpus 1 --minbytes 32M --maxbytes 32M --stepbytes 1M --op sum --datatype float --root 0 --iters 20 --warmup_iters 5 --agg_iters 1 --average 1 --parallel_init 0 --check 1 --blocking 0 --cudagraph 0 +SUCCESS_0=$(grep -q "Avg bus bandwidth" __OUTPUT_DIR__/output/pre_test/nccl/stdout.txt && echo 1 || echo 0) PRE_TEST_SUCCESS=$( [ $SUCCESS_0 -eq 1 ] && echo 1 || echo 0 ) if [ $PRE_TEST_SUCCESS -eq 1 ]; then echo "Loading container with srun command" @@ -19,9 +19,9 @@ if [ $PRE_TEST_SUCCESS -eq 1 ]; then --mpi=none \ \ --export=ALL \ - -o __OUTPUT_DIR__/output-%j-%n-%t.txt \ - -e __OUTPUT_DIR__/error-%j-%n-%t.txt \ + -o __OUTPUT_DIR__/output/output-%j-%n-%t.txt \ + -e __OUTPUT_DIR__/output/error-%j-%n-%t.txt \ --container-name=cont \ - --container-mounts=__OUTPUT_DIR__:/opt/paxml/workspace/ \ + --container-mounts=__OUTPUT_DIR__/output:/opt/paxml/workspace/ \ /opt/paxml/workspace/run.sh fi diff --git a/tests/ref_data/nccl.sbatch b/tests/ref_data/nccl.sbatch index dc179ba9..2a9f57b4 100644 --- a/tests/ref_data/nccl.sbatch +++ b/tests/ref_data/nccl.sbatch @@ -1,8 +1,8 @@ #!/bin/bash #SBATCH --job-name=__JOB_NAME__ #SBATCH -N 1 -#SBATCH --output=__OUTPUT_DIR__/stdout.txt -#SBATCH --error=__OUTPUT_DIR__/stderr.txt +#SBATCH --output=__OUTPUT_DIR__/output/stdout.txt +#SBATCH --error=__OUTPUT_DIR__/output/stderr.txt #SBATCH --partition=main export SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1) diff --git a/tests/ref_data/sleep.sbatch b/tests/ref_data/sleep.sbatch index 9262001b..1ce9ca32 100644 --- a/tests/ref_data/sleep.sbatch +++ b/tests/ref_data/sleep.sbatch @@ -1,8 +1,8 @@ #!/bin/bash #SBATCH --job-name=__JOB_NAME__ #SBATCH -N 1 -#SBATCH --output=__OUTPUT_DIR__/stdout.txt -#SBATCH --error=__OUTPUT_DIR__/stderr.txt +#SBATCH --output=__OUTPUT_DIR__/output/stdout.txt +#SBATCH --error=__OUTPUT_DIR__/output/stderr.txt #SBATCH --partition=main export SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1) diff --git a/tests/ref_data/slurm_container.sbatch b/tests/ref_data/slurm_container.sbatch new file mode 100644 index 00000000..6479402b --- /dev/null +++ b/tests/ref_data/slurm_container.sbatch @@ -0,0 +1,11 @@ +#!/bin/bash +#SBATCH --job-name=__JOB_NAME__ +#SBATCH -N 1 +#SBATCH --output=__OUTPUT_DIR__/output/stdout.txt +#SBATCH --error=__OUTPUT_DIR__/output/stderr.txt +#SBATCH --partition=main + +export SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1) + + +srun --mpi=pmix --container-image=https://docker/url --container-mounts=__OUTPUT_DIR__/install/url__commit_hash:/work,__OUTPUT_DIR__/install/repo__mcore_vfm_commit_hash:/opt/megatron-lm --no-container-mount-home bash -c "pwd ; ls" diff --git a/tests/ref_data/ucc.sbatch b/tests/ref_data/ucc.sbatch index a9f9e686..a3f5fca8 100644 --- a/tests/ref_data/ucc.sbatch +++ b/tests/ref_data/ucc.sbatch @@ -1,8 +1,8 @@ #!/bin/bash #SBATCH --job-name=__JOB_NAME__ #SBATCH -N 1 -#SBATCH --output=__OUTPUT_DIR__/stdout.txt -#SBATCH --error=__OUTPUT_DIR__/stderr.txt +#SBATCH --output=__OUTPUT_DIR__/output/stdout.txt +#SBATCH --error=__OUTPUT_DIR__/output/stderr.txt #SBATCH --partition=main export SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1) diff --git a/tests/test_acceptance.py b/tests/test_acceptance.py index 7c2e61ce..1d7490fa 100644 --- a/tests/test_acceptance.py +++ b/tests/test_acceptance.py @@ -31,6 +31,8 @@ from cloudai.schema.test_template.nemo_launcher.template import NeMoLauncher from cloudai.schema.test_template.sleep.slurm_command_gen_strategy import SleepSlurmCommandGenStrategy from cloudai.schema.test_template.sleep.template import Sleep +from cloudai.schema.test_template.slurm_container.slurm_command_gen_strategy import SlurmContainerCommandGenStrategy +from cloudai.schema.test_template.slurm_container.template import SlurmContainer from cloudai.schema.test_template.ucc_test.slurm_command_gen_strategy import UCCTestSlurmCommandGenStrategy from cloudai.systems import SlurmSystem from cloudai.test_definitions.gpt import GPTCmdArgs, GPTTestDefinition @@ -38,6 +40,7 @@ from cloudai.test_definitions.nccl import NCCLCmdArgs, NCCLTestDefinition from cloudai.test_definitions.nemo_launcher import NeMoLauncherCmdArgs, NeMoLauncherTestDefinition from cloudai.test_definitions.sleep import SleepCmdArgs, SleepTestDefinition +from cloudai.test_definitions.slurm_container import SlurmContainerCmdArgs, SlurmContainerTestDefinition from cloudai.test_definitions.ucc import UCCCmdArgs, UCCTestDefinition SLURM_TEST_SCENARIOS = [ @@ -99,7 +102,17 @@ def partial_tr(slurm_system: SlurmSystem) -> partial[TestRun]: @pytest.fixture( - params=["ucc", "nccl", "sleep", "gpt-pre-test", "gpt-no-hook", "grok-pre-test", "grok-no-hook", "nemo-launcher"] + params=[ + "ucc", + "nccl", + "sleep", + "gpt-pre-test", + "gpt-no-hook", + "grok-pre-test", + "grok-no-hook", + "nemo-launcher", + "slurm_container", + ] ) def test_req(request, slurm_system: SlurmSystem, partial_tr: partial[TestRun]) -> tuple[TestRun, str, Optional[str]]: if request.param == "ucc": @@ -239,6 +252,32 @@ def test_req(request, slurm_system: SlurmSystem, partial_tr: partial[TestRun]) - tr.test.test_template.command_gen_strategy.job_name = Mock(return_value="job_name") return (tr, "nemo-launcher.sbatch", None) + elif request.param == "slurm_container": + tr = partial_tr( + name="slurm_container", + test=Test( + test_definition=SlurmContainerTestDefinition( + name="slurm_container", + description="slurm_container", + test_template_name="slurm_container", + cmd_args=SlurmContainerCmdArgs( + docker_image_url="https://docker/url", + repository_url="https://repo/url", + repository_commit_hash="commit_hash", + mcore_vfm_repo="https://mcore_vfm/repo", + mcore_vfm_commit_hash="mcore_vfm_commit_hash", + ), + extra_cmd_args={"bash": '-c "pwd ; ls"'}, + ), + test_template=SlurmContainer(slurm_system, name="slurm_container"), + ), + ) + tr.test.test_template.command_gen_strategy = SlurmContainerCommandGenStrategy( + slurm_system, tr.test.test_definition.cmd_args_dict + ) + tr.test.test_template.command_gen_strategy.job_name = Mock(return_value="job_name") + + return (tr, "slurm_container.sbatch", None) raise ValueError(f"Unknown test: {request.param}") @@ -248,14 +287,12 @@ def test_sbatch_generation(slurm_system: SlurmSystem, test_req: tuple[TestRun, s tr = test_req[0] - sbatch_script = tr.test.test_template.gen_exec_command(tr).split()[-1] ref = (Path(__file__).parent / "ref_data" / test_req[1]).read_text().strip() + ref = ref.replace("__OUTPUT_DIR__", str(slurm_system.output_path.parent)).replace("__JOB_NAME__", "job_name") + + sbatch_script = tr.test.test_template.gen_exec_command(tr).split()[-1] if "nemo-launcher" in test_req[1]: sbatch_script = slurm_system.output_path / "generated_command.sh" - ref = ref.replace("__OUTPUT_DIR__", str(slurm_system.output_path.parent)) - else: - ref = ref.replace("__OUTPUT_DIR__", str(slurm_system.output_path)).replace("__JOB_NAME__", "job_name") - curr = Path(sbatch_script).read_text().strip() assert curr == ref diff --git a/tests/test_init.py b/tests/test_init.py index 410e154b..07ca9268 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -23,6 +23,8 @@ JsonGenStrategy, Registry, ReportGenerationStrategy, + SlurmContainer, + SlurmContainerTestDefinition, ) from cloudai.installer.slurm_installer import SlurmInstaller from cloudai.installer.standalone_installer import StandaloneInstaller @@ -127,12 +129,13 @@ def test_strategies(key: tuple, value: type): def test_test_templates(): test_templates = Registry().test_templates_map - assert len(test_templates) == 8 + assert len(test_templates) == 9 assert test_templates["ChakraReplay"] == ChakraReplay assert test_templates["NcclTest"] == NcclTest assert test_templates["NeMoLauncher"] == NeMoLauncher assert test_templates["Sleep"] == Sleep assert test_templates["UCCTest"] == UCCTest + assert test_templates["SlurmContainer"] == SlurmContainer def test_installers(): @@ -144,12 +147,13 @@ def test_installers(): def test_definitions(): test_defs = Registry().test_definitions_map - assert len(test_defs) == 8 + assert len(test_defs) == 9 assert test_defs["UCCTest"] == UCCTestDefinition assert test_defs["NcclTest"] == NCCLTestDefinition assert test_defs["ChakraReplay"] == ChakraReplayTestDefinition assert test_defs["Sleep"] == SleepTestDefinition assert test_defs["NeMoLauncher"] == NeMoLauncherTestDefinition + assert test_defs["SlurmContainer"] == SlurmContainerTestDefinition def test_definitions_matches_templates():