diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..c423fae --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,21 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "monthly" + open-pull-requests-limit: 100 + groups: + python-dependencies: + patterns: + - "*" + + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" + open-pull-requests-limit: 100 + groups: + actions-dependencies: + patterns: + - "*" diff --git a/.github/workflows/CodeQL.yml b/.github/workflows/CodeQL.yml new file mode 100644 index 0000000..b46e0a7 --- /dev/null +++ b/.github/workflows/CodeQL.yml @@ -0,0 +1,48 @@ +name: CodeQL + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + schedule: + - cron: 0 7 1 * * + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ python ] + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: ${{ matrix.language }} + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v2 + with: + category: /language:${{matrix.language}} + + # Use this job for branch protection rules + report-codeql-status: + name: Report CodeQL Status + if: always() + needs: analyze + runs-on: ubuntu-latest + steps: + - name: Check build status + if: ${{ contains(needs.*.result, 'failure') }} + run: exit 1 diff --git a/.github/workflows/PackagePublish.yml b/.github/workflows/PackagePublish.yml index 7af07f6..eb783e3 100644 --- a/.github/workflows/PackagePublish.yml +++ b/.github/workflows/PackagePublish.yml @@ -31,6 +31,12 @@ jobs: environment: publish-pypi steps: +<<<<<<< HEAD +======= + - name: Checkout source + uses: actions/checkout@v4 + +>>>>>>> v0.4.x - name: Set up Python uses: actions/setup-python@v4 with: @@ -41,6 +47,7 @@ jobs: with: virtualenvs-create: false +<<<<<<< HEAD - name: Checkout source uses: actions/checkout@v3 @@ -50,6 +57,14 @@ jobs: run: | release_tag=${{github.ref}} poetry version "${release_tag:11}" +======= + # Get the new package version from the release tag + # Git release tags are expected to start with "refs/tags/v" + - name: Set package version + run: | + release_tag=${{github.ref}} + poetry version "${release_tag#refs/tags/v}" +>>>>>>> v0.4.x - name: Build package run: poetry build -v @@ -57,7 +72,11 @@ jobs: - name: Publish package uses: pypa/gh-action-pypi-publish@release/v1 with: +<<<<<<< HEAD verbose: true +======= + print-hash: true +>>>>>>> v0.4.x repository-url: ${{ matrix.host }} user: ${{ secrets.REPO_USER }} password: ${{ secrets.REPO_PASSWORD }} diff --git a/.github/workflows/PackageTest.yml b/.github/workflows/PackageTest.yml index c8298ca..11ecc03 100644 --- a/.github/workflows/PackageTest.yml +++ b/.github/workflows/PackageTest.yml @@ -4,6 +4,11 @@ on: workflow_dispatch: workflow_call: push: +<<<<<<< HEAD +======= + schedule: + - cron: 0 7 1,15 * * +>>>>>>> v0.4.x jobs: run-tests: @@ -12,6 +17,7 @@ jobs: strategy: fail-fast: false matrix: +<<<<<<< HEAD slurm_version: - 20.02.5.1 - 20.11.9.1 @@ -37,6 +43,23 @@ jobs: run: | pip install poetry poetry env use python${{ matrix.python_version }} +======= + python-version: [ "3.8", "3.9", "3.10", "3.11" ] + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + virtualenvs-create: false +>>>>>>> v0.4.x - name: Install dependencies run: poetry install --with tests @@ -50,7 +73,10 @@ jobs: # Report partial coverage results to codacy for the current python version - name: Report partial coverage results if: github.event_name != 'release' +<<<<<<< HEAD shell: bash +======= +>>>>>>> v0.4.x run: bash <(curl -Ls https://coverage.codacy.com/get.sh) report --partial -l Python -r coverage.xml env: CODACY_PROJECT_TOKEN: ${{ secrets.CODACY_PROJECT_TOKEN }} @@ -63,7 +89,10 @@ jobs: runs-on: ubuntu-latest steps: - name: Finish reporting coverage +<<<<<<< HEAD shell: bash +======= +>>>>>>> v0.4.x run: bash <(curl -Ls https://coverage.codacy.com/get.sh) final env: CODACY_PROJECT_TOKEN: ${{ secrets.CODACY_PROJECT_TOKEN }} diff --git a/pyproject.toml b/pyproject.toml index 2b3df0e..41460fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,10 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.poetry] -name = "shinigami" +name = "crc-shinigami" +packages = [ + { include = "shinigami" }, +] version = "0.0.0" # Version is set dynamically by the CI tool on publication authors = ["Pitt Center for Research Computing"] readme = "README.md" @@ -28,7 +31,7 @@ shinigami = "shinigami.cli:Application.execute" python = ">=3.8" pydantic = "^2.0.3" pydantic-settings = "^2.0.2" -asyncssh = {extras = ["bcrypt", "fido2"], version = "^2.13.2"} +asyncssh = { extras = ["bcrypt", "fido2"], version = "^2.13.2" } pandas = "1.5.3" # Newer versions are incompatible with Python 3.8 [tool.poetry.group.tests] diff --git a/shinigami/__init__.py b/shinigami/__init__.py index d9a4d52..1c0f3a5 100644 --- a/shinigami/__init__.py +++ b/shinigami/__init__.py @@ -10,7 +10,7 @@ import importlib.metadata try: - __version__ = importlib.metadata.version('shinigami') + __version__ = importlib.metadata.version('crc-shinigami') except importlib.metadata.PackageNotFoundError: # pragma: no cover __version__ = '0.0.0' diff --git a/shinigami/cli.py b/shinigami/cli.py index 75772eb..1dfb152 100644 --- a/shinigami/cli.py +++ b/shinigami/cli.py @@ -1,58 +1,93 @@ -"""The application command-line interface.""" +"""The executable application and its command-line interface.""" import asyncio +import inspect import logging import logging.config -import logging.handlers -from argparse import RawTextHelpFormatter, ArgumentParser -from typing import List +import sys +from argparse import ArgumentParser +from json import loads +from typing import List, Collection, Union + +from asyncssh import SSHClientConnectionOptions from . import __version__, utils -from .settings import Settings, SETTINGS_PATH -class Parser(ArgumentParser): +class BaseParser(ArgumentParser): + """Custom argument parser that prints help text on error""" + + def error(self, message: str) -> None: + """Prints a usage message and exit + + Args: + message: The usage message + """ + + if len(sys.argv) == 1: + self.print_help() + super().exit(1) + + else: + super().error(message) + + +class Parser(BaseParser): """Defines the command-line interface and parses command-line arguments""" def __init__(self) -> None: """Define the command-line interface""" - super().__init__( - prog='shinigami', - formatter_class=RawTextHelpFormatter, # Allow newlines in description text - description=( - 'Scan Slurm compute nodes and terminate errant processes.\n\n' - f'See {SETTINGS_PATH} for the current application settings.' - )) - + # Configure the top level parser + super().__init__(prog='shinigami', description='Scan Slurm compute nodes and terminate orphan processes.') + subparsers = self.add_subparsers(required=True, parser_class=BaseParser) self.add_argument('--version', action='version', version=__version__) - self.add_argument('--debug', action='store_true', help='force the application to run in debug mode') - self.add_argument('-v', action='count', dest='verbosity', default=0, - help='set output verbosity to warning (-v), info (-vv), or debug (-vvv)') + # This parser defines reusable arguments and is not exposed to the user + common = ArgumentParser(add_help=False) -class Application: - """Entry point for instantiating and executing the application""" + ssh_group = common.add_argument_group('ssh options') + ssh_group.add_argument('-m', '--max-concurrent', type=int, default=1, help='maximum concurrent SSH connections') + ssh_group.add_argument('-t', '--ssh-timeout', type=int, default=120, help='SSH connection timeout in seconds') - def __init__(self, settings: Settings) -> None: - """Instantiate a new instance of the application + debug_group = common.add_argument_group('debugging options') + debug_group.add_argument('--debug', action='store_true', help='run the application in debug mode') + debug_group.add_argument('-v', action='count', dest='verbosity', default=0, + help='set verbosity to warning (-v), info (-vv), or debug (-vvv)') - Args: - settings: Settings to use when configuring and executing the application - """ + # Subparser for the `Application.scan` method + scan = subparsers.add_parser('scan', parents=[common], help='terminate processes on one or more clusters') + scan.set_defaults(callable=Application.scan) + scan.add_argument('-c', '--clusters', nargs='+', required=True, help='cluster names to scan') + scan.add_argument('-i', '--ignore-nodes', nargs='*', help='ignore given nodes') + scan.add_argument('-u', '--uid-whitelist', nargs='+', type=loads, required=True, help='user IDs to scan') + + # Subparser for the `Application.terminate` method + terminate = subparsers.add_parser('terminate', parents=[common], help='terminate processes on a single node') + terminate.set_defaults(callable=Application.terminate) + terminate.add_argument('-n', '--nodes', nargs='+', required=True, help='the DNS name of the node to terminate') + terminate.add_argument('-u', '--uid-whitelist', nargs='+', type=loads, required=True, help='user IDs to scan') - self._settings = settings - self._configure_logging() - def _configure_logging(self) -> None: +class Application: + """Entry point for instantiating and executing the application""" + + @staticmethod + def _configure_logging(verbosity: int) -> None: """Configure Python logging - Configured loggers: - console_logger: For logging to the console only - file_logger: For logging to the log file only - root: For logging to the console and log file + Configured loggers include the following: + - console_logger: For logging to the console only + - file_logger: For logging to the log file only + - root: For logging to the console and log file + + Args: + verbosity: The console verbosity defined as the count passed to the commandline """ + verbosity_to_log_level = {0: logging.ERROR, 1: logging.WARNING, 2: logging.INFO, 3: logging.DEBUG} + console_log_level = verbosity_to_log_level.get(verbosity, logging.DEBUG) + logging.config.dictConfig({ 'version': 1, 'disable_existing_loggers': True, @@ -61,7 +96,7 @@ def _configure_logging(self) -> None: 'format': '%(levelname)8s: %(message)s' }, 'log_file_formatter': { - 'format': '%(levelname)8s | %(asctime)s | %(message)s' + 'format': '%(asctime)s | %(levelname)8s | %(message)s' }, }, 'handlers': { @@ -69,13 +104,12 @@ def _configure_logging(self) -> None: 'class': 'logging.StreamHandler', 'stream': 'ext://sys.stdout', 'formatter': 'console_formatter', - 'level': self._settings.verbosity + 'level': console_log_level }, 'log_file_handler': { - 'class': 'logging.FileHandler', + 'class': 'logging.handlers.SysLogHandler', 'formatter': 'log_file_formatter', - 'level': self._settings.log_level, - 'filename': self._settings.log_path + 'level': 'DEBUG', }, }, 'loggers': { @@ -85,50 +119,89 @@ def _configure_logging(self) -> None: } }) - async def run(self) -> None: - """Terminate errant processes on all clusters/nodes configured in application settings.""" + @staticmethod + async def scan( + clusters: Collection[str], + ignore_nodes: Collection[str], + uid_whitelist: Collection[Union[int, List[int]]], + max_concurrent: asyncio.Semaphore, + ssh_timeout: int, + debug: bool + ) -> None: + """Terminate orphaned processes on all clusters/nodes configured in application settings. - if not self._settings.clusters: - logging.warning('No cluster names configured in application settings.') + Args: + clusters: Slurm cluster names + ignore_nodes: List of nodes to ignore + uid_whitelist: UID values to terminate orphaned processes for + max_concurrent: Maximum number of concurrent ssh connections + ssh_timeout: Timeout for SSH connections + debug: Optionally log but do not terminate processes + """ - ssh_limit = asyncio.Semaphore(self._settings.max_concurrent) - for cluster in self._settings.clusters: + # Clusters are handled synchronously, nodes are handled asynchronously + for cluster in clusters: logging.info(f'Starting scan for nodes in cluster {cluster}') + nodes = utils.get_nodes(cluster, ignore_nodes) + await Application.terminate(nodes, uid_whitelist, max_concurrent, ssh_timeout, debug) + + @staticmethod + async def terminate( + nodes: Collection[str], + uid_whitelist: Collection[Union[int, List[int]]], + max_concurrent: asyncio.Semaphore, + ssh_timeout: int, + debug: bool + ) -> None: + """Terminate processes on a given node + + Args: + nodes: + uid_whitelist: UID values to terminate orphaned processes for + max_concurrent: Maximum number of concurrent ssh connections + ssh_timeout: Timeout for SSH connections + debug: Optionally log but do not terminate processes + """ - # Launch a concurrent job for each node in the cluster - nodes = utils.get_nodes(cluster, self._settings.ignore_nodes) - coroutines = [ - utils.terminate_errant_processes( - node=node, - ssh_limit=ssh_limit, - uid_blacklist=self._settings.uid_blacklist, - timeout=self._settings.ssh_timeout, - debug=self._settings.debug) - for node in nodes - ] - - # Gather results from each concurrent run and check for errors - results = await asyncio.gather(*coroutines, return_exceptions=True) - for node, result in zip(nodes, results): - if isinstance(result, Exception): - logging.error(f'Error with node {node}: {result}') + ssh_options = SSHClientConnectionOptions(connect_timeout=ssh_timeout) + + # Launch a concurrent job for each node in the cluster + coroutines = [ + utils.terminate_errant_processes( + node=node, + uid_whitelist=uid_whitelist, + ssh_limit=asyncio.Semaphore(max_concurrent), + ssh_options=ssh_options, + debug=debug) + for node in nodes + ] + + # Gather results from each concurrent run and check for errors + results = await asyncio.gather(*coroutines, return_exceptions=True) + for node, result in zip(nodes, results): + if isinstance(result, Exception): + logging.error(f'Error with node {node}: {result}') @classmethod def execute(cls, arg_list: List[str] = None) -> None: - """Parse command-line arguments and execute the application""" + """Parse command-line arguments and execute the application - args = Parser().parse_args(arg_list) - verbosity_to_log_level = {0: logging.ERROR, 1: logging.WARNING, 2: logging.INFO, 3: logging.DEBUG} + Args: + arg_list: Optionally parse the given arguments instead of the command line + """ - # Load application settings - override defaults using parsed arguments - settings = Settings.load() - settings.verbosity = verbosity_to_log_level.get(args.verbosity, logging.DEBUG) - settings.debug = settings.debug or args.debug + args = Parser().parse_args(arg_list) + cls._configure_logging(args.verbosity) try: - application = cls(settings) - asyncio.run(application.run()) + # Extract the subset of arguments that are valid for the function ``args.callable`` + valid_params = inspect.signature(args.callable).parameters + valid_arguments = {key: value for key, value in vars(args).items() if key in valid_params} + asyncio.run(args.callable(**valid_arguments)) + + except KeyboardInterrupt: + pass - except Exception as caught: # pragma: no cover + except Exception as caught: logging.getLogger('file_logger').critical('Application crash', exc_info=caught) logging.getLogger('console_logger').critical(str(caught)) diff --git a/shinigami/utils.py b/shinigami/utils.py index 0c2943c..dd15b29 100755 --- a/shinigami/utils.py +++ b/shinigami/utils.py @@ -5,24 +5,26 @@ from io import StringIO from shlex import split from subprocess import Popen, PIPE -from typing import Union, Tuple, Collection +from typing import Union, Tuple, Collection, List import asyncssh import pandas as pd +INIT_PROCESS_ID = 1 -def id_in_blacklist(id_value: int, blacklist: Collection[Union[int, Tuple[int, int]]]) -> bool: + +def id_in_whitelist(id_value: int, whitelist: Collection[Union[int, Tuple[int, int]]]) -> bool: """Return whether an ID is in a list of ID values Args: id_value: The ID value to check - blacklist: A collection of ID values and ID ranges + whitelist: A collection of ID values and ID ranges Returns: - Whether the ID is in the blacklist + Whether the ID is in the whitelist """ - for id_def in blacklist: + for id_def in whitelist: if hasattr(id_def, '__getitem__') and (id_def[0] <= id_value <= id_def[1]): return True @@ -46,7 +48,6 @@ def get_nodes(cluster: str, ignore_substring: Collection[str]) -> set: logging.debug(f'Fetching node list for cluster {cluster}') sub_proc = Popen(split(f"sinfo -M {cluster} -N -o %N -h"), stdout=PIPE, stderr=PIPE) stdout, stderr = sub_proc.communicate() - if stderr: raise RuntimeError(stderr) @@ -57,41 +58,41 @@ def get_nodes(cluster: str, ignore_substring: Collection[str]) -> set: async def terminate_errant_processes( node: str, - ssh_limit: asyncio.Semaphore, - uid_blacklist, - timeout: int = 120, + uid_whitelist: Collection[Union[int, List[int]]], + ssh_limit: asyncio.Semaphore = asyncio.Semaphore(1), + ssh_options: asyncssh.SSHClientConnectionOptions = None, debug: bool = False ) -> None: """Terminate non-Slurm processes on a given node Args: node: The DNS resolvable name of the node to terminate processes on + uid_whitelist: Do not terminate processes owned by the given UID ssh_limit: Semaphore object used to limit concurrent SSH connections - uid_blacklist: Do not terminate processes owned by the given UID - timeout: Maximum time in seconds to complete an outbound SSH connection + ssh_options: Options for configuring the outbound SSH connection debug: Log which process to terminate but do not terminate them """ - # Define SSH connection settings - ssh_options = asyncssh.SSHClientConnectionOptions(connect_timeout=timeout) - - logging.debug(f'Waiting to connect to {node}') + logging.debug(f'[{node}] Waiting for SSH pool') async with ssh_limit, asyncssh.connect(node, options=ssh_options) as conn: + logging.info(f'[{node}] Scanning for processes') # Fetch running process data from the remote machine - logging.info(f'[{node}] Scanning for processes') - ps_data = await conn.run('ps -eo pid,pgid,uid', check=True) - process_df = pd.read_fwf(StringIO(ps_data.stdout)) + # Add 1 to column widths when parsing ps output to account for space between columns + ps_return = await conn.run('ps -eo pid:10,ppid:10,pgid:10,uid:10,cmd:500', check=True) + process_df = pd.read_fwf(StringIO(ps_return.stdout), widths=[11, 11, 11, 11, 500]) + + # Identify orphaned processes and filter them by the UID whitelist + orphaned = process_df[process_df.PPID == INIT_PROCESS_ID] + terminate = orphaned[orphaned['UID'].apply(id_in_whitelist, whitelist=uid_whitelist)] - # Identify orphaned processes and filter them by the UID blacklist - orphaned = process_df[process_df.PPID == 1] - terminate = orphaned[orphaned['UID'].apply(id_in_blacklist, blacklist=uid_blacklist)] for _, row in terminate.iterrows(): - logging.debug(f'[{node}] Marking for termination {dict(row)}') + logging.info(f'[{node}] Marking for termination {dict(row)}') - if debug: - return + if terminate.empty: + logging.info(f'[{node}] no processes found') - proc_id_str = ' '.join(terminate.PGID) - logging.info(f"[{node}] Sending termination signal for process groups {proc_id_str}") - await conn.run(f"pkill --signal -9 --pgroup {proc_id_str}", check=True) + elif not debug: + proc_id_str = ','.join(terminate.PGID.unique().astype(str)) + logging.info(f"[{node}] Sending termination signal for process groups {proc_id_str}") + await conn.run(f"pkill --signal 9 --pgroup {proc_id_str}", check=True) diff --git a/tests/cli/test_parser.py b/tests/cli/test_parser.py index c2499f1..c3e4243 100644 --- a/tests/cli/test_parser.py +++ b/tests/cli/test_parser.py @@ -5,31 +5,145 @@ from shinigami.cli import Parser -class DebugOption(TestCase): - """Test the behavior of the ``debug`` option""" +class ScanParser(TestCase): + """Test the behavior of the ``scan`` subparser""" - def test_default_is_false(self) -> None: - """Test the ``debug`` argument defaults to ``False``""" + def test_debug_option(self) -> None: + """Test the ``debug`` argument""" - args = Parser().parse_args([]) - self.assertFalse(args.debug) + parser = Parser() + + scan_command = ['scan', '-c', 'development', '-u' '100'] + self.assertFalse(parser.parse_args(scan_command).debug) + + scan_command_debug = ['scan', '-c', 'development', '-u' '100', '--debug'] + self.assertTrue(parser.parse_args(scan_command_debug).debug) + + def test_verbose_arg(self) -> None: + """Test the verbosity argument counts the number of provided flags""" + + parser = Parser() + base_command = ['scan', '-c', 'development', '-u' '100'] + + self.assertEqual(0, parser.parse_args(base_command).verbosity) + self.assertEqual(1, parser.parse_args(base_command + ['-v']).verbosity) + self.assertEqual(2, parser.parse_args(base_command + ['-vv']).verbosity) + self.assertEqual(3, parser.parse_args(base_command + ['-vvv']).verbosity) + self.assertEqual(5, parser.parse_args(base_command + ['-vvvvv']).verbosity) + + def test_clusters_arg(self) -> None: + """Test parsing of the ``clusters`` argument""" + + parser = Parser() + + single_cluster_out = ['development'] + single_cluster_cmd = ['scan', '-c', *single_cluster_out, '-u', '100'] + self.assertSequenceEqual(single_cluster_out, parser.parse_args(single_cluster_cmd).clusters) + + multi_cluster_out = ['dev1', 'dev2', 'dev3'] + multi_cluster_cmd = ['scan', '-c', *multi_cluster_out, '-u', '100'] + self.assertSequenceEqual(multi_cluster_out, parser.parse_args(multi_cluster_cmd).clusters) + + def test_ignore_nodes(self) -> None: + """Test parsing of the ``ignore-nodes`` argument""" + + parser = Parser() + base_command = ['scan', '-c', 'development', '-u' '100'] + + single_node_out = ['node1'] + single_node_cmd = base_command + ['-i', 'node1'] + self.assertSequenceEqual(single_node_out, parser.parse_args(single_node_cmd).ignore_nodes) + + multi_node_out = ['node1', 'node2'] + multi_node_cmd = base_command + ['-i', 'node1', 'node2'] + self.assertSequenceEqual(multi_node_out, parser.parse_args(multi_node_cmd).ignore_nodes) + + def test_uid_whitelist_arg(self) -> None: + """Test parsing of the ``uid-whitelist`` argument""" + + parser = Parser() - def test_enabled_is_true(self) -> None: - """Test the ``debug`` flag stores a ``True`` value""" + # Test for a single integer + single_int_command = 'scan -c development -u 100'.split() + single_int_out = [100] + self.assertSequenceEqual(single_int_out, parser.parse_args(single_int_command).uid_whitelist) - args = Parser().parse_args(['--debug']) - self.assertTrue(args.debug) + # Test for a multiple integers + multi_int_command = 'scan -c development -u 100 200'.split() + multi_int_out = [100, 200] + self.assertSequenceEqual(multi_int_out, parser.parse_args(multi_int_command).uid_whitelist) + # Test for a list type + single_list_command = 'scan -c development -u [100,200]'.split() + single_list_out = [[100, 200]] + self.assertSequenceEqual(single_list_out, parser.parse_args(single_list_command).uid_whitelist) -class VerboseOption(TestCase): - """Test the verbosity flag""" + # Test for a mix of types + mixed_command = 'scan -c development -u 100 [200,300] 400 [500,600]'.split() + mixed_out = [100, [200, 300], 400, [500, 600]] + self.assertSequenceEqual(mixed_out, parser.parse_args(mixed_command).uid_whitelist) - def test_counts_instances(self) -> None: - """Test the parser counts the number of provided flags""" + +class TerminateParser(TestCase): + """Test the behavior of the ``terminate`` subparser""" + + def test_debug_option(self) -> None: + """Test the ``debug`` argument""" + + parser = Parser() + + terminate_command = ['terminate', '-n', 'node1', '-u', '100'] + self.assertFalse(parser.parse_args(terminate_command).debug) + + terminate_command_debug = ['terminate', '-n', 'node1', '-u', '100', '--debug'] + self.assertTrue(parser.parse_args(terminate_command_debug).debug) + + def test_verbose_arg(self) -> None: + """Test the verbosity argument counts the number of provided flags""" + + parser = Parser() + base_command = ['terminate', '-n', 'node', '-u' '100'] + + self.assertEqual(0, parser.parse_args(base_command).verbosity) + self.assertEqual(1, parser.parse_args(base_command + ['-v']).verbosity) + self.assertEqual(2, parser.parse_args(base_command + ['-vv']).verbosity) + self.assertEqual(3, parser.parse_args(base_command + ['-vvv']).verbosity) + self.assertEqual(5, parser.parse_args(base_command + ['-vvvvv']).verbosity) + + def test_nodes_arg(self) -> None: + """Test parsing of the ``nodes`` argument""" + + parser = Parser() + + single_node_out = ['development'] + single_node_cmd = ['terminate', '-n', *single_node_out, '-u', '100'] + self.assertSequenceEqual(single_node_out, parser.parse_args(single_node_cmd).nodes) + + multi_node_out = ['dev1', 'dev2', 'dev3'] + multi_node_cmd = ['terminate', '-n', *multi_node_out, '-u', '100'] + self.assertSequenceEqual(multi_node_out, parser.parse_args(multi_node_cmd).nodes) + + def test_uid_whitelist_arg(self) -> None: + """Test parsing of the ``uid-whitelist`` argument""" parser = Parser() - self.assertEqual(0, parser.parse_args([]).verbosity) - self.assertEqual(1, parser.parse_args(['-v']).verbosity) - self.assertEqual(2, parser.parse_args(['-vv']).verbosity) - self.assertEqual(3, parser.parse_args(['-vvv']).verbosity) - self.assertEqual(5, parser.parse_args(['-vvvvv']).verbosity) + + # Test for a single integer + single_int_command = 'terminate -n node -u 100'.split() + single_int_out = [100] + self.assertSequenceEqual(single_int_out, parser.parse_args(single_int_command).uid_whitelist) + + # Test for a multiple integers + multi_int_command = 'terminate -n node -u 100 200'.split() + multi_int_out = [100, 200] + self.assertSequenceEqual(multi_int_out, parser.parse_args(multi_int_command).uid_whitelist) + + # Test for a list type + single_list_command = 'terminate -n node -u [100,200]'.split() + single_list_out = [[100, 200]] + self.assertSequenceEqual(single_list_out, parser.parse_args(single_list_command).uid_whitelist) + + # Test for a mix of types + mixed_command = 'terminate -n node -u 100 [200,300] 400 [500,600]'.split() + mixed_out = [100, [200, 300], 400, [500, 600]] + self.assertSequenceEqual(mixed_out, parser.parse_args(mixed_command).uid_whitelist) diff --git a/tests/utils/test_id_in_whitelist.py b/tests/utils/test_id_in_whitelist.py index 4d26478..b762d6a 100644 --- a/tests/utils/test_id_in_whitelist.py +++ b/tests/utils/test_id_in_whitelist.py @@ -2,7 +2,7 @@ from unittest import TestCase -from shinigami.utils import id_in_blacklist +from shinigami.utils import id_in_whitelist class Whitelisting(TestCase): @@ -11,19 +11,19 @@ class Whitelisting(TestCase): def test_empty_whitelist(self) -> None: """Test the return value is ``False`` for all ID values when the whitelist is empty""" - self.assertFalse(id_in_blacklist(0, [])) - self.assertFalse(id_in_blacklist(123, [])) + self.assertFalse(id_in_whitelist(0, [])) + self.assertFalse(id_in_whitelist(123, [])) def test_whitelisted_by_id(self) -> None: """Test return values for a whitelist of explicit ID values""" whitelist = (123, 456, 789) - self.assertTrue(id_in_blacklist(456, whitelist)) - self.assertFalse(id_in_blacklist(0, whitelist)) + self.assertTrue(id_in_whitelist(456, whitelist)) + self.assertFalse(id_in_whitelist(0, whitelist)) def test_whitelisted_by_id_range(self) -> None: """Test return values for a whitelist of ID ranges""" whitelist = (0, 1, 2, (100, 300)) - self.assertTrue(id_in_blacklist(123, whitelist)) - self.assertFalse(id_in_blacklist(301, whitelist)) + self.assertTrue(id_in_whitelist(123, whitelist)) + self.assertFalse(id_in_whitelist(301, whitelist))