Skip to content

Commit

Permalink
Updates process whitelisting to exclude users running slurm (#95)
Browse files Browse the repository at this point in the history
* Adds grep command for checking proc variables

* Breaks filtering logic into seperate functions

* Minor docstring fix

* Updates type hints

* Outline logic for filtering valid slurm users

* Fills in missing filter logic

* Breaks uitls tests into modules

* Debugging commit to test ssh support in CI

* Adds tests for filtering orphaned procs

* Renames filter functions for clarity

* Updates tests for include_user_whitelist

* Drops tests for remote ssh

* Moves base parser tests into dedicated module

* Doc formatting updates

* Adds tests for the execute method

* Adds tests for exclude_active_slurm_users

* Cleanup pass

* Minor func name edit
  • Loading branch information
djperrefort authored Dec 7, 2023
1 parent 98cb0ce commit c33a749
Show file tree
Hide file tree
Showing 11 changed files with 303 additions and 82 deletions.
4 changes: 0 additions & 4 deletions shinigami/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
"""Shinigami is a command-line application for killing errant processes
on Slurm based compute nodes. The application scans for and terminates any
running processes not associated with a currently running Slurm job.
Individual users and groups can be whitelisted in the application settings file
via UID and GID values. Specific compute nodes can also be ignored using basic
string matching. See the ``settings`` module for more details.
"""

import importlib.metadata
Expand Down
8 changes: 4 additions & 4 deletions shinigami/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""The executable application and its command-line interface."""
"""The executable application and its command line interface."""

import asyncio
import inspect
Expand Down Expand Up @@ -192,14 +192,14 @@ def execute(cls, arg_list: List[str] = None) -> None:
cls._configure_logging(args.verbosity)

try:
# Extract the subset of arguments that are valid for the function ``args.callable``
# Extract the subset of arguments that are valid for the function `args.callable`
valid_params = inspect.signature(args.callable).parameters
valid_arguments = {key: value for key, value in vars(args).items() if key in valid_params}
asyncio.run(args.callable(**valid_arguments))

except KeyboardInterrupt:
except KeyboardInterrupt: # pragma: nocover
pass

except Exception as caught:
except Exception as caught: # pragma: nocover
logging.getLogger('file_logger').critical('Application crash', exc_info=caught)
logging.getLogger('console_logger').critical(str(caught))
111 changes: 94 additions & 17 deletions shinigami/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@
import asyncssh
import pandas as pd

# Technically the init process ID may vary with the system
# architecture, but 1 is an almost universal default
INIT_PROCESS_ID = 1

# Custom type hints
Whitelist = Collection[Union[int, Tuple[int, int]]]

def id_in_whitelist(id_value: int, whitelist: Collection[Union[int, Tuple[int, int]]]) -> bool:

def _id_in_whitelist(id_value: int, whitelist: Whitelist) -> bool:
"""Return whether an ID is in a list of ID value definitions
The `whitelist` of ID values can contain a mix of integers and tuples
The `whitelist` of ID values can contain a mix of integers and tuples
of integer ranges. For example, [0, 1, (2, 9), 10] includes all IDs from
zero through ten.
Expand Down Expand Up @@ -56,7 +61,83 @@ def get_nodes(cluster: str, ignore_nodes: Collection[str] = tuple()) -> set:
raise RuntimeError(stderr)

all_nodes = stdout.decode().strip().split('\n')
return set(node for node in all_nodes if node not in ignore_nodes)
return set(all_nodes) - set(ignore_nodes)


async def get_remote_processes(conn: asyncssh.SSHClientConnection) -> pd.DataFrame:
"""Fetch running process data from a remote machine
The returned DataFrame is guaranteed to have columns `PID`, `PPID`, `PGID`,
`UID`, and `CND`.
Args:
conn: Open SSH connection to the machine
Returns:
A pandas DataFrame with process data
"""

# Add 1 to column widths when parsing ps output to account for space between columns
ps_return = await conn.run('ps -eo pid:10,ppid:10,pgid:10,uid:10,cmd:500', check=True)
return pd.read_fwf(StringIO(ps_return.stdout), widths=[11, 11, 11, 11, 500])


def include_orphaned_processes(df: pd.DataFrame) -> pd.DataFrame:
"""Filter a DataFrame to only include orphaned processes
Given a DataFrame with system process data, return a subset of the data
containing processes parented by `INIT_PROCESS_ID`.
See the `get_remote_processes` function for the assumed DataFrame data model.
Args:
df: A DataFrame with process data
Returns:
A copy of the given DataFrame
"""

return df[df['PPID'] == INIT_PROCESS_ID]


def include_user_whitelist(df: pd.DataFrame, uid_whitelist: Whitelist) -> pd.DataFrame:
"""Filter a DataFrame to only include a subset of user IDs
Given a DataFrame with system process data, return a subset of the data
containing processes owned by the given user IDs.
See the `get_remote_processes` function for the assumed DataFrame data model.
Args:
df: A DataFrame with process data
uid_whitelist: List of user IDs to whitelist
Returns:
A copy of the given DataFrame
"""

whitelist_index = df['UID'].apply(_id_in_whitelist, whitelist=uid_whitelist)
return df[whitelist_index]


def exclude_active_slurm_users(df: pd.DataFrame) -> pd.DataFrame:
"""Filter a DataFrame to exclude user IDs tied to a running slurm job
Given a DataFrame with system process data, return a subset of the data
that excludes processes owned by users running a `slurmd` command.
See the `get_remote_processes` function for the assumed DataFrame data model.
Args:
df: A DataFrame with process data
Returns:
A copy of the given DataFrame
"""

is_slurm = df['CMD'].str.contains('slurmd')
slurm_uids = df['UID'][is_slurm].unique()
return df[~df['UID'].isin(slurm_uids)]


async def terminate_errant_processes(
Expand All @@ -66,11 +147,11 @@ async def terminate_errant_processes(
ssh_options: asyncssh.SSHClientConnectionOptions = None,
debug: bool = False
) -> None:
"""Terminate non-Slurm processes on a given node
"""Terminate orphaned processes on a given node
Args:
node: The DNS resolvable name of the node to terminate processes on
uid_whitelist: Do not terminate processes owned by the given UID
uid_whitelist: Do not terminate processes owned by the given UIDs
ssh_limit: Semaphore object used to limit concurrent SSH connections
ssh_options: Options for configuring the outbound SSH connection
debug: Log which process to terminate but do not terminate them
Expand All @@ -79,24 +160,20 @@ async def terminate_errant_processes(
logging.debug(f'[{node}] Waiting for SSH pool')
async with ssh_limit, asyncssh.connect(node, options=ssh_options) as conn:
logging.info(f'[{node}] Scanning for processes')
process_df = await get_remote_processes(conn)

# Fetch running process data from the remote machine
# Add 1 to column widths when parsing ps output to account for space between columns
ps_return = await conn.run('ps -eo pid:10,ppid:10,pgid:10,uid:10,cmd:500', check=True)
process_df = pd.read_fwf(StringIO(ps_return.stdout), widths=[11, 11, 11, 11, 500])

# Identify orphaned processes and filter them by the UID whitelist
orphaned = process_df[process_df.PPID == INIT_PROCESS_ID]
whitelist_index = orphaned['UID'].apply(id_in_whitelist, whitelist=uid_whitelist)
to_terminate = orphaned[whitelist_index]
# Filter them by various whitelist/blacklist criteria
process_df = include_orphaned_processes(process_df)
process_df = include_user_whitelist(process_df, uid_whitelist)
process_df = exclude_active_slurm_users(process_df)

for _, row in to_terminate.iterrows():
for _, row in process_df.iterrows(): # pragma: nocover
logging.info(f'[{node}] Marking for termination {dict(row)}')

if to_terminate.empty:
if process_df.empty: # pragma: nocover
logging.info(f'[{node}] no processes found')

elif not debug:
proc_id_str = ','.join(to_terminate.PGID.unique().astype(str))
proc_id_str = ','.join(process_df.PGID.unique().astype(str))
logging.info(f"[{node}] Sending termination signal for process groups {proc_id_str}")
await conn.run(f"pkill --signal 9 --pgroup {proc_id_str}", check=True)
24 changes: 24 additions & 0 deletions tests/cli/test_application.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Tests for the `cli.Application` class"""

from unittest import TestCase
from unittest.mock import patch

from shinigami.cli import Application


class MethodRouting(TestCase):
"""Test CLI commands are routed to the correct callable objects"""

def test_scan_method(self) -> None:
"""Test the `scan` command routes to the `scan` method"""

with patch.object(Application, 'scan', autospec=True) as scan:
Application().execute(['scan', '-c', 'cluster1'])
scan.assert_called_once()

def test_terminate_method(self) -> None:
"""Test the `terminate` command routes to the `terminate` method"""

with patch.object(Application, 'terminate', autospec=True) as scan:
Application().execute(['terminate', '-n', 'node1'])
scan.assert_called_once()
22 changes: 22 additions & 0 deletions tests/cli/test_base_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Tests for the `cli.BaseParser` class"""

from unittest import TestCase

from shinigami.cli import BaseParser


class BaseParsing(TestCase):
"""Test custom parsing logic encapsulated by the `BaseParser` class"""

def test_errors_raise_system_exit(self) -> None:
"""Test error messages are raised as `SystemExit` instances"""

with self.assertRaises(SystemExit):
BaseParser().error("This is an error message")

def test_errors_include_message(self) -> None:
"""Test parser messages are included as error messages"""

msg = "This is an error message"
with self.assertRaisesRegex(SystemExit, msg):
BaseParser().error(msg)
44 changes: 14 additions & 30 deletions tests/cli/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,17 @@
"""Tests for the ``cli.Parser`` class"""
"""Tests for the `cli.Parser` class"""

from unittest import TestCase

from shinigami.cli import Parser, BaseParser


class BaseParsing(TestCase):
"""Test custom parsing login encapsulated by the `BaseParser` class"""

def test_error_handling(self) -> None:
"""Test error messages are raised as `SystemExit` instances"""

parser = BaseParser()
error_message = "This is an error message"
with self.assertRaises(SystemExit, msg=error_message):
parser.error(error_message)
from shinigami.cli import Parser


class ScanSubParser(TestCase):
"""Test the behavior of the ``scan`` subparser"""
"""Test the behavior of the `scan` subparser"""

def test_debug_option(self) -> None:
"""Test parsing of the ``debug`` argument"""
def test_debug_arg(self) -> None:
"""Test parsing of the `debug` argument"""

parser = Parser()

scan_command = ['scan', '-c', 'development', '-u' '100']
self.assertFalse(parser.parse_args(scan_command).debug)

Expand All @@ -36,15 +23,14 @@ def test_verbose_arg(self) -> None:

parser = Parser()
base_command = ['scan', '-c', 'development', '-u' '100']

self.assertEqual(0, parser.parse_args(base_command).verbosity)
self.assertEqual(1, parser.parse_args(base_command + ['-v']).verbosity)
self.assertEqual(2, parser.parse_args(base_command + ['-vv']).verbosity)
self.assertEqual(3, parser.parse_args(base_command + ['-vvv']).verbosity)
self.assertEqual(5, parser.parse_args(base_command + ['-vvvvv']).verbosity)

def test_clusters_arg(self) -> None:
"""Test parsing of the ``clusters`` argument"""
"""Test parsing of the `clusters` argument"""

parser = Parser()

Expand All @@ -56,8 +42,8 @@ def test_clusters_arg(self) -> None:
multi_cluster_cmd = ['scan', '-c', *multi_cluster_out, '-u', '100']
self.assertSequenceEqual(multi_cluster_out, parser.parse_args(multi_cluster_cmd).clusters)

def test_ignore_nodes(self) -> None:
"""Test parsing of the ``ignore-nodes`` argument"""
def test_ignore_nodes_arg(self) -> None:
"""Test parsing of the `ignore-nodes` argument"""

parser = Parser()
base_command = ['scan', '-c', 'development', '-u' '100']
Expand All @@ -71,7 +57,7 @@ def test_ignore_nodes(self) -> None:
self.assertSequenceEqual(multi_node_out, parser.parse_args(multi_node_cmd).ignore_nodes)

def test_uid_whitelist_arg(self) -> None:
"""Test parsing of the ``uid-whitelist`` argument"""
"""Test parsing of the `uid-whitelist` argument"""

parser = Parser()

Expand All @@ -97,13 +83,12 @@ def test_uid_whitelist_arg(self) -> None:


class TerminateSubParser(TestCase):
"""Test the behavior of the ``terminate`` subparser"""
"""Test the behavior of the `terminate` subparser"""

def test_debug_option(self) -> None:
"""Test the ``debug`` argument"""
def test_debug_arg(self) -> None:
"""Test the `debug` argument"""

parser = Parser()

terminate_command = ['terminate', '-n', 'node1', '-u', '100']
self.assertFalse(parser.parse_args(terminate_command).debug)

Expand All @@ -115,15 +100,14 @@ def test_verbose_arg(self) -> None:

parser = Parser()
base_command = ['terminate', '-n', 'node', '-u' '100']

self.assertEqual(0, parser.parse_args(base_command).verbosity)
self.assertEqual(1, parser.parse_args(base_command + ['-v']).verbosity)
self.assertEqual(2, parser.parse_args(base_command + ['-vv']).verbosity)
self.assertEqual(3, parser.parse_args(base_command + ['-vvv']).verbosity)
self.assertEqual(5, parser.parse_args(base_command + ['-vvvvv']).verbosity)

def test_nodes_arg(self) -> None:
"""Test parsing of the ``nodes`` argument"""
"""Test parsing of the `nodes` argument"""

parser = Parser()

Expand All @@ -136,7 +120,7 @@ def test_nodes_arg(self) -> None:
self.assertSequenceEqual(multi_node_out, parser.parse_args(multi_node_cmd).nodes)

def test_uid_whitelist_arg(self) -> None:
"""Test parsing of the ``uid-whitelist`` argument"""
"""Test parsing of the `uid-whitelist` argument"""

parser = Parser()

Expand Down
Empty file added tests/utils/__init__.py
Empty file.
44 changes: 44 additions & 0 deletions tests/utils/test_exclude_active_slurm_users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Tests for the `utils.exclude_active_slurm_users` function"""

import unittest

import pandas as pd

from shinigami.utils import exclude_active_slurm_users


class ExcludeSlurmUsers(unittest.TestCase):
"""Test the identification of slurm users from a DataFrame of process data"""

def test_exclude_al_processes_for_user(self) -> None:
"""Test users with slurm processes are excluded from the returned DataFrame"""

# User 1002 has two processes.
# They BOTH should be excluded because ONE of them is a slurm process.
input_df = pd.DataFrame({
'UID': [1001, 1002, 1002, 1003, 1004],
'CMD': ['process 1', '... slurmd ...', 'process 3', 'process 4', 'process5']})

expected_df = input_df.loc[[0, 3, 4]]
returned_df = exclude_active_slurm_users(input_df)
pd.testing.assert_frame_equal(returned_df, expected_df)

def test_no_slurm_users(self) -> None:
"""Test the returned dataframe matches the input dataframe when there are no slurm processes"""

input_df = pd.DataFrame({
'UID': [1001, 1002, 1003, 1004, 1005],
'CMD': ['process1', 'process2', 'process3', 'process4', 'process5']})

returned_df = exclude_active_slurm_users(input_df)
pd.testing.assert_frame_equal(returned_df, input_df)

def test_all_slurm_users(self) -> None:
"""Test the returned dataframe is empty when all process container `slurmd`"""

input_df = pd.DataFrame({
'UID': [1001, 1002, 1003, 1004],
'CMD': ['slurmd', 'prefix slurmd', 'slurmd postfix', 'prfix slurmd postfix']})

returned_df = exclude_active_slurm_users(input_df)
self.assertTrue(returned_df.empty)
Loading

0 comments on commit c33a749

Please sign in to comment.