Skip to content

Commit

Permalink
Replace settings file with commandline arguments (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
djperrefort authored Oct 17, 2023
1 parent 4567201 commit 7b5291d
Show file tree
Hide file tree
Showing 7 changed files with 280 additions and 304 deletions.
206 changes: 138 additions & 68 deletions shinigami/cli.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,93 @@
"""The application command-line interface."""
"""The executable application and its command-line interface."""

import asyncio
import inspect
import logging
import logging.config
import logging.handlers
from argparse import RawTextHelpFormatter, ArgumentParser
from typing import List
import sys
from argparse import ArgumentParser
from json import loads
from typing import List, Collection, Union

from asyncssh import SSHClientConnectionOptions

from . import __version__, utils
from .settings import Settings, SETTINGS_PATH


class Parser(ArgumentParser):
class BaseParser(ArgumentParser):
"""Custom argument parser that prints help text on error"""

def error(self, message: str) -> None:
"""Prints a usage message and exit
Args:
message: The usage message
"""

if len(sys.argv) == 1:
self.print_help()
super().exit(1)

else:
super().error(message)


class Parser(BaseParser):
"""Defines the command-line interface and parses command-line arguments"""

def __init__(self) -> None:
"""Define the command-line interface"""

super().__init__(
prog='shinigami',
formatter_class=RawTextHelpFormatter, # Allow newlines in description text
description=(
'Scan Slurm compute nodes and terminate errant processes.\n\n'
f'See {SETTINGS_PATH} for the current application settings.'
))

# Configure the top level parser
super().__init__(prog='shinigami', description='Scan Slurm compute nodes and terminate orphan processes.')
subparsers = self.add_subparsers(required=True, parser_class=BaseParser)
self.add_argument('--version', action='version', version=__version__)
self.add_argument('--debug', action='store_true', help='force the application to run in debug mode')
self.add_argument('-v', action='count', dest='verbosity', default=0,
help='set output verbosity to warning (-v), info (-vv), or debug (-vvv)')

# This parser defines reusable arguments and is not exposed to the user
common = ArgumentParser(add_help=False)

class Application:
"""Entry point for instantiating and executing the application"""
ssh_group = common.add_argument_group('ssh options')
ssh_group.add_argument('-m', '--max-concurrent', type=int, default=1, help='maximum concurrent SSH connections')
ssh_group.add_argument('-t', '--ssh-timeout', type=int, default=120, help='SSH connection timeout in seconds')

def __init__(self, settings: Settings) -> None:
"""Instantiate a new instance of the application
debug_group = common.add_argument_group('debugging options')
debug_group.add_argument('--debug', action='store_true', help='run the application in debug mode')
debug_group.add_argument('-v', action='count', dest='verbosity', default=0,
help='set verbosity to warning (-v), info (-vv), or debug (-vvv)')

# Subparser for the `Application.scan` method
scan = subparsers.add_parser('scan', parents=[common], help='terminate processes on one or more clusters')
scan.set_defaults(callable=Application.scan)
scan.add_argument('-c', '--clusters', nargs='+', required=True, help='cluster names to scan')
scan.add_argument('-i', '--ignore-nodes', nargs='*', default=[], help='ignore given nodes')
scan.add_argument('-u', '--uid-whitelist', nargs='+', type=loads, default=[0], help='user IDs to scan')

# Subparser for the `Application.terminate` method
terminate = subparsers.add_parser('terminate', parents=[common], help='terminate processes on a single node')
terminate.set_defaults(callable=Application.terminate)
terminate.add_argument('-n', '--nodes', nargs='+', required=True, help='the DNS name of the node to terminate')
terminate.add_argument('-u', '--uid-whitelist', nargs='+', type=loads, default=[0], help='user IDs to scan')

Args:
settings: Settings to use when configuring and executing the application
"""

self._settings = settings
self._configure_logging()
class Application:
"""Entry point for instantiating and executing the application"""

def _configure_logging(self) -> None:
@staticmethod
def _configure_logging(verbosity: int) -> None:
"""Configure Python logging
Configured loggers:
console_logger: For logging to the console only
file_logger: For logging to the log file only
root: For logging to the console and log file
Configured loggers include the following:
- console_logger: For logging to the console only
- file_logger: For logging to the log file only
- root: For logging to the console and log file
Args:
verbosity: The console verbosity defined as the count passed to the commandline
"""

verbosity_to_log_level = {0: logging.ERROR, 1: logging.WARNING, 2: logging.INFO, 3: logging.DEBUG}
console_log_level = verbosity_to_log_level.get(verbosity, logging.DEBUG)

logging.config.dictConfig({
'version': 1,
'disable_existing_loggers': True,
Expand All @@ -61,21 +96,20 @@ def _configure_logging(self) -> None:
'format': '%(levelname)8s: %(message)s'
},
'log_file_formatter': {
'format': '%(levelname)8s | %(asctime)s | %(message)s'
'format': '%(asctime)s | %(levelname)8s | %(message)s'
},
},
'handlers': {
'console_handler': {
'class': 'logging.StreamHandler',
'stream': 'ext://sys.stdout',
'formatter': 'console_formatter',
'level': self._settings.verbosity
'level': console_log_level
},
'log_file_handler': {
'class': 'logging.FileHandler',
'class': 'logging.handlers.SysLogHandler',
'formatter': 'log_file_formatter',
'level': self._settings.log_level,
'filename': self._settings.log_path
'level': 'DEBUG',
},
},
'loggers': {
Expand All @@ -85,49 +119,85 @@ def _configure_logging(self) -> None:
}
})

