diff --git a/shinigami/cli.py b/shinigami/cli.py index bc3c9db..6295057 100644 --- a/shinigami/cli.py +++ b/shinigami/cli.py @@ -59,7 +59,7 @@ def __init__(self) -> None: scan = subparsers.add_parser('scan', parents=[common], help='terminate processes on one or more clusters') scan.set_defaults(callable=Application.scan) scan.add_argument('-c', '--clusters', nargs='+', required=True, help='cluster names to scan') - scan.add_argument('-i', '--ignore-nodes', nargs='*', default=[], help='ignore given nodes') + scan.add_argument('-i', '--ignore-nodes', nargs='*', default=[], help='ignore the given nodes') scan.add_argument('-u', '--uid-whitelist', nargs='+', type=loads, default=[0], help='user IDs to scan') # Subparser for the `Application.terminate` method diff --git a/shinigami/utils.py b/shinigami/utils.py index dd15b29..afac877 100755 --- a/shinigami/utils.py +++ b/shinigami/utils.py @@ -14,7 +14,11 @@ def id_in_whitelist(id_value: int, whitelist: Collection[Union[int, Tuple[int, int]]]) -> bool: - """Return whether an ID is in a list of ID values + """Return whether an ID is in a list of ID value definitions + + The `whitelist` of ID values can contain a mix of integers and tuples + of integer ranges. For example, [0, 1, (2, 9), 10] includes all IDs from + zero through ten. Args: id_value: The ID value to check @@ -34,12 +38,12 @@ def id_in_whitelist(id_value: int, whitelist: Collection[Union[int, Tuple[int, i return False -def get_nodes(cluster: str, ignore_substring: Collection[str]) -> set: +def get_nodes(cluster: str, ignore_nodes: Collection[str] = tuple()) -> set: """Return a set of nodes included in a given Slurm cluster Args: cluster: Name of the cluster to fetch nodes for - ignore_substring: Do not return nodes containing any of the given substrings + ignore_nodes: Do not return nodes included in the provided list Returns: A set of cluster names @@ -52,8 +56,7 @@ def get_nodes(cluster: str, ignore_substring: Collection[str]) -> set: raise RuntimeError(stderr) all_nodes = stdout.decode().strip().split('\n') - is_valid = lambda node: not any(substring in node for substring in ignore_substring) - return set(filter(is_valid, all_nodes)) + return set(node for node in all_nodes if node not in ignore_nodes) async def terminate_errant_processes( diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..d9866e1 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,71 @@ +"""Tests for the `utils` module.""" + +from unittest import TestCase + +from shinigami import utils +from shinigami.utils import id_in_whitelist + +# For information on resources defined in the testing environment, +# see https://github.com/pitt-crc/Slurm-Test-Environment/ +TEST_CLUSTER = 'development' +TEST_NODES = set(f'c{i}' for i in range(1, 11)) + + +class Whitelisting(TestCase): + """Tests for the ``id_in_whitelist`` function""" + + def test_empty_whitelist(self) -> None: + """Test the return value is ``False`` for all ID values when the whitelist is empty""" + + self.assertFalse(id_in_whitelist(0, [])) + self.assertFalse(id_in_whitelist(123, [])) + + def test_whitelisted_by_id(self) -> None: + """Test return values for a whitelist of explicit ID values""" + + whitelist = (123, 456, 789) + self.assertTrue(id_in_whitelist(456, whitelist)) + self.assertFalse(id_in_whitelist(0, whitelist)) + + def test_whitelisted_by_id_range(self) -> None: + """Test return values for a whitelist of ID ranges""" + + whitelist = (0, 1, 2, (100, 300)) + self.assertTrue(id_in_whitelist(123, whitelist)) + self.assertFalse(id_in_whitelist(301, whitelist)) + + +class GetNodes(TestCase): + """Tests for the ``get_nodes`` function""" + + def test_nodes_match_test_env(self) -> None: + """Test the returned node list matches values defined in the testing environment""" + + self.assertCountEqual(TEST_NODES, utils.get_nodes(TEST_CLUSTER)) + + def test_ignore_substring(self) -> None: + """Test nodes with the included substring are ignored""" + + exclude_node = 'c1' + + # Create a copy of the test nodes with on element missing + expected_nodes = TEST_NODES.copy() + expected_nodes.remove(exclude_node) + + returned_nodes = utils.get_nodes(TEST_CLUSTER, [exclude_node]) + self.assertCountEqual(expected_nodes, returned_nodes) + + def test_missing_cluster(self) -> None: + """Test an error is raised for a cluster name that does not exist""" + + with self.assertRaisesRegex(RuntimeError, 'No cluster \'fake_cluster\''): + utils.get_nodes('fake_cluster') + + def test_missing_node(self) -> None: + """Test no error is raised when an excuded node des not exist""" + + excluded_nodes = {'c1', 'fake_node'} + expected_nodes = TEST_NODES - excluded_nodes + + returned_nodes = utils.get_nodes(TEST_CLUSTER, excluded_nodes) + self.assertCountEqual(expected_nodes, returned_nodes) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/utils/test_get_nodes.py b/tests/utils/test_get_nodes.py deleted file mode 100644 index 74c31fb..0000000 --- a/tests/utils/test_get_nodes.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Tests for the ``utils.get_nodes`` function""" - -from unittest import TestCase - -from shinigami import utils - -# For information on resources defined in the testing environment, -# see https://github.com/pitt-crc/Slurm-Test-Environment/ -TEST_CLUSTER = 'development' -TEST_NODES = set(f'c{i}' for i in range(1, 11)) - - -class NodesMatchTestEnvironment(TestCase): - """Test the returned node list matches values defined in the testing environment""" - - def test_returned_nodes(self) -> None: - """Test returned nodes match hose defined in the slurm test env""" - - self.assertSequenceEqual(TEST_NODES, utils.get_nodes(TEST_CLUSTER, [])) diff --git a/tests/utils/test_id_in_whitelist.py b/tests/utils/test_id_in_whitelist.py deleted file mode 100644 index b762d6a..0000000 --- a/tests/utils/test_id_in_whitelist.py +++ /dev/null @@ -1,29 +0,0 @@ -"""Tests for the ``utils.id_in_whitelist`` function""" - -from unittest import TestCase - -from shinigami.utils import id_in_whitelist - - -class Whitelisting(TestCase): - """Test ID values are correctly whitelisted""" - - def test_empty_whitelist(self) -> None: - """Test the return value is ``False`` for all ID values when the whitelist is empty""" - - self.assertFalse(id_in_whitelist(0, [])) - self.assertFalse(id_in_whitelist(123, [])) - - def test_whitelisted_by_id(self) -> None: - """Test return values for a whitelist of explicit ID values""" - - whitelist = (123, 456, 789) - self.assertTrue(id_in_whitelist(456, whitelist)) - self.assertFalse(id_in_whitelist(0, whitelist)) - - def test_whitelisted_by_id_range(self) -> None: - """Test return values for a whitelist of ID ranges""" - - whitelist = (0, 1, 2, (100, 300)) - self.assertTrue(id_in_whitelist(123, whitelist)) - self.assertFalse(id_in_whitelist(301, whitelist))