Skip to content

Commit

Permalink
Improves CLI help text with full length descriptions (#96)
Browse files Browse the repository at this point in the history
* Simplify CLI options and help text

* Adds descriptions to subparsers
  • Loading branch information
djperrefort authored Dec 20, 2023
1 parent c33a749 commit 0310e0c
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 93 deletions.
104 changes: 65 additions & 39 deletions shinigami/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import logging.config
import sys
from argparse import ArgumentParser
from argparse import ArgumentParser, RawTextHelpFormatter
from json import loads
from typing import List, Collection, Union

Expand All @@ -14,57 +14,72 @@
from . import __version__, utils


class BaseParser(ArgumentParser):
"""Custom argument parser that prints help text on error"""

def error(self, message: str) -> None:
"""Prints a usage message and exit
Args:
message: The usage message
"""

if len(sys.argv) == 1:
self.print_help()

raise SystemExit(message)


class Parser(BaseParser):
class Parser(ArgumentParser):
"""Defines the command-line interface and parses command-line arguments"""

def __init__(self) -> None:
"""Define the command-line interface"""

# Configure the top level parser
super().__init__(prog='shinigami', description='Scan Slurm compute nodes and terminate orphan processes.')
subparsers = self.add_subparsers(required=True, parser_class=BaseParser)
subparsers = self.add_subparsers(required=True, parser_class=ArgumentParser)
self.add_argument('--version', action='version', version=__version__)

# This parser defines reusable arguments and is not exposed to the user
# The `common` parser holds reusable argument definitions
common = ArgumentParser(add_help=False)

ssh_group = common.add_argument_group('ssh options')
ssh_group.add_argument('-m', '--max-concurrent', type=int, default=1, help='maximum concurrent SSH connections')
ssh_group.add_argument('-t', '--ssh-timeout', type=int, default=120, help='SSH connection timeout in seconds')
ssh_group.add_argument('-m', dest='max_concurrent', type=int, default=1, help='maximum concurrent SSH connections (Default: 1)')
ssh_group.add_argument('-t', dest='ssh_timeout', type=int, default=120, help='SSH connection timeout in seconds (Default: 120)')

debug_group = common.add_argument_group('debugging options')
debug_group.add_argument('--debug', action='store_true', help='run the application in debug mode')
debug_group.add_argument('-v', action='count', dest='verbosity', default=0,
help='set verbosity to warning (-v), info (-vv), or debug (-vvv)')
debug_group.add_argument('-v', action='count', dest='verbosity', default=0, help='set verbosity to warning (-v), info (-vv), or debug (-vvv)')

# Subparser for the `Application.scan` method
scan = subparsers.add_parser('scan', parents=[common], help='terminate processes on one or more clusters')
scan = subparsers.add_parser(
'scan', parents=[common], formatter_class=RawTextHelpFormatter,
help='terminate processes on one or more clusters',
description=(
"The `scan` function automatically terminates orphaned processes on all compute nodes in a Slurm cluster.\n"
"It is provided as a shorthand alternative to calling the `terminate` command with manually defined node names.\n\n"
"Slurm nodes are identified using the slurm installation on the current machine.\n"
"User IDs can be specified individually (e.g. `-u 1000 1001 1002 1003`) or as ranges (e.g. `-u 1000 [1001,1003]`)."))

scan.set_defaults(callable=Application.scan)
scan.add_argument('-c', '--clusters', nargs='+', required=True, help='cluster names to scan')
scan.add_argument('-i', '--ignore-nodes', nargs='*', default=[], help='ignore the given nodes')
scan.add_argument('-u', '--uid-whitelist', nargs='+', type=loads, default=[0], help='user IDs to scan')
scan_group = scan.add_argument_group('scanning options')
scan_group.add_argument('-c', dest='clusters', metavar='CLUS', nargs='+', required=True, help='Slurm cluster name(s) to scan')
scan_group.add_argument('-i', dest='ignore_nodes', metavar='NODE', nargs='*', default=[], help='ignore the given node(s)')
scan_group.add_argument('-u', dest='uid_whitelist', metavar='UID', nargs='+', type=loads, default=[0], help='only terminate processes owned by the given user IDs')

# Subparser for the `Application.terminate` method
terminate = subparsers.add_parser('terminate', parents=[common], help='terminate processes on a single node')
terminate = subparsers.add_parser(
'terminate', parents=[common], formatter_class=RawTextHelpFormatter,
help='terminate processes on one or more compute nodes',
description=(
"Automatically terminate orphaned processes on one or more Slurm compute nodes.\n\n"
"Processes are only terminated under the following conditions:\n"
f" 1. The process belongs to a process tree parented by init (PID {utils.INIT_PROCESS_ID})\n"
" 2. The associated user ID is in the given UID whitelist\n"
" 3. The user is not running any Slurm jobs on the parent machine\n\n"
"User IDs can be specified individually (e.g. `-u 1000 1001 1002 1003`) or as ranges (e.g. `-u 1000 [1001,1003]`)."))

terminate.set_defaults(callable=Application.terminate)
terminate.add_argument('-n', '--nodes', nargs='+', required=True, help='the DNS name of the node to terminate')
terminate.add_argument('-u', '--uid-whitelist', nargs='+', type=loads, default=[0], help='user IDs to scan')
terminate_group = terminate.add_argument_group('termination options')
terminate_group.add_argument('-n', dest='nodes', metavar='NODE', nargs='+', required=True, help='the DNS name(s) of the node(s) to terminate')
terminate_group.add_argument('-u', dest='uid_whitelist', metavar='UID', nargs='+', type=loads, default=[0], help='only terminate processes owned by the given user IDs')

def error(self, message: str) -> None:
"""Print a usage message and exits the application
Args:
message: The usage message
"""

if len(sys.argv) == 1:
self.print_help()
raise SystemExit

super().error(message)


class Application:
Expand All @@ -79,12 +94,23 @@ def _configure_logging(verbosity: int) -> None:
- file_logger: For logging to the log file only
- root: For logging to the console and log file
Console verbosity levels are defined as following:
- 0: ERROR
- 1: WARNING
- 2: INFO
- 3: DEBUG
- Any other value: DEBUG
Args:
verbosity: The console verbosity defined as the count passed to the commandline
verbosity: The console verbosity level
"""

verbosity_to_log_level = {0: logging.ERROR, 1: logging.WARNING, 2: logging.INFO, 3: logging.DEBUG}
console_log_level = verbosity_to_log_level.get(verbosity, logging.DEBUG)
console_log_level = {
0: logging.ERROR,
1: logging.WARNING,
2: logging.INFO,
3: logging.DEBUG
}.get(verbosity, logging.DEBUG)

logging.config.dictConfig({
'version': 1,
Expand Down Expand Up @@ -122,7 +148,7 @@ async def scan(
clusters: Collection[str],
ignore_nodes: Collection[str],
uid_whitelist: Collection[Union[int, List[int]]],
max_concurrent: asyncio.Semaphore,
max_concurrent: int,
ssh_timeout: int,
debug: bool
) -> None:
Expand All @@ -147,7 +173,7 @@ async def scan(
async def terminate(
nodes: Collection[str],
uid_whitelist: Collection[Union[int, List[int]]],
max_concurrent: asyncio.Semaphore,
max_concurrent: int,
ssh_timeout: int,
debug: bool
) -> None:
Expand Down Expand Up @@ -182,7 +208,7 @@ async def terminate(

@classmethod
def execute(cls, arg_list: List[str] = None) -> None:
"""Parse command-line arguments and execute the application
"""Parse command line arguments and execute the application
Args:
arg_list: Optionally parse the given arguments instead of the command line
Expand All @@ -192,7 +218,7 @@ def execute(cls, arg_list: List[str] = None) -> None:
cls._configure_logging(args.verbosity)

try:
# Extract the subset of arguments that are valid for the function `args.callable`
# Extract the subset of arguments that are valid for the `args.callable` function
valid_params = inspect.signature(args.callable).parameters
valid_arguments = {key: value for key, value in vars(args).items() if key in valid_params}
asyncio.run(args.callable(**valid_arguments))
Expand Down
44 changes: 12 additions & 32 deletions shinigami/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,6 @@
# architecture, but 1 is an almost universal default
INIT_PROCESS_ID = 1

# Custom type hints
Whitelist = Collection[Union[int, Tuple[int, int]]]


def _id_in_whitelist(id_value: int, whitelist: Whitelist) -> bool:
"""Return whether an ID is in a list of ID value definitions
The `whitelist` of ID values can contain a mix of integers and tuples
of integer ranges. For example, [0, 1, (2, 9), 10] includes all IDs from
zero through ten.
Args:
id_value: The ID value to check
whitelist: A collection of ID values and ID ranges
Returns:
Whether the ID is in the whitelist
"""

for id_def in whitelist:
if hasattr(id_def, '__getitem__') and (id_def[0] <= id_value <= id_def[1]):
return True

elif id_value == id_def:
return True

return False


def get_nodes(cluster: str, ignore_nodes: Collection[str] = tuple()) -> set:
"""Return a set of nodes included in a given Slurm cluster
Expand Down Expand Up @@ -100,7 +72,7 @@ def include_orphaned_processes(df: pd.DataFrame) -> pd.DataFrame:
return df[df['PPID'] == INIT_PROCESS_ID]


def include_user_whitelist(df: pd.DataFrame, uid_whitelist: Whitelist) -> pd.DataFrame:
def include_user_whitelist(df: pd.DataFrame, uid_whitelist: Collection[Union[int, Tuple[int, int]]]) -> pd.DataFrame:
"""Filter a DataFrame to only include a subset of user IDs
Given a DataFrame with system process data, return a subset of the data
Expand All @@ -116,8 +88,16 @@ def include_user_whitelist(df: pd.DataFrame, uid_whitelist: Whitelist) -> pd.Dat
A copy of the given DataFrame
"""

whitelist_index = df['UID'].apply(_id_in_whitelist, whitelist=uid_whitelist)
return df[whitelist_index]
whitelisted_uid_values = []
for elt in uid_whitelist:
if isinstance(elt, int):
whitelisted_uid_values.append(elt)

else:
umin, umax = elt
whitelisted_uid_values.extend(range(umin, umax))

return df[df['UID'].isin(whitelisted_uid_values)]


def exclude_active_slurm_users(df: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -162,7 +142,7 @@ async def terminate_errant_processes(
logging.info(f'[{node}] Scanning for processes')
process_df = await get_remote_processes(conn)

# Filter them by various whitelist/blacklist criteria
# Filter process data by various whitelist/blacklist criteria
process_df = include_orphaned_processes(process_df)
process_df = include_user_whitelist(process_df, uid_whitelist)
process_df = exclude_active_slurm_users(process_df)
Expand Down
22 changes: 0 additions & 22 deletions tests/cli/test_base_parser.py

This file was deleted.

10 changes: 10 additions & 0 deletions tests/cli/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@
from shinigami.cli import Parser


class ErrorHandling(TestCase):
"""Test custom parsing logic encapsulated by the `BaseParser` class"""

def test_errors_raise_system_exit(self) -> None:
"""Test error messages are raised as `SystemExit` instances"""

with self.assertRaises(SystemExit):
Parser().error("This is an error message")


class ScanSubParser(TestCase):
"""Test the behavior of the `scan` subparser"""

Expand Down

0 comments on commit 0310e0c

Please sign in to comment.