diff --git a/shinigami/cli.py b/shinigami/cli.py index 6295057..ab6437c 100644 --- a/shinigami/cli.py +++ b/shinigami/cli.py @@ -26,10 +26,8 @@ def error(self, message: str) -> None: if len(sys.argv) == 1: self.print_help() - super().exit(1) - else: - super().error(message) + raise SystemExit(message) class Parser(BaseParser): diff --git a/shinigami/utils.py b/shinigami/utils.py index afac877..32f7ac6 100755 --- a/shinigami/utils.py +++ b/shinigami/utils.py @@ -87,15 +87,16 @@ async def terminate_errant_processes( # Identify orphaned processes and filter them by the UID whitelist orphaned = process_df[process_df.PPID == INIT_PROCESS_ID] - terminate = orphaned[orphaned['UID'].apply(id_in_whitelist, whitelist=uid_whitelist)] + whitelist_index = orphaned['UID'].apply(id_in_whitelist, whitelist=uid_whitelist) + to_terminate = orphaned[whitelist_index] - for _, row in terminate.iterrows(): + for _, row in to_terminate.iterrows(): logging.info(f'[{node}] Marking for termination {dict(row)}') - if terminate.empty: + if to_terminate.empty: logging.info(f'[{node}] no processes found') elif not debug: - proc_id_str = ','.join(terminate.PGID.unique().astype(str)) + proc_id_str = ','.join(to_terminate.PGID.unique().astype(str)) logging.info(f"[{node}] Sending termination signal for process groups {proc_id_str}") await conn.run(f"pkill --signal 9 --pgroup {proc_id_str}", check=True) diff --git a/tests/cli/test_parser.py b/tests/cli/test_parser.py index c3e4243..39a2f83 100644 --- a/tests/cli/test_parser.py +++ b/tests/cli/test_parser.py @@ -2,14 +2,26 @@ from unittest import TestCase -from shinigami.cli import Parser +from shinigami.cli import Parser, BaseParser -class ScanParser(TestCase): +class BaseParsing(TestCase): + """Test custom parsing login encapsulated by the `BaseParser` class""" + + def test_error_handling(self) -> None: + """Test error messages are raised as `SystemExit` instances""" + + parser = BaseParser() + error_message = "This is an error message" + with self.assertRaises(SystemExit, msg=error_message): + parser.error(error_message) + + +class ScanSubParser(TestCase): """Test the behavior of the ``scan`` subparser""" def test_debug_option(self) -> None: - """Test the ``debug`` argument""" + """Test parsing of the ``debug`` argument""" parser = Parser() @@ -84,8 +96,8 @@ def test_uid_whitelist_arg(self) -> None: self.assertSequenceEqual(mixed_out, parser.parse_args(mixed_command).uid_whitelist) -class TerminateParser(TestCase): - """Test the behavior of the ``terminate`` subparser""" +class TerminateSubParser(TestCase): + """Test parsing of the behavior of the ``terminate`` subparser""" def test_debug_option(self) -> None: """Test the ``debug`` argument"""