diff --git a/shinigami/cli.py b/shinigami/cli.py index f6b5371..7929d70 100644 --- a/shinigami/cli.py +++ b/shinigami/cli.py @@ -5,7 +5,7 @@ import logging import logging.config import sys -from argparse import ArgumentParser +from argparse import ArgumentParser, RawTextHelpFormatter from json import loads from typing import List, Collection, Union @@ -14,23 +14,7 @@ from . import __version__, utils -class BaseParser(ArgumentParser): - """Custom argument parser that prints help text on error""" - - def error(self, message: str) -> None: - """Prints a usage message and exit - - Args: - message: The usage message - """ - - if len(sys.argv) == 1: - self.print_help() - - raise SystemExit(message) - - -class Parser(BaseParser): +class Parser(ArgumentParser): """Defines the command-line interface and parses command-line arguments""" def __init__(self) -> None: @@ -38,33 +22,64 @@ def __init__(self) -> None: # Configure the top level parser super().__init__(prog='shinigami', description='Scan Slurm compute nodes and terminate orphan processes.') - subparsers = self.add_subparsers(required=True, parser_class=BaseParser) + subparsers = self.add_subparsers(required=True, parser_class=ArgumentParser) self.add_argument('--version', action='version', version=__version__) - # This parser defines reusable arguments and is not exposed to the user + # The `common` parser holds reusable argument definitions common = ArgumentParser(add_help=False) - ssh_group = common.add_argument_group('ssh options') - ssh_group.add_argument('-m', '--max-concurrent', type=int, default=1, help='maximum concurrent SSH connections') - ssh_group.add_argument('-t', '--ssh-timeout', type=int, default=120, help='SSH connection timeout in seconds') + ssh_group.add_argument('-m', dest='max_concurrent', type=int, default=1, help='maximum concurrent SSH connections (Default: 1)') + ssh_group.add_argument('-t', dest='ssh_timeout', type=int, default=120, help='SSH connection timeout in seconds (Default: 120)') debug_group = common.add_argument_group('debugging options') debug_group.add_argument('--debug', action='store_true', help='run the application in debug mode') - debug_group.add_argument('-v', action='count', dest='verbosity', default=0, - help='set verbosity to warning (-v), info (-vv), or debug (-vvv)') + debug_group.add_argument('-v', action='count', dest='verbosity', default=0, help='set verbosity to warning (-v), info (-vv), or debug (-vvv)') # Subparser for the `Application.scan` method - scan = subparsers.add_parser('scan', parents=[common], help='terminate processes on one or more clusters') + scan = subparsers.add_parser( + 'scan', parents=[common], formatter_class=RawTextHelpFormatter, + help='terminate processes on one or more clusters', + description=( + "The `scan` function automatically terminates orphaned processes on all compute nodes in a Slurm cluster.\n" + "It is provided as a shorthand alternative to calling the `terminate` command with manually defined node names.\n\n" + "Slurm nodes are identified using the slurm installation on the current machine.\n" + "User IDs can be specified individually (e.g. `-u 1000 1001 1002 1003`) or as ranges (e.g. `-u 1000 [1001,1003]`).")) + scan.set_defaults(callable=Application.scan) - scan.add_argument('-c', '--clusters', nargs='+', required=True, help='cluster names to scan') - scan.add_argument('-i', '--ignore-nodes', nargs='*', default=[], help='ignore the given nodes') - scan.add_argument('-u', '--uid-whitelist', nargs='+', type=loads, default=[0], help='user IDs to scan') + scan_group = scan.add_argument_group('scanning options') + scan_group.add_argument('-c', dest='clusters', metavar='CLUS', nargs='+', required=True, help='Slurm cluster name(s) to scan') + scan_group.add_argument('-i', dest='ignore_nodes', metavar='NODE', nargs='*', default=[], help='ignore the given node(s)') + scan_group.add_argument('-u', dest='uid_whitelist', metavar='UID', nargs='+', type=loads, default=[0], help='only terminate processes owned by the given user IDs') # Subparser for the `Application.terminate` method - terminate = subparsers.add_parser('terminate', parents=[common], help='terminate processes on a single node') + terminate = subparsers.add_parser( + 'terminate', parents=[common], formatter_class=RawTextHelpFormatter, + help='terminate processes on one or more compute nodes', + description=( + "Automatically terminate orphaned processes on one or more Slurm compute nodes.\n\n" + "Processes are only terminated under the following conditions:\n" + f" 1. The process belongs to a process tree parented by init (PID {utils.INIT_PROCESS_ID})\n" + " 2. The associated user ID is in the given UID whitelist\n" + " 3. The user is not running any Slurm jobs on the parent machine\n\n" + "User IDs can be specified individually (e.g. `-u 1000 1001 1002 1003`) or as ranges (e.g. `-u 1000 [1001,1003]`).")) + terminate.set_defaults(callable=Application.terminate) - terminate.add_argument('-n', '--nodes', nargs='+', required=True, help='the DNS name of the node to terminate') - terminate.add_argument('-u', '--uid-whitelist', nargs='+', type=loads, default=[0], help='user IDs to scan') + terminate_group = terminate.add_argument_group('termination options') + terminate_group.add_argument('-n', dest='nodes', metavar='NODE', nargs='+', required=True, help='the DNS name(s) of the node(s) to terminate') + terminate_group.add_argument('-u', dest='uid_whitelist', metavar='UID', nargs='+', type=loads, default=[0], help='only terminate processes owned by the given user IDs') + + def error(self, message: str) -> None: + """Print a usage message and exits the application + + Args: + message: The usage message + """ + + if len(sys.argv) == 1: + self.print_help() + raise SystemExit + + super().error(message) class Application: @@ -79,12 +94,23 @@ def _configure_logging(verbosity: int) -> None: - file_logger: For logging to the log file only - root: For logging to the console and log file + Console verbosity levels are defined as following: + - 0: ERROR + - 1: WARNING + - 2: INFO + - 3: DEBUG + - Any other value: DEBUG + Args: - verbosity: The console verbosity defined as the count passed to the commandline + verbosity: The console verbosity level """ - verbosity_to_log_level = {0: logging.ERROR, 1: logging.WARNING, 2: logging.INFO, 3: logging.DEBUG} - console_log_level = verbosity_to_log_level.get(verbosity, logging.DEBUG) + console_log_level = { + 0: logging.ERROR, + 1: logging.WARNING, + 2: logging.INFO, + 3: logging.DEBUG + }.get(verbosity, logging.DEBUG) logging.config.dictConfig({ 'version': 1, @@ -122,7 +148,7 @@ async def scan( clusters: Collection[str], ignore_nodes: Collection[str], uid_whitelist: Collection[Union[int, List[int]]], - max_concurrent: asyncio.Semaphore, + max_concurrent: int, ssh_timeout: int, debug: bool ) -> None: @@ -147,7 +173,7 @@ async def scan( async def terminate( nodes: Collection[str], uid_whitelist: Collection[Union[int, List[int]]], - max_concurrent: asyncio.Semaphore, + max_concurrent: int, ssh_timeout: int, debug: bool ) -> None: @@ -182,7 +208,7 @@ async def terminate( @classmethod def execute(cls, arg_list: List[str] = None) -> None: - """Parse command-line arguments and execute the application + """Parse command line arguments and execute the application Args: arg_list: Optionally parse the given arguments instead of the command line @@ -192,7 +218,7 @@ def execute(cls, arg_list: List[str] = None) -> None: cls._configure_logging(args.verbosity) try: - # Extract the subset of arguments that are valid for the function `args.callable` + # Extract the subset of arguments that are valid for the `args.callable` function valid_params = inspect.signature(args.callable).parameters valid_arguments = {key: value for key, value in vars(args).items() if key in valid_params} asyncio.run(args.callable(**valid_arguments)) diff --git a/shinigami/utils.py b/shinigami/utils.py index d9c982d..1f843a5 100755 --- a/shinigami/utils.py +++ b/shinigami/utils.py @@ -14,34 +14,6 @@ # architecture, but 1 is an almost universal default INIT_PROCESS_ID = 1 -# Custom type hints -Whitelist = Collection[Union[int, Tuple[int, int]]] - - -def _id_in_whitelist(id_value: int, whitelist: Whitelist) -> bool: - """Return whether an ID is in a list of ID value definitions - - The `whitelist` of ID values can contain a mix of integers and tuples - of integer ranges. For example, [0, 1, (2, 9), 10] includes all IDs from - zero through ten. - - Args: - id_value: The ID value to check - whitelist: A collection of ID values and ID ranges - - Returns: - Whether the ID is in the whitelist - """ - - for id_def in whitelist: - if hasattr(id_def, '__getitem__') and (id_def[0] <= id_value <= id_def[1]): - return True - - elif id_value == id_def: - return True - - return False - def get_nodes(cluster: str, ignore_nodes: Collection[str] = tuple()) -> set: """Return a set of nodes included in a given Slurm cluster @@ -100,7 +72,7 @@ def include_orphaned_processes(df: pd.DataFrame) -> pd.DataFrame: return df[df['PPID'] == INIT_PROCESS_ID] -def include_user_whitelist(df: pd.DataFrame, uid_whitelist: Whitelist) -> pd.DataFrame: +def include_user_whitelist(df: pd.DataFrame, uid_whitelist: Collection[Union[int, Tuple[int, int]]]) -> pd.DataFrame: """Filter a DataFrame to only include a subset of user IDs Given a DataFrame with system process data, return a subset of the data @@ -116,8 +88,16 @@ def include_user_whitelist(df: pd.DataFrame, uid_whitelist: Whitelist) -> pd.Dat A copy of the given DataFrame """ - whitelist_index = df['UID'].apply(_id_in_whitelist, whitelist=uid_whitelist) - return df[whitelist_index] + whitelisted_uid_values = [] + for elt in uid_whitelist: + if isinstance(elt, int): + whitelisted_uid_values.append(elt) + + else: + umin, umax = elt + whitelisted_uid_values.extend(range(umin, umax)) + + return df[df['UID'].isin(whitelisted_uid_values)] def exclude_active_slurm_users(df: pd.DataFrame) -> pd.DataFrame: @@ -162,7 +142,7 @@ async def terminate_errant_processes( logging.info(f'[{node}] Scanning for processes') process_df = await get_remote_processes(conn) - # Filter them by various whitelist/blacklist criteria + # Filter process data by various whitelist/blacklist criteria process_df = include_orphaned_processes(process_df) process_df = include_user_whitelist(process_df, uid_whitelist) process_df = exclude_active_slurm_users(process_df) diff --git a/tests/cli/test_base_parser.py b/tests/cli/test_base_parser.py deleted file mode 100644 index 39239c7..0000000 --- a/tests/cli/test_base_parser.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Tests for the `cli.BaseParser` class""" - -from unittest import TestCase - -from shinigami.cli import BaseParser - - -class BaseParsing(TestCase): - """Test custom parsing logic encapsulated by the `BaseParser` class""" - - def test_errors_raise_system_exit(self) -> None: - """Test error messages are raised as `SystemExit` instances""" - - with self.assertRaises(SystemExit): - BaseParser().error("This is an error message") - - def test_errors_include_message(self) -> None: - """Test parser messages are included as error messages""" - - msg = "This is an error message" - with self.assertRaisesRegex(SystemExit, msg): - BaseParser().error(msg) diff --git a/tests/cli/test_parser.py b/tests/cli/test_parser.py index e8f6112..a06e2c8 100644 --- a/tests/cli/test_parser.py +++ b/tests/cli/test_parser.py @@ -5,6 +5,16 @@ from shinigami.cli import Parser +class ErrorHandling(TestCase): + """Test custom parsing logic encapsulated by the `BaseParser` class""" + + def test_errors_raise_system_exit(self) -> None: + """Test error messages are raised as `SystemExit` instances""" + + with self.assertRaises(SystemExit): + Parser().error("This is an error message") + + class ScanSubParser(TestCase): """Test the behavior of the `scan` subparser"""