Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates process whitelisting to exclude users running slurm #95

Merged
merged 18 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions shinigami/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
"""Shinigami is a command-line application for killing errant processes
on Slurm based compute nodes. The application scans for and terminates any
running processes not associated with a currently running Slurm job.

Individual users and groups can be whitelisted in the application settings file
via UID and GID values. Specific compute nodes can also be ignored using basic
string matching. See the ``settings`` module for more details.
"""

import importlib.metadata
Expand Down
8 changes: 4 additions & 4 deletions shinigami/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""The executable application and its command-line interface."""
"""The executable application and its command line interface."""

import asyncio
import inspect
Expand Down Expand Up @@ -192,14 +192,14 @@ def execute(cls, arg_list: List[str] = None) -> None:
cls._configure_logging(args.verbosity)

try:
# Extract the subset of arguments that are valid for the function ``args.callable``
# Extract the subset of arguments that are valid for the function `args.callable`
valid_params = inspect.signature(args.callable).parameters
valid_arguments = {key: value for key, value in vars(args).items() if key in valid_params}
asyncio.run(args.callable(**valid_arguments))

except KeyboardInterrupt:
except KeyboardInterrupt: # pragma: nocover
pass

except Exception as caught:
except Exception as caught: # pragma: nocover
logging.getLogger('file_logger').critical('Application crash', exc_info=caught)
logging.getLogger('console_logger').critical(str(caught))
111 changes: 94 additions & 17 deletions shinigami/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@
import asyncssh
import pandas as pd

# Technically the init process ID may vary with the system
# architecture, but 1 is an almost universal default
INIT_PROCESS_ID = 1

# Custom type hints
Whitelist = Collection[Union[int, Tuple[int, int]]]

def id_in_whitelist(id_value: int, whitelist: Collection[Union[int, Tuple[int, int]]]) -> bool:

def _id_in_whitelist(id_value: int, whitelist: Whitelist) -> bool:
"""Return whether an ID is in a list of ID value definitions

The `whitelist` of ID values can contain a mix of integers and tuples
The `whitelist` of ID values can contain a mix of integers and tuples
of integer ranges. For example, [0, 1, (2, 9), 10] includes all IDs from
zero through ten.

Expand Down Expand Up @@ -56,7 +61,83 @@ def get_nodes(cluster: str, ignore_nodes: Collection[str] = tuple()) -> set:
raise RuntimeError(stderr)

all_nodes = stdout.decode().strip().split('\n')
return set(node for node in all_nodes if node not in ignore_nodes)
return set(all_nodes) - set(ignore_nodes)


async def get_remote_processes(conn: asyncssh.SSHClientConnection) -> pd.DataFrame:
"""Fetch running process data from a remote machine

The returned DataFrame is guaranteed to have columns `PID`, `PPID`, `PGID`,
`UID`, and `CND`.

Args:
conn: Open SSH connection to the machine

Returns:
A pandas DataFrame with process data
"""

# Add 1 to column widths when parsing ps output to account for space between columns
ps_return = await conn.run('ps -eo pid:10,ppid:10,pgid:10,uid:10,cmd:500', check=True)
return pd.read_fwf(StringIO(ps_return.stdout), widths=[11, 11, 11, 11, 500])


def include_orphaned_processes(df: pd.DataFrame) -> pd.DataFrame:
"""Filter a DataFrame to only include orphaned processes

Given a DataFrame with system process data, return a subset of the data
containing processes parented by `INIT_PROCESS_ID`.

See the `get_remote_processes` function for the assumed DataFrame data model.

Args:
df: A DataFrame with process data

Returns:
A copy of the given DataFrame
"""

return df[df['PPID'] == INIT_PROCESS_ID]


def include_user_whitelist(df: pd.DataFrame, uid_whitelist: Whitelist) -> pd.DataFrame:
"""Filter a DataFrame to only include a subset of user IDs

Given a DataFrame with system process data, return a subset of the data
containing processes owned by the given user IDs.

See the `get_remote_processes` function for the assumed DataFrame data model.

Args:
df: A DataFrame with process data
uid_whitelist: List of user IDs to whitelist

Returns:
A copy of the given DataFrame
"""

whitelist_index = df['UID'].apply(_id_in_whitelist, whitelist=uid_whitelist)
return df[whitelist_index]


def exclude_active_slurm_users(df: pd.DataFrame) -> pd.DataFrame:
"""Filter a DataFrame to exclude user IDs tied to a running slurm job

Given a DataFrame with system process data, return a subset of the data
that excludes processes owned by users running a `slurmd` command.

See the `get_remote_processes` function for the assumed DataFrame data model.

Args:
df: A DataFrame with process data

Returns:
A copy of the given DataFrame
"""

is_slurm = df['CMD'].str.contains('slurmd')
slurm_uids = df['UID'][is_slurm].unique()
return df[~df['UID'].isin(slurm_uids)]


