diff --git a/apps/crc_idle.py b/apps/crc_idle.py index dfa485c..20fb413 100755 --- a/apps/crc_idle.py +++ b/apps/crc_idle.py @@ -5,29 +5,29 @@ Resource summaries are provided for GPU and CPU partitions. """ -from argparse import Namespace -from typing import Dict, Tuple import re +from argparse import Namespace +from collections import defaultdict +from .utils import Shell, Slurm from .utils.cli import BaseParser -from .utils.system_info import Shell, Slurm class CrcIdle(BaseParser): """Display idle Slurm resources.""" - # The type of resource available on a cluster - # Either ``cores`` or ``GPUs`` depending on the cluster type - cluster_types = { - 'smp': 'cores', - 'gpu': 'GPUs', - 'mpi': 'cores', - 'htc': 'cores' - } - default_type = 'cores' + # Specify the type of resource available on each cluster + # Either `cores` or `GPUs` depending on the cluster type + cluster_types = defaultdict( + lambda: 'cores', + smp='cores', + gpu='GPUs', + mpi='cores', + htc='cores', + ) def __init__(self) -> None: - """Define arguments for the command line interface""" + """Define arguments for the command line interface.""" super(CrcIdle, self).__init__() self.add_argument('-s', '--smp', action='store_true', help='list idle resources on the smp cluster') @@ -37,8 +37,8 @@ def __init__(self) -> None: self.add_argument('-d', '--htc', action='store_true', help='list idle resources on the htc cluster') self.add_argument('-p', '--partition', nargs='+', help='only include information for specific partitions') - def get_cluster_list(self, args: Namespace) -> Tuple[str]: - """Return a list of clusters specified by command line arguments + def get_cluster_list(self, args: Namespace) -> tuple[str]: + """Return a list of clusters specified by command line arguments. Returns a tuple of clusters specified by command line arguments. If no clusters were specified, then return a tuple of all cluster names. @@ -50,64 +50,61 @@ def get_cluster_list(self, args: Namespace) -> Tuple[str]: A tuple of cluster names """ - argument_options = self.cluster_types - argument_clusters = tuple(filter(lambda cluster: getattr(args, cluster), argument_options)) + # Select only the specified clusters + argument_clusters = tuple(self.cluster_types.keys()) + specified_clusters = tuple(filter(lambda cluster: getattr(args, cluster), argument_clusters)) # Default to returning all clusters - return argument_clusters or argument_options + return specified_clusters or argument_clusters @staticmethod - def _idle_cpu_resources(cluster: str, partition: str) -> Dict[int, int]: - """Return the idle CPU resources on a given cluster partition + def _count_idle_cpu_resources(cluster: str, partition: str) -> dict[int, int]: + """Return the idle CPU resources on a given cluster partition. Args: - cluster: The cluster to print a summary for - partition: The partition in the parent cluster + cluster: The cluster to print a summary for. + partition: The partition in the parent cluster. Returns: - A dictionary mapping idle resources to number of nodes + A dictionary mapping the number of idle resources to the number of nodes with that many idle resources. """ # Use `sinfo` command to determine the status of each node in the given partition command = f'sinfo -h -M {cluster} -p {partition} -N -o %N,%C' - stdout = Shell.run_command(command) - slurm_data = stdout.strip().split() + slurm_data = Shell.run_command(command).strip().split() # Count the number of nodes having a given number of idle cores/GPUs return_dict = dict() for node_info in slurm_data: - _, resource_data = node_info.split(',') # Returns: node_name, resource_data - _, idle, _, _ = [int(x) for x in resource_data.split('/')] # Returns: allocated, idle, other, total + node_name, resource_data = node_info.split(',') + allocated, idle, other, total = [int(x) for x in resource_data.split('/')] return_dict[idle] = return_dict.setdefault(idle, 0) + 1 return return_dict @staticmethod - def _idle_gpu_resources(cluster: str, partition: str) -> Dict[int, int]: - """Return idle GPU resources on a given cluster partition + def _count_idle_gpu_resources(cluster: str, partition: str) -> dict[int, int]: + """Return idle GPU resources on a given cluster partition. - If the host node is in a ``drain`` state, the GPUs are reported as unavailable. + If the host node is in a `drain` state, the GPUs are reported as unavailable. Args: - cluster: The cluster to print a summary for - partition: The partition in the parent cluster + cluster: The cluster to print a summary for. + partition: The partition in the parent cluster. Returns: - A dictionary mapping idle resources to number of nodes + A dictionary mapping the number of idle resources to the number of nodes with that many idle resources. """ # Use `sinfo` command to determine the status of each node in the given partition - command = f"sinfo -h -M {cluster} -p {partition} -N " \ - f"--Format=NodeList:'_',gres:5'_',gresUsed:12'_',StateCompact:' '" - - stdout = Shell.run_command(command) - slurm_data = stdout.strip().split() + slurm_output_format = "NodeList:'_',gres:5'_',gresUsed:12'_',StateCompact:' '" + command = f"sinfo -h -M {cluster} -p {partition} -N --Format={slurm_output_format}" + slurm_data = Shell.run_command(command).strip().split() # Count the number of nodes having a given number of idle cores/GPUs return_dict = dict() for node_info in slurm_data: - # Returns: node_name, total, allocated, node state - _, total, allocated, state = node_info.split('_') + node_name, total, allocated, state = node_info.split('_') # If the node is in a downed state, report 0 resource availability. if re.search("drain", state): @@ -122,61 +119,61 @@ def _idle_gpu_resources(cluster: str, partition: str) -> Dict[int, int]: return return_dict - def count_idle_resources(self, cluster: str, partition: str) -> Dict[int, int]: - """Determine the number of idle resources on a given cluster partition + def count_idle_resources(self, cluster: str, partition: str) -> dict[int, int]: + """Determine the number of idle resources on a given cluster partition. The returned dictionary maps the number of idle resources (e.g., cores) to the number of nodes in the partition having that many resources idle. Args: - cluster: The cluster to print a summary for - partition: The partition in the parent cluster + cluster: The cluster to print a summary for. + partition: The partition in the parent cluster. Returns: - A dictionary mapping idle resources to number of nodes + A dictionary mapping idle resources to number of nodes. """ - cluster_type = self.cluster_types.get(cluster, self.default_type) + cluster_type = self.cluster_types[cluster] if cluster_type == 'GPUs': - return self._idle_gpu_resources(cluster, partition) + return self._count_idle_gpu_resources(cluster, partition) elif cluster_type == 'cores': - return self._idle_cpu_resources(cluster, partition) + return self._count_idle_cpu_resources(cluster, partition) raise ValueError(f'Unknown cluster type: {cluster}') - def print_partition_summary(self, cluster: str, partition: str) -> None: + def print_partition_summary(self, cluster: str, partition: str, idle_resources: dict) -> None: """Print a summary of idle resources in a single partition Args: cluster: The cluster to print a summary for partition: The partition in the parent cluster + idle_resources: Dictionary mapping idle resources to number of nodes """ - resource_allocation = self.count_idle_resources(cluster, partition) - output_width = 30 header = f'Cluster: {cluster}, Partition: {partition}' - unit = self.cluster_types.get(cluster, self.default_type) + unit = self.cluster_types[cluster] print(header) print('=' * output_width) - for idle, nodes in sorted(resource_allocation.items()): + for idle, nodes in sorted(idle_resources.items()): print(f'{nodes:4d} nodes w/ {idle:3d} idle {unit}') - if not resource_allocation: + if not idle_resources: print(' No idle resources') print('') def app_logic(self, args: Namespace) -> None: - """Logic to evaluate when executing the application + """Logic to evaluate when executing the application. Args: - args: Parsed command line arguments + args: Parsed command line arguments. """ for cluster in self.get_cluster_list(args): partitions_to_print = args.partition or Slurm.get_partition_names(cluster) for partition in partitions_to_print: - self.print_partition_summary(cluster, partition) + idle_resources = self.count_idle_resources(cluster, partition) + self.print_partition_summary(cluster, partition, idle_resources) diff --git a/apps/crc_interactive.py b/apps/crc_interactive.py index e280bb6..84b382a 100755 --- a/apps/crc_interactive.py +++ b/apps/crc_interactive.py @@ -1,4 +1,4 @@ -"""A simple wrapper around the Slurm ``srun`` command +"""A simple wrapper around the Slurm `srun` command. The application launches users into an interactive Slurm session on a user-selected cluster and (if specified) partition. Dedicated command line @@ -9,7 +9,8 @@ to be manually added (or removed) by updating the application CLI arguments. """ -from argparse import Namespace, ArgumentTypeError +from argparse import ArgumentTypeError, Namespace +from collections import defaultdict from datetime import time from os import system @@ -21,7 +22,7 @@ class CrcInteractive(BaseParser): """Launch an interactive Slurm session.""" min_mpi_nodes = 2 # Minimum limit on requested MPI nodes - min_mpi_cores = {'mpi': 48, 'opa-high-mem': 28} + min_mpi_cores = defaultdict(lambda: 28, {'mpi': 48, 'opa-high-mem': 28}) # Minimum cores per MPI partition min_time = 1 # Minimum limit on requested time in hours max_time = 12 # Maximum limit on requested time in hours @@ -32,6 +33,17 @@ class CrcInteractive(BaseParser): default_mem = 1 # Default memory in GB default_gpus = 0 # Default number of GPUs + # Clusters names to make available from the command line + # Maps cluster name to single character abbreviation use in the CLI + clusters = { + 'smp': 's', + 'gpu': 'g', + 'mpi': 'm', + 'invest': 'i', + 'htc': 'd', + 'teach': 'e' + } + def __init__(self) -> None: """Define arguments for the command line interface.""" @@ -42,13 +54,9 @@ def __init__(self) -> None: # Arguments for specifying what cluster to start an interactive session on cluster_args = self.add_argument_group('Cluster Arguments') - cluster_args.add_argument('-s', '--smp', action='store_true', help='launch a session on the smp cluster') - cluster_args.add_argument('-g', '--gpu', action='store_true', help='launch a session on the gpu cluster') - cluster_args.add_argument('-m', '--mpi', action='store_true', help='launch a session on the mpi cluster') - cluster_args.add_argument('-i', '--invest', action='store_true', help='launch a session on the invest cluster') - cluster_args.add_argument('-d', '--htc', action='store_true', help='launch a session on the htc cluster') - cluster_args.add_argument('-e', '--teach', action='store_true', help='launch a session on the teach cluster') cluster_args.add_argument('-p', '--partition', help='run the session on a specific partition') + for cluster, abbrev in self.clusters.items(): + cluster_args.add_argument(f'-{abbrev}', f'--{cluster}', action='store_true', help=f'launch a session on the {cluster} cluster') # Arguments for requesting additional hardware resources resource_args = self.add_argument_group('Arguments for Increased Resources') @@ -79,7 +87,7 @@ def __init__(self) -> None: @staticmethod def parse_time(time_str: str) -> time: - """Parse a string representation of time in 'HH:MM:SS' format and return a time object + """Parse a string representation of time in 'HH:MM:SS' format and return a time object. Args: time_str: A string representing time in 'HH:MM:SS' format. @@ -101,40 +109,45 @@ def parse_time(time_str: str) -> time: except Exception: raise ArgumentTypeError(f'Could not parse time value {time_str}') - def _validate_arguments(self, args: Namespace) -> None: - """Exit the application if command line arguments are invalid + def parse_args(self, args=None, namespace=None) -> Namespace: + """Parse command line arguments.""" - Args: - args: Parsed commandline arguments - """ + args = super().parse_args(args, namespace) + + # Set defaults that need to be determined dynamically + if not args.num_gpus: + args.num_gpus = 1 if args.gpu else 0 # Check wall time is between limits, enable both %H:%M format and integer hours check_time = args.time.hour + args.time.minute / 60 + args.time.second / 3600 - if not self.min_time <= check_time <= self.max_time: - self.error(f'{check_time} is not in {self.min_time} <= time <= {self.max_time}... exiting') + self.error(f'Requested time must be between {self.min_time} and {self.max_time}.') # Check the minimum number of nodes are requested for mpi if args.mpi and args.num_nodes < self.min_mpi_nodes: - self.error(f'You must use at least {self.min_mpi_nodes} nodes when using the MPI cluster') + self.error(f'You must use at least {self.min_mpi_nodes} nodes when using the MPI cluster.') # Check the minimum number of cores are requested for mpi - if args.mpi and args.num_cores < self.min_mpi_cores.get(args.partition, self.default_mpi_cores): - self.error(f'You must request at least {self.min_mpi_cores.get(args.partition, self.default_mpi_cores)} ' - f'cores per node when using the MPI cluster {args.partition} partition') + min_cores = self.min_mpi_cores[args.partition] + if args.mpi and args.num_cores < min_cores: + self.error( + f'You must request at least {min_cores} cores per node when using the {args.partition} partition on the MPI cluster.' + ) # Check a partition is specified if the user is requesting invest if args.invest and not args.partition: - self.error('You must specify a partition when using the Investor cluster') + self.error('You must specify a partition when using the invest cluster.') + + return args def create_srun_command(self, args: Namespace) -> str: - """Create an ``srun`` command based on parsed command line arguments + """Create an `srun` command based on parsed command line arguments. Args: - args: A dictionary of parsed command line parsed_args + args: A dictionary of parsed command line parsed_args. Return: - The equivalent ``srun`` command as a string + The equivalent `srun` command as a string. """ # Map arguments from the parent application to equivalent srun arguments @@ -161,26 +174,26 @@ def create_srun_command(self, args: Namespace) -> str: if (args.gpu or args.invest) and args.num_gpus: srun_args += ' ' + f'--gres=gpu:{args.num_gpus}' - cluster_to_run = next(cluster for cluster in Slurm.get_cluster_names() if getattr(args, cluster)) + try: + cluster_to_run = next(cluster for cluster in self.clusters if getattr(args, cluster)) + + except StopIteration: + raise RuntimeError('Please specify which cluster to run on.') + return f'srun -M {cluster_to_run} {srun_args} --pty bash' def app_logic(self, args: Namespace) -> None: - """Logic to evaluate when executing the application + """Logic to evaluate when executing the application. Args: - args: Parsed command line arguments + args: Parsed command line arguments. """ if not any(getattr(args, cluster, False) for cluster in Slurm.get_cluster_names()): self.print_help() self.exit() - # Set defaults that need to be determined dynamically - if not args.num_gpus: - args.num_gpus = 1 if args.gpu else 0 - # Create the slurm command - self._validate_arguments(args) srun_command = self.create_srun_command(args) if args.print_command: diff --git a/apps/utils/__init__.py b/apps/utils/__init__.py index f73126c..bce24b6 100644 --- a/apps/utils/__init__.py +++ b/apps/utils/__init__.py @@ -1 +1,3 @@ """The ``utils`` module defines helper utilities for building commandline system tools.""" + +from .system_info import Shell, Slurm diff --git a/tests/test_crc_idle.py b/tests/test_crc_idle.py index fb518bd..e64b3c7 100644 --- a/tests/test_crc_idle.py +++ b/tests/test_crc_idle.py @@ -1,6 +1,8 @@ -"""Tests for the ``crc-idle`` application""" +"""Tests for the `crc-idle` application""" -from unittest import TestCase, skip +from argparse import Namespace +from unittest import TestCase +from unittest.mock import call, Mock, patch from apps.crc_idle import CrcIdle from apps.utils.system_info import Slurm @@ -33,9 +35,9 @@ def test_cluster_parsing(self) -> None: self.assertFalse(args.htc) self.assertFalse(args.gpu) - @skip('Requires slurm utilities') + @patch('apps.utils.Slurm.get_cluster_names', new=lambda: tuple(CrcIdle.cluster_types.keys())) def test_clusters_default_to_false(self) -> None: - """Test all cluster flags default to a ``False`` value""" + """Test all cluster flags default to a `False` value""" app = CrcIdle() args, unknown_args = app.parse_known_args([]) @@ -45,25 +47,101 @@ def test_clusters_default_to_false(self) -> None: self.assertFalse(getattr(args, cluster)) -class ClusterList(TestCase): - """Test the selection of what clusters to print""" +class GetClusterList(TestCase): + """Test the selection of which clusters to print""" - @skip('Requires slurm utilities') - def test_defaults_all_clusters(self) -> None: - """Test all clusters are returned if none are specified in the parsed arguments""" + def test_get_cluster_list_no_arguments(self) -> None: + """Test returned values when no clusters are specified.""" app = CrcIdle() - args, unknown_args = app.parse_known_args(['-p', 'partition1']) - self.assertFalse(unknown_args) + args = Namespace(smp=False, gpu=False, mpi=False, invest=False, htc=False, partition=None) + result = app.get_cluster_list(args) - returned_clusters = app.get_cluster_list(args) - self.assertCountEqual(Slurm.get_cluster_names(), returned_clusters) + expected = tuple(app.cluster_types.keys()) + self.assertEqual(expected, result) + + def test_get_cluster_list_with_cluster_arguments(self) -> None: + """Test returned values when select clusters are specified.""" - def test_returns_arg_values(self) -> None: - """Test returned cluster names match the clusters specified in the parsed arguments""" app = CrcIdle() - args, unknown_args = app.parse_known_args(['-s', '--mpi']) - self.assertFalse(unknown_args) + args = Namespace(smp=True, gpu=False, mpi=True, invest=False, htc=False, partition=None) + result = app.get_cluster_list(args) + + self.assertEqual(('smp', 'mpi'), result) + + +class CountIdleResources(TestCase): + """Test the counting of idle CPU/DPU resources""" + + @patch('apps.utils.Shell.run_command') + def test_count_idle_cpu_resources(self, mock_run_command: Mock) -> None: + """Test counting idle CPU resources.""" + + cluster = 'smp' + partition = 'default' + mock_run_command.return_value = "node1,2/4/0/4\nnode2,3/2/0/3" + + app = CrcIdle() + result = app.count_idle_resources(cluster, partition) + + expected = {4: 1, 2: 1} + self.assertEqual(expected, result) + + @patch('apps.utils.Shell.run_command') + def test_count_idle_gpu_resources(self, mock_run_command: Mock) -> None: + """Test counting idle GPU resources.""" - returned_clusters = app.get_cluster_list(args) - self.assertCountEqual(['smp', 'mpi'], returned_clusters) + cluster = 'gpu' + partition = 'default' + mock_run_command.return_value = "node1_4_2_idle\nnode2_4_4_drain" + + app = CrcIdle() + result = app.count_idle_resources(cluster, partition) + expected = {2: 1, 0: 1} + self.assertEqual(expected, result) + + +class PrintPartitionSummary(TestCase): + """Test the printing of a partition summary""" + + @patch('builtins.print') + def test_print_partition_summary_with_idle_resources(self, mock_print: Mock) -> None: + """Test printing a summary with idle resources.""" + + cluster = 'smp' + partition = 'default' + idle_resources = {2: 3, 4: 1} # 3 nodes with 2 idle resources, 1 node with 4 idle resources + + app = CrcIdle() + app.print_partition_summary(cluster, partition, idle_resources) + + mock_print.assert_has_calls([ + call(f'Cluster: {cluster}, Partition: {partition}'), + call('=' * 30), + call(' 3 nodes w/ 2 idle cores'), + call(' 1 nodes w/ 4 idle cores'), + call('') + ], any_order=False) + + @patch('builtins.print') + def test_print_partition_summary_no_idle_resources(self, mock_print: Mock) -> None: + """Test printing a summary when no idle resources are available.""" + + cluster = 'smp' + partition = 'default' + idle_resources = dict() # No idle resources + + app = CrcIdle() + app.print_partition_summary(cluster, partition, idle_resources) + + mock_print.assert_any_call(f'Cluster: {cluster}, Partition: {partition}') + mock_print.assert_any_call('=' * 30) + mock_print.assert_any_call(' No idle resources') + mock_print.assert_any_call('') + + mock_print.assert_has_calls([ + call(f'Cluster: {cluster}, Partition: {partition}'), + call('=====' * 6), + call(' No idle resources'), + call('') + ], any_order=False) diff --git a/tests/test_crc_interactive.py b/tests/test_crc_interactive.py index 2194d85..787d726 100644 --- a/tests/test_crc_interactive.py +++ b/tests/test_crc_interactive.py @@ -1,7 +1,6 @@ -"""Tests for the ``crc-interactive`` application.""" +"""Tests for the `crc-interactive` application.""" -import unittest -from argparse import ArgumentTypeError +from argparse import ArgumentTypeError, Namespace from datetime import time from unittest import TestCase @@ -9,24 +8,72 @@ class ArgumentParsing(TestCase): - """Test the parsing of command line arguments""" + """Test the parsing of command line arguments.""" + + def setUp(self) -> None: + """Set up the test environment.""" + + self.app = CrcInteractive() def test_args_match_class_settings(self) -> None: - """Test parsed args default to the values defined as class settings""" + """Test parsed args default to the values defined as class settings.""" - args, _ = CrcInteractive().parse_known_args(['--mpi']) + args = self.app.parse_args(['--smp']) self.assertEqual(CrcInteractive.default_time, args.time) self.assertEqual(CrcInteractive.default_cores, args.num_cores) self.assertEqual(CrcInteractive.default_mem, args.mem) self.assertEqual(CrcInteractive.default_gpus, args.num_gpus) + def test_default_gpus(self) -> None: + """Test the default number of GPUs is determined dynamically by cluster.""" + + smp_args = self.app.parse_args(['--smp']) + self.assertEqual(0, smp_args.num_gpus) + + gpu_args = self.app.parse_args(['--gpu']) + self.assertEqual(1, gpu_args.num_gpus) + + def test_time_argument_out_of_range(self) -> None: + """Test invalid time arguments raise an error.""" + + # Time too short + with self.assertRaisesRegex(SystemExit, 'Requested time must be', msg='Minimum MPI time not enforced.'): + self.app.parse_args(['--smp', '--time', '00:00:30']) + + # Time too long + with self.assertRaisesRegex(SystemExit, 'Requested time must be', msg='Maximum MPI time not enforced.'): + self.app.parse_args(['--smp', '--time', '00:50:00']) + + def test_mpi_minimums(self) -> None: + """Test minimum usage limits on MPI.""" + + min_nodes = CrcInteractive.min_mpi_nodes + min_cores = CrcInteractive.min_mpi_cores.default_factory() + + # Minimum values should parse without any errors + self.app.parse_args(['--mpi', '--num-nodes', str(min_nodes), '--num-cores', str(min_cores)]) -class TestParseTime(unittest.TestCase): - """Test the parsing of time strings""" + nodes_err_message = 'You must use at least .* nodes' + with self.assertRaisesRegex(SystemExit, nodes_err_message, msg='Minimum nodes not enforced.'): + self.app.parse_args(['--mpi', '--num-nodes', str(min_nodes - 1), '--num-cores', str(min_cores)]) + + cores_error_message = 'You must request at least .* cores' + with self.assertRaisesRegex(SystemExit, cores_error_message, msg='Minimum cores not enforced.'): + self.app.parse_args(['--mpi', '--num-nodes', str(min_nodes), '--num-cores', str(min_cores - 1)]) + + def test_invest_partition_required(self) -> None: + """Test a partition must be specified for the invest cluster.""" + + with self.assertRaisesRegex(SystemExit, 'You must specify a partition'): + self.app.parse_args(['--invest']) + + +class TestParseTime(TestCase): + """Test the parsing of time strings.""" def test_valid_time(self) -> None: - """Test the parsing of valid time strings""" + """Test the parsing of valid time strings.""" self.assertEqual(CrcInteractive.parse_time('1'), time(1, 0, 0)) self.assertEqual(CrcInteractive.parse_time('01'), time(1, 0, 0)) @@ -34,7 +81,7 @@ def test_valid_time(self) -> None: self.assertEqual(CrcInteractive.parse_time('12:34:56'), time(12, 34, 56)) def test_invalid_time_format(self) -> None: - """Test an errr is raised for invalid time formatting""" + """Test an errr is raised for invalid time formatting.""" # Test with invalid time formats with self.assertRaises(ArgumentTypeError, msg='Error not raised for invalid delimiter'): @@ -47,7 +94,7 @@ def test_invalid_time_format(self) -> None: CrcInteractive.parse_time('12:34:56:78') def test_invalid_time_value(self) -> None: - """Test an errr is raised for invalid time values""" + """Test an errr is raised for invalid time values.""" with self.assertRaises(ArgumentTypeError, msg='Error not raised for invalid hour'): CrcInteractive.parse_time('25:00:00') @@ -59,7 +106,171 @@ def test_invalid_time_value(self) -> None: CrcInteractive.parse_time('12:34:60') def test_empty_string(self) -> None: - """Test an error is raised for empty strings""" + """Test an error is raised for empty strings.""" with self.assertRaises(ArgumentTypeError): CrcInteractive.parse_time('') + + +class CreateSrunCommand(TestCase): + """Test the creation of `srun` commands.""" + + def setUp(self) -> None: + """Set up the test environment.""" + + self.app = CrcInteractive() + + def test_gpu_cluster(self) -> None: + """Test generating an `srun` command for the `gpu` cluster.""" + + args = Namespace( + print_command=False, + smp=False, + gpu=True, + mpi=False, + invest=False, + htc=False, + teach=False, + partition=None, + mem=2, + time=time(2, 0), + num_nodes=2, + num_cores=4, + num_gpus=1, + account=None, + reservation=None, + license=None, + feature=None, + openmp=False + ) + + expected_command = ( + 'srun -M gpu --export=ALL ' + '--nodes=2 --time=02:00:00 --mem=2g --ntasks-per-node=4 --gres=gpu:1 --pty bash' + ) + + actual_command = self.app.create_srun_command(args) + self.assertEqual(expected_command, actual_command) + + def test_mpi_cluster(self) -> None: + """Test generating an `srun` command for the `gpu` cluster.""" + + args = Namespace( + print_command=False, + smp=False, + gpu=False, + mpi=True, + invest=False, + htc=False, + teach=False, + partition='mpi', + mem=4, + time=time(3, 0), + num_nodes=3, + num_cores=48, + num_gpus=0, + account=None, + reservation=None, + license=None, + feature=None, + openmp=False + ) + + expected_command = ( + 'srun -M mpi --export=ALL --partition=mpi ' + '--nodes=3 --time=03:00:00 --mem=4g --ntasks-per-node=48 --pty bash' + ) + + actual_command = self.app.create_srun_command(args) + self.assertEqual(expected_command, actual_command) + + def test_invest_command(self) -> None: + """Test srun command for the invest cluster.""" + + args = Namespace( + print_command=False, + smp=False, + gpu=False, + mpi=False, + invest=True, + htc=False, + teach=False, + partition='invest-partition', + mem=2, + time=time(1, 0), + num_nodes=1, + num_cores=4, + num_gpus=0, + account=None, + reservation=None, + license=None, + feature=None, + openmp=False + ) + + expected_command = ( + 'srun -M invest --export=ALL --partition=invest-partition ' + '--nodes=1 --time=01:00:00 --mem=2g --ntasks-per-node=4 --pty bash' + ) + + actual_command = self.app.create_srun_command(args) + self.assertEqual(expected_command, actual_command) + + def test_partition_specific_cores(self) -> None: + """Test srun command with partition-specific core requirements.""" + + args = Namespace( + print_command=False, + smp=False, + gpu=False, + mpi=True, + invest=False, + htc=False, + teach=False, + partition='opa-high-mem', + mem=8, + time=time(2, 0), + num_nodes=2, + num_cores=28, + num_gpus=0, + account=None, + reservation=None, + license=None, + feature=None, + openmp=False + ) + + expected_command = ( + 'srun -M mpi --export=ALL --partition=opa-high-mem ' + '--nodes=2 --time=02:00:00 --mem=8g --ntasks-per-node=28 --pty bash' + ) + + actual_command = self.app.create_srun_command(args) + self.assertEqual(expected_command, actual_command) + + def test_no_cluster_specified(self) -> None: + """Test an error is raised when no cluster is specified.""" + + args = Namespace( + print_command=False, + smp=False, + gpu=False, + mpi=False, + invest=False, + htc=False, + teach=False, + partition=None, + mem=1, + time=time(1, 0), + num_nodes=1, + num_cores=4, + num_gpus=0, + account=None, + reservation=None, + license=None, + feature=None, + openmp=True + ) + + with self.assertRaises(RuntimeError): + self.app.create_srun_command(args)