From 4567201d861c2080485241684de6fbe08673d0c2 Mon Sep 17 00:00:00 2001 From: Daniel Perrefort Date: Thu, 5 Oct 2023 12:03:19 -0400 Subject: [PATCH 1/5] Fix termination signal (#84) --- shinigami/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/shinigami/utils.py b/shinigami/utils.py index cce3166..d7a84b6 100755 --- a/shinigami/utils.py +++ b/shinigami/utils.py @@ -10,6 +10,8 @@ import asyncssh import pandas as pd +INIT_PROCESS_ID = 1 + def id_in_whitelist(id_value: int, whitelist: Collection[Union[int, Tuple[int, int]]]) -> bool: """Return whether an ID is in a list of ID values @@ -84,7 +86,7 @@ async def terminate_errant_processes( process_df = pd.read_fwf(StringIO(ps_return.stdout), widths=[11, 11, 11, 11, 500]) # Identify orphaned processes and filter them by the UID whitelist - orphaned = process_df[process_df.PPID == 1] + orphaned = process_df[process_df.PPID == INIT_PROCESS_ID] terminate = orphaned[orphaned['UID'].apply(id_in_whitelist, whitelist=uid_whitelist)] for _, row in terminate.iterrows(): logging.debug(f'[{node}] Marking for termination {dict(row)}') @@ -93,6 +95,6 @@ async def terminate_errant_processes( logging.info(f'[{node}] No orphans found') elif not debug: - proc_id_str = ','.join(terminate.PGID.astype(str)) + proc_id_str = ','.join(terminate.PGID.unique().astype(str)) logging.info(f"[{node}] Sending termination signal for process groups {proc_id_str}") - await conn.run(f"pkill --signal -9 --pgroup {proc_id_str}", check=True) + await conn.run(f"pkill --signal 9 --pgroup {proc_id_str}", check=True) From 7b5291d6d6819769a6352e551325145aed5d0b23 Mon Sep 17 00:00:00 2001 From: Daniel Perrefort Date: Tue, 17 Oct 2023 10:54:13 -0400 Subject: [PATCH 2/5] Replace settings file with commandline arguments (#87) --- shinigami/cli.py | 206 +++++++++++++++++++++----------- shinigami/settings.py | 79 ------------ shinigami/utils.py | 20 ++-- tests/cli/test_application.py | 81 ------------- tests/cli/test_parser.py | 152 ++++++++++++++++++++--- tests/settings/__init__.py | 0 tests/settings/test_settings.py | 46 ------- 7 files changed, 280 insertions(+), 304 deletions(-) delete mode 100644 shinigami/settings.py delete mode 100644 tests/cli/test_application.py delete mode 100644 tests/settings/__init__.py delete mode 100644 tests/settings/test_settings.py diff --git a/shinigami/cli.py b/shinigami/cli.py index ef598c8..bc3c9db 100644 --- a/shinigami/cli.py +++ b/shinigami/cli.py @@ -1,58 +1,93 @@ -"""The application command-line interface.""" +"""The executable application and its command-line interface.""" import asyncio +import inspect import logging import logging.config -import logging.handlers -from argparse import RawTextHelpFormatter, ArgumentParser -from typing import List +import sys +from argparse import ArgumentParser +from json import loads +from typing import List, Collection, Union + +from asyncssh import SSHClientConnectionOptions from . import __version__, utils -from .settings import Settings, SETTINGS_PATH -class Parser(ArgumentParser): +class BaseParser(ArgumentParser): + """Custom argument parser that prints help text on error""" + + def error(self, message: str) -> None: + """Prints a usage message and exit + + Args: + message: The usage message + """ + + if len(sys.argv) == 1: + self.print_help() + super().exit(1) + + else: + super().error(message) + + +class Parser(BaseParser): """Defines the command-line interface and parses command-line arguments""" def __init__(self) -> None: """Define the command-line interface""" - super().__init__( - prog='shinigami', - formatter_class=RawTextHelpFormatter, # Allow newlines in description text - description=( - 'Scan Slurm compute nodes and terminate errant processes.\n\n' - f'See {SETTINGS_PATH} for the current application settings.' - )) - + # Configure the top level parser + super().__init__(prog='shinigami', description='Scan Slurm compute nodes and terminate orphan processes.') + subparsers = self.add_subparsers(required=True, parser_class=BaseParser) self.add_argument('--version', action='version', version=__version__) - self.add_argument('--debug', action='store_true', help='force the application to run in debug mode') - self.add_argument('-v', action='count', dest='verbosity', default=0, - help='set output verbosity to warning (-v), info (-vv), or debug (-vvv)') + # This parser defines reusable arguments and is not exposed to the user + common = ArgumentParser(add_help=False) -class Application: - """Entry point for instantiating and executing the application""" + ssh_group = common.add_argument_group('ssh options') + ssh_group.add_argument('-m', '--max-concurrent', type=int, default=1, help='maximum concurrent SSH connections') + ssh_group.add_argument('-t', '--ssh-timeout', type=int, default=120, help='SSH connection timeout in seconds') - def __init__(self, settings: Settings) -> None: - """Instantiate a new instance of the application + debug_group = common.add_argument_group('debugging options') + debug_group.add_argument('--debug', action='store_true', help='run the application in debug mode') + debug_group.add_argument('-v', action='count', dest='verbosity', default=0, + help='set verbosity to warning (-v), info (-vv), or debug (-vvv)') + + # Subparser for the `Application.scan` method + scan = subparsers.add_parser('scan', parents=[common], help='terminate processes on one or more clusters') + scan.set_defaults(callable=Application.scan) + scan.add_argument('-c', '--clusters', nargs='+', required=True, help='cluster names to scan') + scan.add_argument('-i', '--ignore-nodes', nargs='*', default=[], help='ignore given nodes') + scan.add_argument('-u', '--uid-whitelist', nargs='+', type=loads, default=[0], help='user IDs to scan') + + # Subparser for the `Application.terminate` method + terminate = subparsers.add_parser('terminate', parents=[common], help='terminate processes on a single node') + terminate.set_defaults(callable=Application.terminate) + terminate.add_argument('-n', '--nodes', nargs='+', required=True, help='the DNS name of the node to terminate') + terminate.add_argument('-u', '--uid-whitelist', nargs='+', type=loads, default=[0], help='user IDs to scan') - Args: - settings: Settings to use when configuring and executing the application - """ - self._settings = settings - self._configure_logging() +class Application: + """Entry point for instantiating and executing the application""" - def _configure_logging(self) -> None: + @staticmethod + def _configure_logging(verbosity: int) -> None: """Configure Python logging - Configured loggers: - console_logger: For logging to the console only - file_logger: For logging to the log file only - root: For logging to the console and log file + Configured loggers include the following: + - console_logger: For logging to the console only + - file_logger: For logging to the log file only + - root: For logging to the console and log file + + Args: + verbosity: The console verbosity defined as the count passed to the commandline """ + verbosity_to_log_level = {0: logging.ERROR, 1: logging.WARNING, 2: logging.INFO, 3: logging.DEBUG} + console_log_level = verbosity_to_log_level.get(verbosity, logging.DEBUG) + logging.config.dictConfig({ 'version': 1, 'disable_existing_loggers': True, @@ -61,7 +96,7 @@ def _configure_logging(self) -> None: 'format': '%(levelname)8s: %(message)s' }, 'log_file_formatter': { - 'format': '%(levelname)8s | %(asctime)s | %(message)s' + 'format': '%(asctime)s | %(levelname)8s | %(message)s' }, }, 'handlers': { @@ -69,13 +104,12 @@ def _configure_logging(self) -> None: 'class': 'logging.StreamHandler', 'stream': 'ext://sys.stdout', 'formatter': 'console_formatter', - 'level': self._settings.verbosity + 'level': console_log_level }, 'log_file_handler': { - 'class': 'logging.FileHandler', + 'class': 'logging.handlers.SysLogHandler', 'formatter': 'log_file_formatter', - 'level': self._settings.log_level, - 'filename': self._settings.log_path + 'level': 'DEBUG', }, }, 'loggers': { @@ -85,49 +119,85 @@ def _configure_logging(self) -> None: } }) - async def run(self) -> None: - """Terminate errant processes on all clusters/nodes configured in application settings.""" + @staticmethod + async def scan( + clusters: Collection[str], + ignore_nodes: Collection[str], + uid_whitelist: Collection[Union[int, List[int]]], + max_concurrent: asyncio.Semaphore, + ssh_timeout: int, + debug: bool + ) -> None: + """Terminate orphaned processes on all clusters/nodes configured in application settings. - if not self._settings.clusters: - logging.warning('No cluster names configured in application settings.') + Args: + clusters: Slurm cluster names + ignore_nodes: List of nodes to ignore + uid_whitelist: UID values to terminate orphaned processes for + max_concurrent: Maximum number of concurrent ssh connections + ssh_timeout: Timeout for SSH connections + debug: Optionally log but do not terminate processes + """ - ssh_limit = asyncio.Semaphore(self._settings.max_concurrent) - for cluster in self._settings.clusters: + # Clusters are handled synchronously, nodes are handled asynchronously + for cluster in clusters: logging.info(f'Starting scan for nodes in cluster {cluster}') + nodes = utils.get_nodes(cluster, ignore_nodes) + await Application.terminate(nodes, uid_whitelist, max_concurrent, ssh_timeout, debug) + + @staticmethod + async def terminate( + nodes: Collection[str], + uid_whitelist: Collection[Union[int, List[int]]], + max_concurrent: asyncio.Semaphore, + ssh_timeout: int, + debug: bool + ) -> None: + """Terminate processes on a given node - # Launch a concurrent job for each node in the cluster - nodes = utils.get_nodes(cluster, self._settings.ignore_nodes) - coroutines = [ - utils.terminate_errant_processes( - node=node, - ssh_limit=ssh_limit, - uid_whitelist=self._settings.uid_whitelist, - timeout=self._settings.ssh_timeout, - debug=self._settings.debug) - for node in nodes - ] - - # Gather results from each concurrent run and check for errors - results = await asyncio.gather(*coroutines, return_exceptions=True) - for node, result in zip(nodes, results): - if isinstance(result, Exception): - logging.error(f'Error with node {node}: {result}') + Args: + nodes: + uid_whitelist: UID values to terminate orphaned processes for + max_concurrent: Maximum number of concurrent ssh connections + ssh_timeout: Timeout for SSH connections + debug: Optionally log but do not terminate processes + """ + + ssh_options = SSHClientConnectionOptions(connect_timeout=ssh_timeout) + + # Launch a concurrent job for each node in the cluster + coroutines = [ + utils.terminate_errant_processes( + node=node, + uid_whitelist=uid_whitelist, + ssh_limit=asyncio.Semaphore(max_concurrent), + ssh_options=ssh_options, + debug=debug) + for node in nodes + ] + + # Gather results from each concurrent run and check for errors + results = await asyncio.gather(*coroutines, return_exceptions=True) + for node, result in zip(nodes, results): + if isinstance(result, Exception): + logging.error(f'Error with node {node}: {result}') @classmethod def execute(cls, arg_list: List[str] = None) -> None: - """Parse command-line arguments and execute the application""" + """Parse command-line arguments and execute the application - args = Parser().parse_args(arg_list) - verbosity_to_log_level = {0: logging.ERROR, 1: logging.WARNING, 2: logging.INFO, 3: logging.DEBUG} + Args: + arg_list: Optionally parse the given arguments instead of the command line + """ - # Load application settings - override defaults using parsed arguments - settings = Settings.load() - settings.verbosity = verbosity_to_log_level.get(args.verbosity, logging.DEBUG) - settings.debug = settings.debug or args.debug + args = Parser().parse_args(arg_list) + cls._configure_logging(args.verbosity) try: - application = cls(settings) - asyncio.run(application.run()) + # Extract the subset of arguments that are valid for the function ``args.callable`` + valid_params = inspect.signature(args.callable).parameters + valid_arguments = {key: value for key, value in vars(args).items() if key in valid_params} + asyncio.run(args.callable(**valid_arguments)) except KeyboardInterrupt: pass diff --git a/shinigami/settings.py b/shinigami/settings.py deleted file mode 100644 index 7e9eb4a..0000000 --- a/shinigami/settings.py +++ /dev/null @@ -1,79 +0,0 @@ -"""The application settings schema.""" - -from __future__ import annotations - -import json -from pathlib import Path -from typing import Tuple, Union, Literal, Optional - -from pydantic import Field -from pydantic_settings import BaseSettings - -SETTINGS_PATH = Path('/etc/shinigami/settings.json') - - -class Settings(BaseSettings): - """Defines the settings schema and default settings values""" - - debug: bool = Field( - title='Debug Mode', - default=False, - description='When enabled, processes are scanned and logged but not terminated.') - - uid_whitelist: Tuple[Union[int, Tuple[int, int]], ...] = Field( - title='Whitelisted User IDs', - default=(0,), - description='Only terminate processes launched by users with the given UID values.') - - clusters: Tuple[str, ...] = Field( - title='Clusters to Scan', - default=tuple(), - description='Scan and terminate processes on the given Slurm clusters.') - - ignore_nodes: Tuple[str, ...] = Field( - title='Ignore Nodes', - default=tuple(), - description='Ignore nodes with Slurm names containing any of the provided substrings.') - - max_concurrent: int = Field( - title='Maximum SSH Connections', - default=10, - description='The maximum number of simultaneous SSH connections to open.') - - log_level: Literal['DEBUG', 'INFO', 'WARNING', 'ERROR'] = Field( - title='Logging Level', - default='INFO', - description='Application logging level.') - - log_path: Optional[Path] = Field( - title='Log Path', - default_factory=lambda: Path('/tmp/shinigami.log'), - description='Optionally log application events to a file.') - - verbosity: Literal['DEBUG', 'INFO', 'WARNING', 'ERROR'] = Field( - title='Default console verbosity', - default='ERROR', - description='Default verbosity level for console output.') - - ssh_timeout: int = Field( - title='SSH Timeout', - default=120, - description='Maximum time in seconds to complete an outbound SSH connection.') - - @classmethod - def load(cls, path: Path = SETTINGS_PATH) -> Settings: - """Factory method for loading application settings from disk - - If a settings file does not exist, return default settings values. - - Args: - path: The settings file to read - - Returns: - An instance of the parent class - """ - - if path.exists(): - return cls.model_validate(json.loads(path.read_text())) - - return cls() # Returns default settings diff --git a/shinigami/utils.py b/shinigami/utils.py index d7a84b6..dd15b29 100755 --- a/shinigami/utils.py +++ b/shinigami/utils.py @@ -5,7 +5,7 @@ from io import StringIO from shlex import split from subprocess import Popen, PIPE -from typing import Union, Tuple, Collection +from typing import Union, Tuple, Collection, List import asyncssh import pandas as pd @@ -58,24 +58,21 @@ def get_nodes(cluster: str, ignore_substring: Collection[str]) -> set: async def terminate_errant_processes( node: str, - ssh_limit: asyncio.Semaphore, - uid_whitelist, - timeout: int = 120, + uid_whitelist: Collection[Union[int, List[int]]], + ssh_limit: asyncio.Semaphore = asyncio.Semaphore(1), + ssh_options: asyncssh.SSHClientConnectionOptions = None, debug: bool = False ) -> None: """Terminate non-Slurm processes on a given node Args: node: The DNS resolvable name of the node to terminate processes on - ssh_limit: Semaphore object used to limit concurrent SSH connections uid_whitelist: Do not terminate processes owned by the given UID - timeout: Maximum time in seconds to complete an outbound SSH connection + ssh_limit: Semaphore object used to limit concurrent SSH connections + ssh_options: Options for configuring the outbound SSH connection debug: Log which process to terminate but do not terminate them """ - # Define SSH connection settings - ssh_options = asyncssh.SSHClientConnectionOptions(connect_timeout=timeout) - logging.debug(f'[{node}] Waiting for SSH pool') async with ssh_limit, asyncssh.connect(node, options=ssh_options) as conn: logging.info(f'[{node}] Scanning for processes') @@ -88,11 +85,12 @@ async def terminate_errant_processes( # Identify orphaned processes and filter them by the UID whitelist orphaned = process_df[process_df.PPID == INIT_PROCESS_ID] terminate = orphaned[orphaned['UID'].apply(id_in_whitelist, whitelist=uid_whitelist)] + for _, row in terminate.iterrows(): - logging.debug(f'[{node}] Marking for termination {dict(row)}') + logging.info(f'[{node}] Marking for termination {dict(row)}') if terminate.empty: - logging.info(f'[{node}] No orphans found') + logging.info(f'[{node}] no processes found') elif not debug: proc_id_str = ','.join(terminate.PGID.unique().astype(str)) diff --git a/tests/cli/test_application.py b/tests/cli/test_application.py deleted file mode 100644 index 441ecfe..0000000 --- a/tests/cli/test_application.py +++ /dev/null @@ -1,81 +0,0 @@ -"""Tests for the ``cli.Application`` class.""" - -import logging -from unittest import TestCase - -from shinigami.cli import Application - - -class ConsoleLoggingConfiguration(TestCase): - """Test the application verbosity is set to match command-line arguments""" - - def test_root_logs_to_console(self) -> None: - """Test the root logger logs to the console""" - - Application.execute(['--debug']) - handler_names = [handler.name for handler in logging.getLogger().handlers] - self.assertIn('console_handler', handler_names) - - def test_console_logger_has_stream_handler(self) -> None: - """Test the ``console`` logger has a single ``StreamHandler``""" - - Application.execute(['--debug']) - handlers = logging.getLogger('console_logger').handlers - - self.assertEqual(1, len(handlers)) - self.assertIsInstance(handlers[0], logging.StreamHandler) - - def test_verbose_level_zero(self): - """Test the application defaults to logging errors and above in the console""" - - Application.execute(['--debug']) - for handler in logging.getLogger('console_logger').handlers: - self.assertEqual(logging.ERROR, handler.level) - - def test_verbose_level_one(self): - """Test a single verbose flag sets the logging level to ``WARNING``""" - - Application.execute(['-v', '--debug']) - for handler in logging.getLogger('console_logger').handlers: - self.assertEqual(logging.WARNING, handler.level) - - def test_verbose_level_two(self): - """Test two verbose flags sets the logging level to ``INFO``""" - - Application.execute(['-vv', '--debug']) - for handler in logging.getLogger('console_logger').handlers: - self.assertEqual(logging.INFO, handler.level) - - def test_verbose_level_three(self): - """Test three verbose flags sets the logging level to ``DEBUG``""" - - Application.execute(['-vvv', '--debug']) - for handler in logging.getLogger('console_logger').handlers: - self.assertEqual(logging.DEBUG, handler.level) - - def test_verbose_level_many(self): - """Test several verbose flags sets the logging level to ``DEBUG``""" - - Application.execute(['-vvvvvvvvvv', '--debug']) - for handler in logging.getLogger('console_logger').handlers: - self.assertEqual(logging.DEBUG, handler.level) - - -class FileLoggingConfiguration(TestCase): - """Test the configuration for logging to a file""" - - def test_root_logs_to_file(self) -> None: - """Test the root logger logs to the log file""" - - Application.execute(['--debug']) - handler_names = [handler.name for handler in logging.getLogger().handlers] - self.assertIn('log_file_handler', handler_names) - - def test_file_logger_has_file_handler(self) -> None: - """Test the ``file_logger`` logger has a single ``FileHandler``""" - - Application.execute(['--debug']) - handlers = logging.getLogger('file_logger').handlers - - self.assertEqual(1, len(handlers)) - self.assertIsInstance(handlers[0], logging.FileHandler) diff --git a/tests/cli/test_parser.py b/tests/cli/test_parser.py index c2499f1..c3e4243 100644 --- a/tests/cli/test_parser.py +++ b/tests/cli/test_parser.py @@ -5,31 +5,145 @@ from shinigami.cli import Parser -class DebugOption(TestCase): - """Test the behavior of the ``debug`` option""" +class ScanParser(TestCase): + """Test the behavior of the ``scan`` subparser""" - def test_default_is_false(self) -> None: - """Test the ``debug`` argument defaults to ``False``""" + def test_debug_option(self) -> None: + """Test the ``debug`` argument""" - args = Parser().parse_args([]) - self.assertFalse(args.debug) + parser = Parser() + + scan_command = ['scan', '-c', 'development', '-u' '100'] + self.assertFalse(parser.parse_args(scan_command).debug) + + scan_command_debug = ['scan', '-c', 'development', '-u' '100', '--debug'] + self.assertTrue(parser.parse_args(scan_command_debug).debug) + + def test_verbose_arg(self) -> None: + """Test the verbosity argument counts the number of provided flags""" + + parser = Parser() + base_command = ['scan', '-c', 'development', '-u' '100'] + + self.assertEqual(0, parser.parse_args(base_command).verbosity) + self.assertEqual(1, parser.parse_args(base_command + ['-v']).verbosity) + self.assertEqual(2, parser.parse_args(base_command + ['-vv']).verbosity) + self.assertEqual(3, parser.parse_args(base_command + ['-vvv']).verbosity) + self.assertEqual(5, parser.parse_args(base_command + ['-vvvvv']).verbosity) + + def test_clusters_arg(self) -> None: + """Test parsing of the ``clusters`` argument""" + + parser = Parser() + + single_cluster_out = ['development'] + single_cluster_cmd = ['scan', '-c', *single_cluster_out, '-u', '100'] + self.assertSequenceEqual(single_cluster_out, parser.parse_args(single_cluster_cmd).clusters) + + multi_cluster_out = ['dev1', 'dev2', 'dev3'] + multi_cluster_cmd = ['scan', '-c', *multi_cluster_out, '-u', '100'] + self.assertSequenceEqual(multi_cluster_out, parser.parse_args(multi_cluster_cmd).clusters) + + def test_ignore_nodes(self) -> None: + """Test parsing of the ``ignore-nodes`` argument""" + + parser = Parser() + base_command = ['scan', '-c', 'development', '-u' '100'] + + single_node_out = ['node1'] + single_node_cmd = base_command + ['-i', 'node1'] + self.assertSequenceEqual(single_node_out, parser.parse_args(single_node_cmd).ignore_nodes) + + multi_node_out = ['node1', 'node2'] + multi_node_cmd = base_command + ['-i', 'node1', 'node2'] + self.assertSequenceEqual(multi_node_out, parser.parse_args(multi_node_cmd).ignore_nodes) + + def test_uid_whitelist_arg(self) -> None: + """Test parsing of the ``uid-whitelist`` argument""" + + parser = Parser() - def test_enabled_is_true(self) -> None: - """Test the ``debug`` flag stores a ``True`` value""" + # Test for a single integer + single_int_command = 'scan -c development -u 100'.split() + single_int_out = [100] + self.assertSequenceEqual(single_int_out, parser.parse_args(single_int_command).uid_whitelist) - args = Parser().parse_args(['--debug']) - self.assertTrue(args.debug) + # Test for a multiple integers + multi_int_command = 'scan -c development -u 100 200'.split() + multi_int_out = [100, 200] + self.assertSequenceEqual(multi_int_out, parser.parse_args(multi_int_command).uid_whitelist) + # Test for a list type + single_list_command = 'scan -c development -u [100,200]'.split() + single_list_out = [[100, 200]] + self.assertSequenceEqual(single_list_out, parser.parse_args(single_list_command).uid_whitelist) -class VerboseOption(TestCase): - """Test the verbosity flag""" + # Test for a mix of types + mixed_command = 'scan -c development -u 100 [200,300] 400 [500,600]'.split() + mixed_out = [100, [200, 300], 400, [500, 600]] + self.assertSequenceEqual(mixed_out, parser.parse_args(mixed_command).uid_whitelist) - def test_counts_instances(self) -> None: - """Test the parser counts the number of provided flags""" + +class TerminateParser(TestCase): + """Test the behavior of the ``terminate`` subparser""" + + def test_debug_option(self) -> None: + """Test the ``debug`` argument""" + + parser = Parser() + + terminate_command = ['terminate', '-n', 'node1', '-u', '100'] + self.assertFalse(parser.parse_args(terminate_command).debug) + + terminate_command_debug = ['terminate', '-n', 'node1', '-u', '100', '--debug'] + self.assertTrue(parser.parse_args(terminate_command_debug).debug) + + def test_verbose_arg(self) -> None: + """Test the verbosity argument counts the number of provided flags""" + + parser = Parser() + base_command = ['terminate', '-n', 'node', '-u' '100'] + + self.assertEqual(0, parser.parse_args(base_command).verbosity) + self.assertEqual(1, parser.parse_args(base_command + ['-v']).verbosity) + self.assertEqual(2, parser.parse_args(base_command + ['-vv']).verbosity) + self.assertEqual(3, parser.parse_args(base_command + ['-vvv']).verbosity) + self.assertEqual(5, parser.parse_args(base_command + ['-vvvvv']).verbosity) + + def test_nodes_arg(self) -> None: + """Test parsing of the ``nodes`` argument""" + + parser = Parser() + + single_node_out = ['development'] + single_node_cmd = ['terminate', '-n', *single_node_out, '-u', '100'] + self.assertSequenceEqual(single_node_out, parser.parse_args(single_node_cmd).nodes) + + multi_node_out = ['dev1', 'dev2', 'dev3'] + multi_node_cmd = ['terminate', '-n', *multi_node_out, '-u', '100'] + self.assertSequenceEqual(multi_node_out, parser.parse_args(multi_node_cmd).nodes) + + def test_uid_whitelist_arg(self) -> None: + """Test parsing of the ``uid-whitelist`` argument""" parser = Parser() - self.assertEqual(0, parser.parse_args([]).verbosity) - self.assertEqual(1, parser.parse_args(['-v']).verbosity) - self.assertEqual(2, parser.parse_args(['-vv']).verbosity) - self.assertEqual(3, parser.parse_args(['-vvv']).verbosity) - self.assertEqual(5, parser.parse_args(['-vvvvv']).verbosity) + + # Test for a single integer + single_int_command = 'terminate -n node -u 100'.split() + single_int_out = [100] + self.assertSequenceEqual(single_int_out, parser.parse_args(single_int_command).uid_whitelist) + + # Test for a multiple integers + multi_int_command = 'terminate -n node -u 100 200'.split() + multi_int_out = [100, 200] + self.assertSequenceEqual(multi_int_out, parser.parse_args(multi_int_command).uid_whitelist) + + # Test for a list type + single_list_command = 'terminate -n node -u [100,200]'.split() + single_list_out = [[100, 200]] + self.assertSequenceEqual(single_list_out, parser.parse_args(single_list_command).uid_whitelist) + + # Test for a mix of types + mixed_command = 'terminate -n node -u 100 [200,300] 400 [500,600]'.split() + mixed_out = [100, [200, 300], 400, [500, 600]] + self.assertSequenceEqual(mixed_out, parser.parse_args(mixed_command).uid_whitelist) diff --git a/tests/settings/__init__.py b/tests/settings/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/settings/test_settings.py b/tests/settings/test_settings.py deleted file mode 100644 index bb99d1e..0000000 --- a/tests/settings/test_settings.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Tests for the ``settings.Settings`` class""" - -from pathlib import Path -from tempfile import NamedTemporaryFile -from unittest import TestCase - -from shinigami.settings import Settings - - -class Defaults(TestCase): - """Test default values""" - - def test_defaults_allow_init(self) -> None: - """Test default values are sufficient for instantiating a new Settings object""" - - Settings() - - def test_debug_is_false(self) -> None: - """Test the ``debug`` setting defaults to ``False``""" - - self.assertFalse(Settings().debug) - - def test_ignore_nodes_empty(self) -> None: - """Test the ``ignore_nodes`` setting is empty""" - - self.assertEqual(tuple(), Settings().ignore_nodes) - - -class Load(TestCase): - """Test loading settings from disk via the ``load`` function""" - - def test_file_does_not_exist(self) -> None: - """Test default settings are used when the file does not exist""" - - self.assertEqual(Settings(), Settings.load()) - - def test_file_exists(self) -> None: - """Test settings are successfully parsed from disk""" - - with NamedTemporaryFile() as temp_file: - file_path = Path(temp_file.name) - file_path.write_text('{ "debug": true, "clusters": ["cluster1", "cluster2"] }') - settings = Settings.load(file_path) - - self.assertTrue(settings.debug) - self.assertSequenceEqual(["cluster1", "cluster2"], settings.clusters) From f8a6969936cdc26637ca71aa4b65645bd8e69ff9 Mon Sep 17 00:00:00 2001 From: Daniel Perrefort Date: Tue, 17 Oct 2023 12:08:06 -0400 Subject: [PATCH 3/5] Run tests in Slurm test env (#88) * Adds CI for running tests (#15) * Adds CI for running tests * Fix CI typo * Adds setup step for python * Adds minimal pyproject.toml file * Update pyproject.toml * Adds missing --with directive * Adds missing description field * Adds missing docstring [skip ci] * Updates the way coverage is run * Expand .gitignore * Exceptions top (#16) * Moves exception pattern checking into dedicated method * Update top of file comments for clarity * Deletes test placeholder * Fix boolean typo * Add codacy badges to readme (#17) * Adds tests for piping shell commands (#20) * Drop accumulator pattern for node names (#21) * Use builtin support for logging (#22) * Switch to builtin logging * Explicitly set logging level * Moves logging to /var * Log directly to syslog * Add missing import * Replace accumulator (#23) * Replace accumulator patterns with single iteration * Fix typo in list unpacking * Eod cleanup (#24) * Drop support for regex ignore patterns * Increase number/detail of log messages * Drop Python2 support and impliment Python3 features * Delete test_check_ignore_node.py * Simplify set logic * Adds ability to whitelist UID and UID ranges (#25) * Adds basic whitelist * Replaces usernames with uid * Updates docs and comments * Ignore uid from 0 to 15000 * Formally package source code for pip installability (#28) * Move source into a package * Add GPL3 license * Update pyproject.toml * Removes 'if __name__' conditional * Drop /home/djperrefort/GitHub/pitt-crc/shinigami reference in testing CI * Update test suite structure to match package structure * Add CI workflow for publishing to PyPI (#30) * Updates Readme and Package Docs (#31) * Updates README * Updates package docstring * Adds ability to load settings from settings file (#32) * Introduces settings module * Configure settings to load from disk * Add dedicated method for loading settings from disk * Adds ability to skip FileNotFoundError error * Adds basic CLI parsing (#34) * Initial commit of cli module * Fixes outdated import signature * Mves main function into Application class * Mid-flight cleanup pass * Updates import signatures in test suite * Adds support for async execution across nodes in the same cluster (#35) * Execute SSH commands asynchronously * Run async functins using asyncio.gather * Fix typo in SSH connection * Run all remote commands through SSH object * Imposes limit on max SSH connections * PEP8 * Drops old tests * Adds dummy test coverage for ci * Drops Settings.load_from_disk method (#36) * Adds debug option to CLI (#37) * Build out test coverage (#38) * Adds tests for parser debug option * Adds tests for Settings class * Adds tests for id_in_whitelist function * Adds verbosity argument to CLI (#39) * Adds configurable logging * Adds verbosity argument * Adds tests for logging configuration * Adds tests for CLI --debug option (#40) * Delete lock file * Fix syntax typo in async generator * Eliminate global semaphre and glbal settings to address cncurrency locks * Logs errors collected by asyncio.gather * Move semaphore inside event loop * Adds nde name to SSH related logs * Pre-review cleanup * Minor capitalization typos * Drop yaml for json * Update setup instructions README.md There was a typo in the cron job example. * Run AI linter (#50) * Makes Application.settings private * Runs AI linter * Adds ssh timeout option (#49) * Drop whitelisting functionality for GIDs (#51) * Drops GID whitelist * Drops GID related tests * Move logic for loading settings into Settings class (#53) * Move logic for loading settings into Settings class * Revert changes to parser class * Raise error on file not found * Revert last commit * Terminate orphaned processes matching blacklisted user IDs (#55) * Terminate orphaned processes in userlist * Rename whitelist to blacklist * Kill processes using group ID * Fix bug in fetching of remote process data * updates tests * Lower pandas requirement * Adds test coverage for settings file parsing (#56) * Adds test coverage for settings file parsing * PEP8 * Update testing CI to use slurm * Adds rudementary test to try out slurm support * Replace settings file with CLI arguments (#80) * Outline CLI options to replace settings module * Print help text on error * Parse UID list as json string * Dynamically determine valid arguments from func signature * Add log message when there are no processes to terminate * Add nargs='+' to uid-whitelist arg * Updates docstring * Adds missing type hints * Introduce dedicat ssh argument group * Abstract away SSH options in function signatures * Adds tests for argument parsing * Drops application tests * Makes logging setup private * Shorten line lengths * Merge V0.3.3 updates into v0.4.x (#86) * Bump actions/checkout from 3 to 4 (#81) Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Updates CI (#82) * Fix termination signal (#84) --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Fix conflicts in workflow files * Add docker test env back into testing matrix * Drops applciation logging tests * Updates test node names * Adds newer python versions to CI * Updats expected node names * Run codacy report from bash shell --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/PackageTest.yml | 25 +++++++++++++++++++------ pyproject.toml | 2 +- tests/utils/test_get_nodes.py | 19 +++++++++++++++++++ 3 files changed, 39 insertions(+), 7 deletions(-) create mode 100644 tests/utils/test_get_nodes.py diff --git a/.github/workflows/PackageTest.yml b/.github/workflows/PackageTest.yml index 973a58b..b241b8f 100644 --- a/.github/workflows/PackageTest.yml +++ b/.github/workflows/PackageTest.yml @@ -14,17 +14,29 @@ jobs: strategy: fail-fast: false matrix: - python-version: [ "3.8", "3.9", "3.10", "3.11" ] + slurm_version: + - 20.11.9.1 + - 22.05.2.1 + - 23.02.5.1 + python_version: + - 3.8 + - 3.9 + - 3.10 + - 3.11 + + container: + image: ghcr.io/pitt-crc/test-env:${{ matrix.slurm_version }} + credentials: + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} steps: + - name: Setup environment + run: /usr/local/bin/entrypoint.sh + - name: Checkout repository uses: actions/checkout@v4 - - name: Setup Python - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Install Poetry uses: snok/install-poetry@v1 with: @@ -43,6 +55,7 @@ jobs: - name: Report partial coverage results if: github.event_name != 'release' run: bash <(curl -Ls https://coverage.codacy.com/get.sh) report --partial -l Python -r coverage.xml + shell: bash env: CODACY_PROJECT_TOKEN: ${{ secrets.CODACY_PROJECT_TOKEN }} diff --git a/pyproject.toml b/pyproject.toml index f9db1e2..41460fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ shinigami = "shinigami.cli:Application.execute" python = ">=3.8" pydantic = "^2.0.3" pydantic-settings = "^2.0.2" -asyncssh = {extras = ["bcrypt", "fido2"], version = "^2.13.2"} +asyncssh = { extras = ["bcrypt", "fido2"], version = "^2.13.2" } pandas = "1.5.3" # Newer versions are incompatible with Python 3.8 [tool.poetry.group.tests] diff --git a/tests/utils/test_get_nodes.py b/tests/utils/test_get_nodes.py new file mode 100644 index 0000000..74c31fb --- /dev/null +++ b/tests/utils/test_get_nodes.py @@ -0,0 +1,19 @@ +"""Tests for the ``utils.get_nodes`` function""" + +from unittest import TestCase + +from shinigami import utils + +# For information on resources defined in the testing environment, +# see https://github.com/pitt-crc/Slurm-Test-Environment/ +TEST_CLUSTER = 'development' +TEST_NODES = set(f'c{i}' for i in range(1, 11)) + + +class NodesMatchTestEnvironment(TestCase): + """Test the returned node list matches values defined in the testing environment""" + + def test_returned_nodes(self) -> None: + """Test returned nodes match hose defined in the slurm test env""" + + self.assertSequenceEqual(TEST_NODES, utils.get_nodes(TEST_CLUSTER, [])) From f05ba7b0881dc8b76ff670318307d613d89ea4d1 Mon Sep 17 00:00:00 2001 From: Daniel Perrefort Date: Wed, 18 Oct 2023 13:54:01 -0400 Subject: [PATCH 4/5] Update permissions in PackagePublish.yml (#91) --- .github/workflows/PackagePublish.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/PackagePublish.yml b/.github/workflows/PackagePublish.yml index 7bb6460..25a88af 100644 --- a/.github/workflows/PackagePublish.yml +++ b/.github/workflows/PackagePublish.yml @@ -12,6 +12,9 @@ permissions: jobs: run-tests: name: Tests + permissions: + packages: read + contents: read uses: ./.github/workflows/PackageTest.yml secrets: inherit From dffe834632ae298101d857ff9f43f1f30b88322e Mon Sep 17 00:00:00 2001 From: Daniel Perrefort Date: Thu, 2 Nov 2023 11:15:13 -0400 Subject: [PATCH 5/5] Ignore entire node names instead of substrings (#93) * Make node whitelisting explicit instead of substring * Updates and cnsolidates tess for getting node names * Parser help text updates --- shinigami/cli.py | 2 +- shinigami/utils.py | 13 ++++-- tests/test_utils.py | 71 +++++++++++++++++++++++++++++ tests/utils/__init__.py | 0 tests/utils/test_get_nodes.py | 19 -------- tests/utils/test_id_in_whitelist.py | 29 ------------ 6 files changed, 80 insertions(+), 54 deletions(-) create mode 100644 tests/test_utils.py delete mode 100644 tests/utils/__init__.py delete mode 100644 tests/utils/test_get_nodes.py delete mode 100644 tests/utils/test_id_in_whitelist.py diff --git a/shinigami/cli.py b/shinigami/cli.py index bc3c9db..6295057 100644 --- a/shinigami/cli.py +++ b/shinigami/cli.py @@ -59,7 +59,7 @@ def __init__(self) -> None: scan = subparsers.add_parser('scan', parents=[common], help='terminate processes on one or more clusters') scan.set_defaults(callable=Application.scan) scan.add_argument('-c', '--clusters', nargs='+', required=True, help='cluster names to scan') - scan.add_argument('-i', '--ignore-nodes', nargs='*', default=[], help='ignore given nodes') + scan.add_argument('-i', '--ignore-nodes', nargs='*', default=[], help='ignore the given nodes') scan.add_argument('-u', '--uid-whitelist', nargs='+', type=loads, default=[0], help='user IDs to scan') # Subparser for the `Application.terminate` method diff --git a/shinigami/utils.py b/shinigami/utils.py index dd15b29..afac877 100755 --- a/shinigami/utils.py +++ b/shinigami/utils.py @@ -14,7 +14,11 @@ def id_in_whitelist(id_value: int, whitelist: Collection[Union[int, Tuple[int, int]]]) -> bool: - """Return whether an ID is in a list of ID values + """Return whether an ID is in a list of ID value definitions + + The `whitelist` of ID values can contain a mix of integers and tuples + of integer ranges. For example, [0, 1, (2, 9), 10] includes all IDs from + zero through ten. Args: id_value: The ID value to check @@ -34,12 +38,12 @@ def id_in_whitelist(id_value: int, whitelist: Collection[Union[int, Tuple[int, i return False -def get_nodes(cluster: str, ignore_substring: Collection[str]) -> set: +def get_nodes(cluster: str, ignore_nodes: Collection[str] = tuple()) -> set: """Return a set of nodes included in a given Slurm cluster Args: cluster: Name of the cluster to fetch nodes for - ignore_substring: Do not return nodes containing any of the given substrings + ignore_nodes: Do not return nodes included in the provided list Returns: A set of cluster names @@ -52,8 +56,7 @@ def get_nodes(cluster: str, ignore_substring: Collection[str]) -> set: raise RuntimeError(stderr) all_nodes = stdout.decode().strip().split('\n') - is_valid = lambda node: not any(substring in node for substring in ignore_substring) - return set(filter(is_valid, all_nodes)) + return set(node for node in all_nodes if node not in ignore_nodes) async def terminate_errant_processes( diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..d9866e1 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,71 @@ +"""Tests for the `utils` module.""" + +from unittest import TestCase + +from shinigami import utils +from shinigami.utils import id_in_whitelist + +# For information on resources defined in the testing environment, +# see https://github.com/pitt-crc/Slurm-Test-Environment/ +TEST_CLUSTER = 'development' +TEST_NODES = set(f'c{i}' for i in range(1, 11)) + + +class Whitelisting(TestCase): + """Tests for the ``id_in_whitelist`` function""" + + def test_empty_whitelist(self) -> None: + """Test the return value is ``False`` for all ID values when the whitelist is empty""" + + self.assertFalse(id_in_whitelist(0, [])) + self.assertFalse(id_in_whitelist(123, [])) + + def test_whitelisted_by_id(self) -> None: + """Test return values for a whitelist of explicit ID values""" + + whitelist = (123, 456, 789) + self.assertTrue(id_in_whitelist(456, whitelist)) + self.assertFalse(id_in_whitelist(0, whitelist)) + + def test_whitelisted_by_id_range(self) -> None: + """Test return values for a whitelist of ID ranges""" + + whitelist = (0, 1, 2, (100, 300)) + self.assertTrue(id_in_whitelist(123, whitelist)) + self.assertFalse(id_in_whitelist(301, whitelist)) + + +class GetNodes(TestCase): + """Tests for the ``get_nodes`` function""" + + def test_nodes_match_test_env(self) -> None: + """Test the returned node list matches values defined in the testing environment""" + + self.assertCountEqual(TEST_NODES, utils.get_nodes(TEST_CLUSTER)) + + def test_ignore_substring(self) -> None: + """Test nodes with the included substring are ignored""" + + exclude_node = 'c1' + + # Create a copy of the test nodes with on element missing + expected_nodes = TEST_NODES.copy() + expected_nodes.remove(exclude_node) + + returned_nodes = utils.get_nodes(TEST_CLUSTER, [exclude_node]) + self.assertCountEqual(expected_nodes, returned_nodes) + + def test_missing_cluster(self) -> None: + """Test an error is raised for a cluster name that does not exist""" + + with self.assertRaisesRegex(RuntimeError, 'No cluster \'fake_cluster\''): + utils.get_nodes('fake_cluster') + + def test_missing_node(self) -> None: + """Test no error is raised when an excuded node des not exist""" + + excluded_nodes = {'c1', 'fake_node'} + expected_nodes = TEST_NODES - excluded_nodes + + returned_nodes = utils.get_nodes(TEST_CLUSTER, excluded_nodes) + self.assertCountEqual(expected_nodes, returned_nodes) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/utils/test_get_nodes.py b/tests/utils/test_get_nodes.py deleted file mode 100644 index 74c31fb..0000000 --- a/tests/utils/test_get_nodes.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Tests for the ``utils.get_nodes`` function""" - -from unittest import TestCase - -from shinigami import utils - -# For information on resources defined in the testing environment, -# see https://github.com/pitt-crc/Slurm-Test-Environment/ -TEST_CLUSTER = 'development' -TEST_NODES = set(f'c{i}' for i in range(1, 11)) - - -class NodesMatchTestEnvironment(TestCase): - """Test the returned node list matches values defined in the testing environment""" - - def test_returned_nodes(self) -> None: - """Test returned nodes match hose defined in the slurm test env""" - - self.assertSequenceEqual(TEST_NODES, utils.get_nodes(TEST_CLUSTER, [])) diff --git a/tests/utils/test_id_in_whitelist.py b/tests/utils/test_id_in_whitelist.py deleted file mode 100644 index b762d6a..0000000 --- a/tests/utils/test_id_in_whitelist.py +++ /dev/null @@ -1,29 +0,0 @@ -"""Tests for the ``utils.id_in_whitelist`` function""" - -from unittest import TestCase - -from shinigami.utils import id_in_whitelist - - -class Whitelisting(TestCase): - """Test ID values are correctly whitelisted""" - - def test_empty_whitelist(self) -> None: - """Test the return value is ``False`` for all ID values when the whitelist is empty""" - - self.assertFalse(id_in_whitelist(0, [])) - self.assertFalse(id_in_whitelist(123, [])) - - def test_whitelisted_by_id(self) -> None: - """Test return values for a whitelist of explicit ID values""" - - whitelist = (123, 456, 789) - self.assertTrue(id_in_whitelist(456, whitelist)) - self.assertFalse(id_in_whitelist(0, whitelist)) - - def test_whitelisted_by_id_range(self) -> None: - """Test return values for a whitelist of ID ranges""" - - whitelist = (0, 1, 2, (100, 300)) - self.assertTrue(id_in_whitelist(123, whitelist)) - self.assertFalse(id_in_whitelist(301, whitelist))