async def terminate_errant_processes(
Expand All @@ -66,11 +147,11 @@ async def terminate_errant_processes(
ssh_options: asyncssh.SSHClientConnectionOptions = None,
debug: bool = False
) -> None:
"""Terminate non-Slurm processes on a given node
"""Terminate orphaned processes on a given node

Args:
node: The DNS resolvable name of the node to terminate processes on
uid_whitelist: Do not terminate processes owned by the given UID
uid_whitelist: Do not terminate processes owned by the given UIDs
ssh_limit: Semaphore object used to limit concurrent SSH connections
ssh_options: Options for configuring the outbound SSH connection
debug: Log which process to terminate but do not terminate them
Expand All @@ -79,24 +160,20 @@ async def terminate_errant_processes(
logging.debug(f'[{node}] Waiting for SSH pool')
async with ssh_limit, asyncssh.connect(node, options=ssh_options) as conn:
logging.info(f'[{node}] Scanning for processes')
process_df = await get_remote_processes(conn)

# Fetch running process data from the remote machine
# Add 1 to column widths when parsing ps output to account for space between columns
ps_return = await conn.run('ps -eo pid:10,ppid:10,pgid:10,uid:10,cmd:500', check=True)
process_df = pd.read_fwf(StringIO(ps_return.stdout), widths=[11, 11, 11, 11, 500])

# Identify orphaned processes and filter them by the UID whitelist
orphaned = process_df[process_df.PPID == INIT_PROCESS_ID]
whitelist_index = orphaned['UID'].apply(id_in_whitelist, whitelist=uid_whitelist)
to_terminate = orphaned[whitelist_index]
# Filter them by various whitelist/blacklist criteria
process_df = include_orphaned_processes(process_df)
process_df = include_user_whitelist(process_df, uid_whitelist)
process_df = exclude_active_slurm_users(process_df)

for _, row in to_terminate.iterrows():
for _, row in process_df.iterrows(): # pragma: nocover
logging.info(f'[{node}] Marking for termination {dict(row)}')

if to_terminate.empty:
if process_df.empty: # pragma: nocover
logging.info(f'[{node}] no processes found')

elif not debug:
proc_id_str = ','.join(to_terminate.PGID.unique().astype(str))
proc_id_str = ','.join(process_df.PGID.unique().astype(str))
logging.info(f"[{node}] Sending termination signal for process groups {proc_id_str}")
await conn.run(f"pkill --signal 9 --pgroup {proc_id_str}", check=True)
24 changes: 24 additions & 0 deletions tests/cli/test_application.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Tests for the `cli.Application` class"""

from unittest import TestCase
from unittest.mock import patch

from shinigami.cli import Application


class MethodRouting(TestCase):
"""Test CLI commands are routed to the correct callable objects"""

def test_scan_method(self) -> None:
"""Test the `scan` command routes to the `scan` method"""

with patch.object(Application, 'scan', autospec=True) as scan:
Application().execute(['scan', '-c', 'cluster1'])
scan.assert_called_once()

def test_terminate_method(self) -> None:
"""Test the `terminate` command routes to the `terminate` method"""

with patch.object(Application, 'terminate', autospec=True) as scan:
Application().execute(['terminate', '-n', 'node1'])
scan.assert_called_once()
22 changes: 22 additions & 0 deletions tests/cli/test_base_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Tests for the `cli.BaseParser` class"""

from unittest import TestCase

from shinigami.cli import BaseParser


class BaseParsing(TestCase):
"""Test custom parsing logic encapsulated by the `BaseParser` class"""

def test_errors_raise_system_exit(self) -> None:
"""Test error messages are raised as `SystemExit` instances"""

with self.assertRaises(SystemExit):
BaseParser().error("This is an error message")

def test_errors_include_message(self) -> None:
"""Test parser messages are included as error messages"""

msg = "This is an error message"
with self.assertRaisesRegex(SystemExit, msg):
BaseParser().error(msg)
44 changes: 14 additions & 30 deletions tests/cli/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,17 @@
"""Tests for the ``cli.Parser`` class"""
"""Tests for the `cli.Parser` class"""

from unittest import TestCase

from shinigami.cli import Parser, BaseParser


class BaseParsing(TestCase):
"""Test custom parsing login encapsulated by the `BaseParser` class"""

def test_error_handling(self) -> None:
"""Test error messages are raised as `SystemExit` instances"""

parser = BaseParser()
error_message = "This is an error message"
with self.assertRaises(SystemExit, msg=error_message):
parser.error(error_message)
from shinigami.cli import Parser


class ScanSubParser(TestCase):
"""Test the behavior of the ``scan`` subparser"""
"""Test the behavior of the `scan` subparser"""

def test_debug_option(self) -> None:
"""Test parsing of the ``debug`` argument"""
def test_debug_arg(self) -> None:
"""Test parsing of the `debug` argument"""

parser = Parser()

scan_command = ['scan', '-c', 'development', '-u' '100']
self.assertFalse(parser.parse_args(scan_command).debug)

Expand All @@ -36,15 +23,14 @@ def test_verbose_arg(self) -> None:

parser = Parser()
base_command = ['scan', '-c', 'development', '-u' '100']

self.assertEqual(0, parser.parse_args(base_command).verbosity)
self.assertEqual(1, parser.parse_args(base_command + ['-v']).verbosity)
self.assertEqual(2, parser.parse_args(base_command + ['-vv']).verbosity)
self.assertEqual(3, parser.parse_args(base_command + ['-vvv']).verbosity)
self.assertEqual(5, parser.parse_args(base_command + ['-vvvvv']).verbosity)

def test_clusters_arg(self) -> None:
"""Test parsing of the ``clusters`` argument"""
"""Test parsing of the `clusters` argument"""

parser = Parser()

Expand All @@ -56,8 +42,8 @@ def test_clusters_arg(self) -> None:
multi_cluster_cmd = ['scan', '-c', *multi_cluster_out, '-u', '100']
self.assertSequenceEqual(multi_cluster_out, parser.parse_args(multi_cluster_cmd).clusters)

def test_ignore_nodes(self) -> None:
"""Test parsing of the ``ignore-nodes`` argument"""
def test_ignore_nodes_arg(self) -> None:
"""Test parsing of the `ignore-nodes` argument"""

parser = Parser()
base_command = ['scan', '-c', 'development', '-u' '100']
Expand All @@ -71,7 +57,7 @@ def test_ignore_nodes(self) -> None:
self.assertSequenceEqual(multi_node_out, parser.parse_args(multi_node_cmd).ignore_nodes)

def test_uid_whitelist_arg(self) -> None:
"""Test parsing of the ``uid-whitelist`` argument"""
"""Test parsing of the `uid-whitelist` argument"""

parser = Parser()

Expand All @@ -97,13 +83,12 @@ def test_uid_whitelist_arg(self) -> None:


class TerminateSubParser(TestCase):
"""Test the behavior of the ``terminate`` subparser"""
"""Test the behavior of the `terminate` subparser"""

def test_debug_option(self) -> None:
"""Test the ``debug`` argument"""
def test_debug_arg(self) -> None:
"""Test the `debug` argument"""

parser = Parser()

terminate_command = ['terminate', '-n', 'node1', '-u', '100']
self.assertFalse(parser.parse_args(terminate_command).debug)

Expand All @@ -115,15 +100,14 @@ def test_verbose_arg(self) -> None:

parser = Parser()
base_command = ['terminate', '-n', 'node', '-u' '100']

self.assertEqual(0, parser.parse_args(base_command).verbosity)
self.assertEqual(1, parser.parse_args(base_command + ['-v']).verbosity)
self.assertEqual(2, parser.parse_args(base_command + ['-vv']).verbosity)
self.assertEqual(3, parser.parse_args(base_command + ['-vvv']).verbosity)
self.assertEqual(5, parser.parse_args(base_command + ['-vvvvv']).verbosity)

def test_nodes_arg(self) -> None:
"""Test parsing of the ``nodes`` argument"""
"""Test parsing of the `nodes` argument"""

parser = Parser()

Expand All @@ -136,7 +120,7 @@ def test_nodes_arg(self) -> None:
self.assertSequenceEqual(multi_node_out, parser.parse_args(multi_node_cmd).nodes)

def test_uid_whitelist_arg(self) -> None:
"""Test parsing of the ``uid-whitelist`` argument"""
"""Test parsing of the `uid-whitelist` argument"""

parser = Parser()

Expand Down
Empty file added tests/utils/__init__.py
Empty file.
44 changes: 44 additions & 0 deletions tests/utils/test_exclude_active_slurm_users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Tests for the `utils.exclude_active_slurm_users` function"""

import unittest

import pandas as pd

from shinigami.utils import exclude_active_slurm_users


class ExcludeSlurmUsers(unittest.TestCase):
"""Test the identification of slurm users from a DataFrame of process data"""

def test_exclude_al_processes_for_user(self) -> None:
"""Test users with slurm processes are excluded from the returned DataFrame"""

# User 1002 has two processes.
# They BOTH should be excluded because ONE of them is a slurm process.
input_df = pd.DataFrame({
'UID': [1001, 1002, 1002, 1003, 1004],
'CMD': ['process 1', '... slurmd ...', 'process 3', 'process 4', 'process5']})

expected_df = input_df.loc[[0, 3, 4]]
returned_df = exclude_active_slurm_users(input_df)
pd.testing.assert_frame_equal(returned_df, expected_df)

def test_no_slurm_users(self) -> None:
"""Test the returned dataframe matches the input dataframe when there are no slurm processes"""

input_df = pd.DataFrame({
'UID': [1001, 1002, 1003, 1004, 1005],
'CMD': ['process1', 'process2', 'process3', 'process4', 'process5']})

returned_df = exclude_active_slurm_users(input_df)
pd.testing.assert_frame_equal(returned_df, input_df)

def test_all_slurm_users(self) -> None:
"""Test the returned dataframe is empty when all process container `slurmd`"""

input_df = pd.DataFrame({
'UID': [1001, 1002, 1003, 1004],
'CMD': ['slurmd', 'prefix slurmd', 'slurmd postfix', 'prfix slurmd postfix']})

returned_df = exclude_active_slurm_users(input_df)
self.assertTrue(returned_df.empty)
Loading