diff --git a/src/cloudai/schema/test_template/nemo_launcher/slurm_command_gen_strategy.py b/src/cloudai/schema/test_template/nemo_launcher/slurm_command_gen_strategy.py index d743b72e..fb0443af 100644 --- a/src/cloudai/schema/test_template/nemo_launcher/slurm_command_gen_strategy.py +++ b/src/cloudai/schema/test_template/nemo_launcher/slurm_command_gen_strategy.py @@ -110,7 +110,6 @@ def _prepare_environment(self, cmd_args: Dict[str, str], extra_env_vars: Dict[st if "training.values" in self.final_cmd_args: self.final_cmd_args["training"] = self.final_cmd_args.pop("training.values") - self.final_cmd_args["cluster.partition"] = self.system.default_partition self._handle_reservation() def _handle_reservation(self) -> None: diff --git a/src/cloudai/systems/slurm/slurm_system.py b/src/cloudai/systems/slurm/slurm_system.py index e716f4a9..9ce87aec 100644 --- a/src/cloudai/systems/slurm/slurm_system.py +++ b/src/cloudai/systems/slurm/slurm_system.py @@ -142,11 +142,14 @@ def groups(self) -> Dict[str, Dict[str, List[SlurmNode]]]: groups: Dict[str, Dict[str, List[SlurmNode]]] = {} for part in self.partitions: groups[part.name] = {} - for group in part.groups: - node_names = set() - for group_nodes in group.nodes: - node_names.update(set(parse_node_list(group_nodes))) - groups[part.name][group.name] = [node for node in part.slurm_nodes if node.name in node_names] + if part.groups: + for group in part.groups: + node_names = set() + for group_nodes in group.nodes: + node_names.update(set(parse_node_list(group_nodes))) + groups[part.name][group.name] = [node for node in part.slurm_nodes if node.name in node_names] + else: + groups[part.name][part.name] = [node for node in part.slurm_nodes] return groups @@ -474,23 +477,14 @@ def get_available_nodes_from_group( grouped_nodes = self.group_nodes_by_state(partition_name, group_name) - try: - allocated_nodes = self.allocate_nodes(grouped_nodes, number_of_nodes, partition_name, group_name) + allocated_nodes = self.allocate_nodes(grouped_nodes, number_of_nodes, partition_name, group_name) - logging.info( - f"Allocated nodes from {group_print}partition '{partition_name}': " - f"{[node.name for node in allocated_nodes]}" - ) + logging.info( + f"Allocated nodes from {group_print}partition '{partition_name}': " + f"{[node.name for node in allocated_nodes]}" + ) - return allocated_nodes - - except ValueError as e: - logging.error( - f"Error occurred while allocating nodes from group '{group_name}' in partition '{partition_name}': {e}", - exc_info=True, - ) - - return [] + return allocated_nodes def validate_partition_and_group(self, partition_name: str, group_name: Optional[str] = None) -> None: """ @@ -538,12 +532,14 @@ def group_nodes_by_state( SlurmNodeState.COMPLETING: [], SlurmNodeState.ALLOCATED: [], } + if group_name: nodes = self.groups[partition_name][group_name] else: nodes = [] for group_name in self.groups[partition_name]: nodes.extend(self.groups[partition_name][group_name]) + for node in nodes: if node.state in grouped_nodes and (not reserved_nodes or node.name in reserved_nodes): grouped_nodes[node.state].append(node) @@ -597,9 +593,9 @@ def allocate_nodes( if len(allocated_nodes) < number_of_nodes: raise ValueError( - f"CloudAI is requesting {number_of_nodes} nodes from the {group_or_partition}, but only " - f"{len(allocated_nodes)} nodes are available. Please review the available nodes in the system " - f"and ensure there are enough resources to meet the requested node count. Additionally, " + f"CloudAI is requesting {number_of_nodes} nodes from the {group_or_partition}, but there are only " + f"{len(allocated_nodes)} nodes in {group_or_partition}. Please review the available nodes in the " + f"system and ensure there are enough resources to meet the requested node count. Additionally, " f"verify that the system can accommodate the number of nodes required by the test scenario." ) else: @@ -857,6 +853,7 @@ def parse_nodes(self, nodes: List[str]) -> List[str]: if len(parts) == 2: partition_name, num_nodes_spec = parts group_name = None + self.default_partition = partition_name elif len(parts) == 3: partition_name, group_name, num_nodes_spec = parts else: diff --git a/tests/test_slurm_system.py b/tests/test_slurm_system.py index 1a5346dd..e63c94c0 100644 --- a/tests/test_slurm_system.py +++ b/tests/test_slurm_system.py @@ -15,7 +15,7 @@ # limitations under the License. import re -from typing import Dict, List +from typing import List from unittest.mock import patch import pytest @@ -168,19 +168,6 @@ def grouped_nodes() -> dict[SlurmNodeState, list[SlurmNode]]: return grouped_nodes -def test_get_available_nodes_exceeding_limit_no_callstack( - slurm_system: SlurmSystem, grouped_nodes: Dict[SlurmNodeState, List[SlurmNode]], caplog -): - group_name = "group1" - partition_name = "main" - num_nodes = 5 - - slurm_system.get_available_nodes_from_group(num_nodes, partition_name, group_name) - - log_message = "CloudAI is requesting 5 nodes from the group 'group1', but only 0 nodes are available." - assert log_message in caplog.text - - def test_allocate_nodes_max_avail(slurm_system: SlurmSystem, grouped_nodes: dict[SlurmNodeState, list[SlurmNode]]): partition_name = "main" group_name = "group_name" @@ -218,16 +205,16 @@ def test_allocate_nodes_exceeding_limit( slurm_system: SlurmSystem, grouped_nodes: dict[SlurmNodeState, list[SlurmNode]] ): partition_name = "main" - group_name = "group_name" + group_name = "group1" num_nodes = 5 - available_nodes = 4 + total_nodes = 4 with pytest.raises( ValueError, match=re.escape( - f"CloudAI is requesting {num_nodes} nodes from the group '{group_name}', but only " - f"{available_nodes} nodes are available. Please review the available nodes in the system " - f"and ensure there are enough resources to meet the requested node count. Additionally, " + f"CloudAI is requesting {num_nodes} nodes from the group '{group_name}', but there are only " + f"{total_nodes} nodes in group '{group_name}'. Please review the available nodes in the " + f"system and ensure there are enough resources to meet the requested node count. Additionally, " f"verify that the system can accommodate the number of nodes required by the test scenario." ), ):