Skip to content

Commit

Permalink
[Fix][Core] Don't raise ValueError in DeferSigint context manager (ra…
Browse files Browse the repository at this point in the history
…y-project#48494)

See the description in the corresponding issue for details.
  • Loading branch information
MortalHappiness authored Nov 5, 2024
1 parent 81cf6d8 commit 736e120
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 33 deletions.
49 changes: 23 additions & 26 deletions python/ray/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1839,15 +1839,15 @@ def _get_pyarrow_version() -> Optional[str]:


class DeferSigint(contextlib.AbstractContextManager):
"""Context manager that defers SIGINT signals until the the context is left."""
"""Context manager that defers SIGINT signals until the context is left."""

# This is used by Ray's task cancellation to defer cancellation interrupts during
# problematic areas, e.g. task argument deserialization.
def __init__(self):
# Whether the task has been cancelled while in the context.
self.task_cancelled = False
# The original SIGINT handler.
self.orig_sigint_handler = None
# Whether a SIGINT signal was received during the context.
self.signal_received = False
# The overridden SIGINT handler
self.overridden_sigint_handler = None
# The original signal method.
self.orig_signal = None

Expand All @@ -1861,49 +1861,46 @@ def create_if_main_thread(cls) -> contextlib.AbstractContextManager:
else:
return contextlib.nullcontext()

def _set_task_cancelled(self, signum, frame):
def _set_signal_received(self, signum, frame):
"""SIGINT handler that defers the signal."""
self.task_cancelled = True
self.signal_received = True

def _signal_monkey_patch(self, signum, handler):
"""Monkey patch for signal.signal that raises an error if a SIGINT handler is
registered within the DeferSigint context.
"""
# Only raise an error if setting a SIGINT handler in the main thread; if setting
# a handler in a non-main thread, signal.signal will raise an error anyway
# indicating that Python does not allow that.
"""Monkey patch for signal.signal that defers the setting of new signal
handler after the DeferSigint context exits."""
# Only handle it in the main thread because if setting a handler in a non-main
# thread, signal.signal will raise an error because Python doesn't allow it.
if (
threading.current_thread() == threading.main_thread()
and signum == signal.SIGINT
):
raise ValueError(
"Can't set signal handler for SIGINT while SIGINT is being deferred "
"within a DeferSigint context."
)
orig_sigint_handler = self.overridden_sigint_handler
self.overridden_sigint_handler = handler
return orig_sigint_handler
return self.orig_signal(signum, handler)

def __enter__(self):
# Save original SIGINT handler for later restoration.
self.orig_sigint_handler = signal.getsignal(signal.SIGINT)
self.overridden_sigint_handler = signal.getsignal(signal.SIGINT)
# Set SIGINT signal handler that defers the signal.
signal.signal(signal.SIGINT, self._set_task_cancelled)
signal.signal(signal.SIGINT, self._set_signal_received)
# Monkey patch signal.signal to raise an error if a SIGINT handler is registered
# within the context.
self.orig_signal = signal.signal
signal.signal = self._signal_monkey_patch
return self

def __exit__(self, exc_type, exc, exc_tb):
assert self.orig_sigint_handler is not None
assert self.overridden_sigint_handler is not None
assert self.orig_signal is not None
# Restore original signal.signal function.
signal.signal = self.orig_signal
# Restore original SIGINT handler.
signal.signal(signal.SIGINT, self.orig_sigint_handler)
if exc_type is None and self.task_cancelled:
# No exception raised in context but task has been cancelled, so we raise
# KeyboardInterrupt to go through the task cancellation path.
raise KeyboardInterrupt
# Restore overridden SIGINT handler.
signal.signal(signal.SIGINT, self.overridden_sigint_handler)
if exc_type is None and self.signal_received:
# No exception raised in context, call the original SIGINT handler.
# By default, this means raising KeyboardInterrupt.
self.overridden_sigint_handler(signal.SIGINT, None)
else:
# If exception was raised in context, returning False will cause it to be
# reraised.
Expand Down
63 changes: 56 additions & 7 deletions python/ray/tests/test_cancel.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_defer_sigint():
orig_sigint_handler = signal.getsignal(signal.SIGINT)
try:
with DeferSigint():
# Send singal to current process.
# Send signal to current process.
# NOTE: We use _thread.interrupt_main() instead of os.kill() in order to
# support Windows.
_thread.interrupt_main()
Expand All @@ -124,13 +124,62 @@ def test_defer_sigint():
pytest.fail("SIGINT signal was never sent in test")


def test_defer_sigint_monkey_patch():
# Tests that setting a SIGINT signal handler within a DeferSigint context is not
# allowed.
def test_defer_sigint_monkey_patch_handler_called_when_exit():
# Tests that the SIGINT signal handlers set within a DeferSigint
# is triggered at most once and only at context exit.
orig_sigint_handler = signal.getsignal(signal.SIGINT)
with pytest.raises(ValueError):
with DeferSigint():
signal.signal(signal.SIGINT, orig_sigint_handler)
handler_called_times = 0

def new_sigint_handler(signum, frame):
nonlocal handler_called_times
handler_called_times += 1

with DeferSigint():
signal.signal(signal.SIGINT, new_sigint_handler)
for _ in range(3):
_thread.interrupt_main()
time.sleep(1)
assert handler_called_times == 0

assert handler_called_times == 1

# Restore original SIGINT handler.
signal.signal(signal.SIGINT, orig_sigint_handler)


def test_defer_sigint_monkey_patch_only_last_handler_called():
# Tests that only the last SIGINT signal handler set within a DeferSigint
# is triggered at most once and only at context exit.
orig_sigint_handler = signal.getsignal(signal.SIGINT)

handler_1_called_times = 0
handler_2_called_times = 0

def sigint_handler_1(signum, frame):
nonlocal handler_1_called_times
handler_1_called_times += 1

def sigint_handler_2(signum, frame):
nonlocal handler_2_called_times
handler_2_called_times += 1

with DeferSigint():
signal.signal(signal.SIGINT, sigint_handler_1)
for _ in range(3):
_thread.interrupt_main()
time.sleep(1)
signal.signal(signal.SIGINT, sigint_handler_2)
for _ in range(3):
_thread.interrupt_main()
time.sleep(1)
assert handler_1_called_times == 0
assert handler_2_called_times == 0

assert handler_1_called_times == 0
assert handler_2_called_times == 1

# Restore original SIGINT handler.
signal.signal(signal.SIGINT, orig_sigint_handler)


def test_defer_sigint_noop_in_non_main_thread():
Expand Down

0 comments on commit 736e120

Please sign in to comment.