Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Job on partition new #261

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def _prepare_environment(self, cmd_args: Dict[str, str], extra_env_vars: Dict[st
if "training.values" in self.final_cmd_args:
self.final_cmd_args["training"] = self.final_cmd_args.pop("training.values")

self.final_cmd_args["cluster.partition"] = self.system.default_partition
self._handle_reservation()

def _handle_reservation(self) -> None:
Expand Down
138 changes: 98 additions & 40 deletions src/cloudai/systems/slurm/slurm_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,14 @@ def groups(self) -> Dict[str, Dict[str, List[SlurmNode]]]:
groups: Dict[str, Dict[str, List[SlurmNode]]] = {}
for part in self.partitions:
groups[part.name] = {}
for group in part.groups:
node_names = set()
for group_nodes in group.nodes:
node_names.update(set(parse_node_list(group_nodes)))
groups[part.name][group.name] = [node for node in part.slurm_nodes if node.name in node_names]
if part.groups:
for group in part.groups:
node_names = set()
for group_nodes in group.nodes:
node_names.update(set(parse_node_list(group_nodes)))
groups[part.name][group.name] = [node for node in part.slurm_nodes if node.name in node_names]
else:
groups[part.name][part.name] = [node for node in part.slurm_nodes]

return groups

Expand Down Expand Up @@ -445,7 +448,7 @@ def get_group_node_names(self, partition_name: str, group_name: str) -> List[str
return [node.name for node in self.get_group_nodes(partition_name, group_name)]

def get_available_nodes_from_group(
self, partition_name: str, group_name: str, number_of_nodes: Union[int, str]
self, number_of_nodes: Union[int, str], partition_name: str, group_name: Optional[str] = None
) -> List[SlurmNode]:
"""
Retrieve a specific number of potentially available nodes from a group within a partition.
Expand All @@ -466,31 +469,24 @@ def get_available_nodes_from_group(
ValueError: If the partition or group is not found, or if the requested number of nodes exceeds the
available nodes.
"""
group_print = f"group '{group_name}' in " if group_name else ""

self.validate_partition_and_group(partition_name, group_name)

self.update_node_states()

grouped_nodes = self.group_nodes_by_state(partition_name, group_name)

try:
allocated_nodes = self.allocate_nodes(grouped_nodes, number_of_nodes, group_name)

logging.info(
f"Allocated nodes from group '{group_name}' in partition '{partition_name}': "
f"{[node.name for node in allocated_nodes]}"
)

return allocated_nodes
allocated_nodes = self.allocate_nodes(grouped_nodes, number_of_nodes, partition_name, group_name)

except ValueError as e:
logging.error(
f"Error occurred while allocating nodes from group '{group_name}' in partition '{partition_name}': {e}",
exc_info=True,
)
logging.info(
f"Allocated nodes from {group_print}partition '{partition_name}': "
f"{[node.name for node in allocated_nodes]}"
)

return []
return allocated_nodes

def validate_partition_and_group(self, partition_name: str, group_name: str) -> None:
def validate_partition_and_group(self, partition_name: str, group_name: Optional[str] = None) -> None:
"""
Validate that the partition and group exist.

Expand All @@ -504,10 +500,12 @@ def validate_partition_and_group(self, partition_name: str, group_name: str) ->
"""
if partition_name not in self.groups:
raise ValueError(f"Partition '{partition_name}' not found.")
if group_name not in self.groups[partition_name]:
if group_name and group_name not in self.groups[partition_name]:
raise ValueError(f"Group '{group_name}' not found in partition '{partition_name}'.")

def group_nodes_by_state(self, partition_name: str, group_name: str) -> Dict[SlurmNodeState, List[SlurmNode]]:
def group_nodes_by_state(
self, partition_name: str, group_name: Optional[str]
) -> Dict[SlurmNodeState, List[SlurmNode]]:
"""
Group nodes by their states, excluding nodes allocated to the current user.

Expand All @@ -519,20 +517,41 @@ def group_nodes_by_state(self, partition_name: str, group_name: str) -> Dict[Slu
Returns:
Dict[SlurmNodeState, List[SlurmNode]]: A dictionary grouping nodes by their state.
"""
grouped_nodes = {
SlurmNodeState.IDLE: [],
SlurmNodeState.COMPLETING: [],
SlurmNodeState.ALLOCATED: [],
}
if self.extra_srun_args and "reservation" in self.extra_srun_args:
reservation_key = "--reservation "
reservation_name = self.extra_srun_args.split(reservation_key, 1)[1].split(" ", 1)[0]
reservation_output = self.get_reservation()
reserved_nodes = self.parse_reservation_output(reservation_output, reservation_name)
grouped_nodes = {
SlurmNodeState.RESERVED: [],
}
else:
reserved_nodes = []
grouped_nodes = {
SlurmNodeState.IDLE: [],
SlurmNodeState.COMPLETING: [],
SlurmNodeState.ALLOCATED: [],
}

if group_name:
nodes = self.groups[partition_name][group_name]
else:
nodes = []
for group_name in self.groups[partition_name]:
nodes.extend(self.groups[partition_name][group_name])

for node in self.groups[partition_name][group_name]:
if node.state in grouped_nodes:
for node in nodes:
if node.state in grouped_nodes and (not reserved_nodes or node.name in reserved_nodes):
grouped_nodes[node.state].append(node)

return grouped_nodes

def allocate_nodes(
self, grouped_nodes: Dict[SlurmNodeState, List[SlurmNode]], number_of_nodes: Union[int, str], group_name: str
self,
grouped_nodes: Dict[SlurmNodeState, List[SlurmNode]],
number_of_nodes: Union[int, str],
partition_name: str,
group_name: Optional[str],
) -> List[SlurmNode]:
"""
Allocate nodes based on the requested number or maximum availability.
Expand All @@ -541,6 +560,7 @@ def allocate_nodes(
grouped_nodes (Dict[SlurmNodeState, List[SlurmNode]]): Nodes grouped by their state.
number_of_nodes (Union[int, str]): The number of nodes to allocate, or 'max_avail' to allocate
all available nodes.
partition_name (str): The name of the partition.
group_name (str): The name of the group.

Returns:
Expand All @@ -549,6 +569,8 @@ def allocate_nodes(
Raises:
ValueError: If the requested number of nodes exceeds the available nodes.
"""
# Allocate nodes based on priority: idle, then completing, then allocated
group_or_partition = f"group '{group_name}'" if group_name else f"partition '{partition_name}'"
allocated_nodes = []

if isinstance(number_of_nodes, str) and number_of_nodes == "max_avail":
Expand All @@ -557,7 +579,7 @@ def allocate_nodes(

if len(allocated_nodes) == 0:
raise ValueError(
f"CloudAI is requesting the maximum available nodes from the group '{group_name}', "
f"CloudAI is requesting the maximum available nodes from the {group_or_partition}, "
f"but no nodes are available. Please review the available nodes in the system and ensure "
f"there are sufficient resources to meet the requirements of the test scenario. Additionally, "
f"verify that the system is capable of hosting the maximum number of nodes specified in the test "
Expand All @@ -571,9 +593,9 @@ def allocate_nodes(

if len(allocated_nodes) < number_of_nodes:
raise ValueError(
f"CloudAI is requesting {number_of_nodes} nodes from the group '{group_name}', but only "
f"{len(allocated_nodes)} nodes are available. Please review the available nodes in the system "
f"and ensure there are enough resources to meet the requested node count. Additionally, "
f"CloudAI is requesting {number_of_nodes} nodes from the {group_or_partition}, but there are only "
f"{len(allocated_nodes)} nodes in {group_or_partition}. Please review the available nodes in the "
f"system and ensure there are enough resources to meet the requested node count. Additionally, "
f"verify that the system can accommodate the number of nodes required by the test scenario."
)
else:
Expand Down Expand Up @@ -639,6 +661,16 @@ def get_sinfo(self) -> str:
sinfo_output, _ = self.fetch_command_output("sinfo")
return sinfo_output

def get_reservation(self) -> str:
"""
Fetch the output from the 'scontrol show reservation' command.

Returns
str: The stdout from the 'scontrol show reservation' command execution.
"""
reservation_output, _ = self.fetch_command_output("scontrol show reservation")
return reservation_output

def fetch_command_output(self, command: str) -> Tuple[str, str]:
"""
Execute a system command and return its output.
Expand Down Expand Up @@ -715,6 +747,27 @@ def parse_sinfo_output(self, sinfo_output: str, node_user_map: Dict[str, str]) -
node.user = node_user_map.get(node_name, "N/A")
break

def parse_reservation_output(self, reservation_output: str, reservation_name: str) -> List[str]:
"""
Parse the output from the 'scontrol show reservation' command to get reserved nodes from this reservation.

The expected format of scontrol show reservation is lines of 'ReservationName='.

Args:
reservation_output (str): The raw output from the scontrol show reservation command.
reservation_name (str): The name of the reservation specified.

Returns:
Dict[str, str]: A dictionary mapping node names to usernames.
"""
node_list = []
for reservation in reservation_output.split("ReservationName"):
if reservation_name in reservation:
nodes = reservation.split("Nodes=")[1].split(" ")[0]
node_list = parse_node_list(nodes)

return node_list

def convert_state_to_enum(self, state_str: str) -> SlurmNodeState:
"""
Convert a Slurm node state string to its corresponding enum member.
Expand Down Expand Up @@ -797,11 +850,16 @@ def parse_nodes(self, nodes: List[str]) -> List[str]:
for node_spec in nodes:
if ":" in node_spec:
parts = node_spec.split(":")
if len(parts) != 3:
raise ValueError("Format should be partition:group:num_nodes")
partition_name, group_name, num_nodes_spec = parts
if len(parts) == 2:
partition_name, num_nodes_spec = parts
group_name = None
self.default_partition = partition_name
elif len(parts) == 3:
partition_name, group_name, num_nodes_spec = parts
else:
raise ValueError("Format should be partition:group:num_nodes or partition:num_nodes")
num_nodes = int(num_nodes_spec) if num_nodes_spec != "max_avail" else num_nodes_spec
group_nodes = self.get_available_nodes_from_group(partition_name, group_name, num_nodes)
group_nodes = self.get_available_nodes_from_group(num_nodes, partition_name, group_name)
parsed_nodes += [node.name for node in group_nodes]
else:
# Handle both individual node names and ranges
Expand Down
34 changes: 12 additions & 22 deletions tests/test_slurm_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

import re
from typing import Dict, List
from typing import List
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -168,23 +168,11 @@ def grouped_nodes() -> dict[SlurmNodeState, list[SlurmNode]]:
return grouped_nodes


def test_get_available_nodes_exceeding_limit_no_callstack(
slurm_system: SlurmSystem, grouped_nodes: Dict[SlurmNodeState, List[SlurmNode]], caplog
):
group_name = "group1"
partition_name = "main"
num_nodes = 5

slurm_system.get_available_nodes_from_group(partition_name, group_name, num_nodes)

log_message = "CloudAI is requesting 5 nodes from the group 'group1', but only 0 nodes are available."
assert log_message in caplog.text


def test_allocate_nodes_max_avail(slurm_system: SlurmSystem, grouped_nodes: dict[SlurmNodeState, list[SlurmNode]]):
partition_name = "main"
group_name = "group_name"

available_nodes = slurm_system.allocate_nodes(grouped_nodes, "max_avail", group_name)
available_nodes = slurm_system.allocate_nodes(grouped_nodes, "max_avail", partition_name, group_name)
expected_node_names = [
grouped_nodes[SlurmNodeState.IDLE][0].name,
grouped_nodes[SlurmNodeState.IDLE][1].name,
Expand All @@ -202,9 +190,10 @@ def test_allocate_nodes_max_avail(slurm_system: SlurmSystem, grouped_nodes: dict
def test_allocate_nodes_num_nodes_integers(
slurm_system: SlurmSystem, grouped_nodes: dict[SlurmNodeState, list[SlurmNode]]
):
partition_name = "main"
group_name = "group_name"

available_nodes = slurm_system.allocate_nodes(grouped_nodes, 1, group_name)
available_nodes = slurm_system.allocate_nodes(grouped_nodes, 1, partition_name, group_name)
expected_node_names = [grouped_nodes[SlurmNodeState.IDLE][0].name]

returned_node_names = [node.name for node in available_nodes]
Expand All @@ -215,17 +204,18 @@ def test_allocate_nodes_num_nodes_integers(
def test_allocate_nodes_exceeding_limit(
slurm_system: SlurmSystem, grouped_nodes: dict[SlurmNodeState, list[SlurmNode]]
):
group_name = "group_name"
partition_name = "main"
group_name = "group1"
num_nodes = 5
available_nodes = 4
total_nodes = 4

with pytest.raises(
ValueError,
match=re.escape(
f"CloudAI is requesting {num_nodes} nodes from the group '{group_name}', but only "
f"{available_nodes} nodes are available. Please review the available nodes in the system "
f"and ensure there are enough resources to meet the requested node count. Additionally, "
f"CloudAI is requesting {num_nodes} nodes from the group '{group_name}', but there are only "
f"{total_nodes} nodes in group '{group_name}'. Please review the available nodes in the "
f"system and ensure there are enough resources to meet the requested node count. Additionally, "
f"verify that the system can accommodate the number of nodes required by the test scenario."
),
):
slurm_system.allocate_nodes(grouped_nodes, num_nodes, group_name)
slurm_system.allocate_nodes(grouped_nodes, num_nodes, partition_name, group_name)
Loading