async def run(self) -> None:
"""Terminate errant processes on all clusters/nodes configured in application settings."""
@staticmethod
async def scan(
clusters: Collection[str],
ignore_nodes: Collection[str],
uid_whitelist: Collection[Union[int, List[int]]],
max_concurrent: asyncio.Semaphore,
ssh_timeout: int,
debug: bool
) -> None:
"""Terminate orphaned processes on all clusters/nodes configured in application settings.
if not self._settings.clusters:
logging.warning('No cluster names configured in application settings.')
Args:
clusters: Slurm cluster names
ignore_nodes: List of nodes to ignore
uid_whitelist: UID values to terminate orphaned processes for
max_concurrent: Maximum number of concurrent ssh connections
ssh_timeout: Timeout for SSH connections
debug: Optionally log but do not terminate processes
"""

ssh_limit = asyncio.Semaphore(self._settings.max_concurrent)
for cluster in self._settings.clusters:
# Clusters are handled synchronously, nodes are handled asynchronously
for cluster in clusters:
logging.info(f'Starting scan for nodes in cluster {cluster}')
nodes = utils.get_nodes(cluster, ignore_nodes)
await Application.terminate(nodes, uid_whitelist, max_concurrent, ssh_timeout, debug)

@staticmethod
async def terminate(
nodes: Collection[str],
uid_whitelist: Collection[Union[int, List[int]]],
max_concurrent: asyncio.Semaphore,
ssh_timeout: int,
debug: bool
) -> None:
"""Terminate processes on a given node
# Launch a concurrent job for each node in the cluster
nodes = utils.get_nodes(cluster, self._settings.ignore_nodes)
coroutines = [
utils.terminate_errant_processes(
node=node,
ssh_limit=ssh_limit,
uid_whitelist=self._settings.uid_whitelist,
timeout=self._settings.ssh_timeout,
debug=self._settings.debug)
for node in nodes
]

# Gather results from each concurrent run and check for errors
results = await asyncio.gather(*coroutines, return_exceptions=True)
for node, result in zip(nodes, results):
if isinstance(result, Exception):
logging.error(f'Error with node {node}: {result}')
Args:
nodes:
uid_whitelist: UID values to terminate orphaned processes for
max_concurrent: Maximum number of concurrent ssh connections
ssh_timeout: Timeout for SSH connections
debug: Optionally log but do not terminate processes
"""

ssh_options = SSHClientConnectionOptions(connect_timeout=ssh_timeout)

# Launch a concurrent job for each node in the cluster
coroutines = [
utils.terminate_errant_processes(
node=node,
uid_whitelist=uid_whitelist,
ssh_limit=asyncio.Semaphore(max_concurrent),
ssh_options=ssh_options,
debug=debug)
for node in nodes
]

# Gather results from each concurrent run and check for errors
results = await asyncio.gather(*coroutines, return_exceptions=True)
for node, result in zip(nodes, results):
if isinstance(result, Exception):
logging.error(f'Error with node {node}: {result}')

@classmethod
def execute(cls, arg_list: List[str] = None) -> None:
"""Parse command-line arguments and execute the application"""
"""Parse command-line arguments and execute the application
args = Parser().parse_args(arg_list)
verbosity_to_log_level = {0: logging.ERROR, 1: logging.WARNING, 2: logging.INFO, 3: logging.DEBUG}
Args:
arg_list: Optionally parse the given arguments instead of the command line
"""

# Load application settings - override defaults using parsed arguments
settings = Settings.load()
settings.verbosity = verbosity_to_log_level.get(args.verbosity, logging.DEBUG)
settings.debug = settings.debug or args.debug
args = Parser().parse_args(arg_list)
cls._configure_logging(args.verbosity)

try:
application = cls(settings)
asyncio.run(application.run())
# Extract the subset of arguments that are valid for the function ``args.callable``
valid_params = inspect.signature(args.callable).parameters
valid_arguments = {key: value for key, value in vars(args).items() if key in valid_params}
asyncio.run(args.callable(**valid_arguments))

except KeyboardInterrupt:
pass
Expand Down
79 changes: 0 additions & 79 deletions shinigami/settings.py

This file was deleted.

Loading

0 comments on commit 7b5291d

Please sign in to comment.