Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement parent monitoring thread #8

Merged
merged 4 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions src/jsi/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

Less common options:
--output DIRECTORY directory where solver output files will be written
--supervisor run a supervisor process to avoid orphaned subprocesses
--reaper run a reaper process that kills orphaned solvers when jsi exits
--debug enable debug logging
--csv print solver results in CSV format (<output>/<input>.csv)
--perf print performance timers
Expand All @@ -53,6 +53,7 @@
import signal
import sys
import threading
import time
from functools import partial

from jsi.config.loader import Config, find_available_solvers, load_definitions
Expand Down Expand Up @@ -134,6 +135,38 @@ def cleanup():
atexit.register(cleanup)


def monitor_parent():
"""
Monitor the parent process and exit if it dies or changes.

Caveats:
- only works on POSIX systems
- only works if called early enough (before the original parent process exits)
"""

parent_pid = os.getppid()

def check_parent():
while True:
try:
current_ppid = os.getppid()

# if parent PID changed (original parent died), we exit
if current_ppid != parent_pid or current_ppid == 1:
stderr.print("parent process died, exiting...")
os.kill(os.getpid(), signal.SIGTERM)
break
time.sleep(1) # check every second
except ProcessLookupError:
# if we can't check parent PID, assume parent died
os.kill(os.getpid(), signal.SIGTERM)
break

# Start monitoring in background thread
monitor_thread = threading.Thread(target=check_parent, daemon=True)
monitor_thread.start()


class BadParameterError(Exception):
pass

Expand Down Expand Up @@ -174,8 +207,8 @@ def parse_args(args: list[str]) -> Config:
config.model = True
case "--csv":
config.csv = True
case "--supervisor":
config.supervisor = True
case "--reaper":
config.reaper = True
case "--daemon":
config.daemon = True
case "--timeout":
Expand Down Expand Up @@ -229,6 +262,9 @@ def main(args: list[str] | None = None) -> int:
global stdout
global stderr

# kick off the parent monitor in the background as early as possible
monitor_parent()

if args is None:
args = sys.argv[1:]

Expand Down Expand Up @@ -310,8 +346,8 @@ def main(args: list[str] | None = None) -> int:
controller.start()
status.start()

if config.supervisor:
from jsi.supervisor import Supervisor
if config.reaper:
from jsi.reaper import Reaper

# wait for the subprocesses to start, we need the PIDs for the supervisor
while controller.task.status.value < TaskStatus.RUNNING.value:
Expand All @@ -320,9 +356,9 @@ def main(args: list[str] | None = None) -> int:
# start a supervisor process in daemon mode so that it does not block
# the program from exiting
child_pids = [command.pid for command in controller.commands]
sv = Supervisor(os.getpid(), child_pids, config)
sv.daemon = True
sv.start()
reaper = Reaper(os.getpid(), child_pids, config.debug)
reaper.daemon = True
reaper.start()

# wait for the solvers to finish
controller.join()
Expand Down
4 changes: 2 additions & 2 deletions src/jsi/config/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
debug: bool = False,
input_file: str | None = None,
output_dir: str | None = None,
supervisor: bool = False,
reaper: bool = False,
sequence: Sequence[str] | None = None,
model: bool = False,
csv: bool = False,
Expand All @@ -36,7 +36,7 @@ def __init__(
self.debug = debug
self.input_file = input_file
self.output_dir = output_dir
self.supervisor = supervisor
self.reaper = reaper
self.sequence = sequence
self.model = model
self.csv = csv
Expand Down
24 changes: 12 additions & 12 deletions src/jsi/supervisor.py → src/jsi/reaper.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
import multiprocessing
import sys
import time

from jsi.core import Config
from jsi.utils import LogLevel, kill_process, logger, pid_exists
from jsi.utils import LogLevel, get_console, kill_process, logger, pid_exists


class Supervisor(multiprocessing.Process):
"""Supervisor process that monitors the parent process and its children."""
class Reaper(multiprocessing.Process):
"""Reaper process that monitors the parent process and its children."""

parent_pid: int
child_pids: list[int]
config: Config
debug: bool

def __init__(self, parent_pid: int, child_pids: list[int], config: Config):
def __init__(self, parent_pid: int, child_pids: list[int], debug: bool = False):
super().__init__()
self.parent_pid = parent_pid
self.child_pids = child_pids
self.config = config
self.debug = debug

def run(self):
if self.config.debug:
logger.enable(console=self.config.stderr, level=LogLevel.DEBUG)
level = LogLevel.DEBUG if self.debug else LogLevel.INFO
logger.enable(console=get_console(sys.stderr), level=level)

logger.debug(f"supervisor started (PID: {self.pid})")
logger.debug(f"watching parent (PID: {self.parent_pid})")
logger.debug(f"watching children (PID: {self.child_pids})")
logger.info(f"reaper started (PID: {self.pid})")
logger.info(f"watching parent (PID: {self.parent_pid})")
logger.info(f"watching children (PID: {self.child_pids})")

last_message_time = time.time()
try:
Expand Down
6 changes: 4 additions & 2 deletions src/jsi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,11 @@ def sync_solve(self, file: str) -> str:

listener = ResultListener()
controller = ProcessController(
task, commands, self.config,
task,
commands,
self.config,
start_callback=start_logger,
exit_callback=listener.exit_callback
exit_callback=listener.exit_callback,
)
controller.start()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_process_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_real_process():
def test_cmd():
command = cmd()
command.start()
stdout, stderr = command.communicate(timeout=0.1)
stdout, stderr = command.communicate(timeout=0.2)

assert command.returncode == 0
assert not stdout
Expand Down