diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 679920880286..4414d3f78ad2 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -1839,15 +1839,15 @@ def _get_pyarrow_version() -> Optional[str]: class DeferSigint(contextlib.AbstractContextManager): - """Context manager that defers SIGINT signals until the the context is left.""" + """Context manager that defers SIGINT signals until the context is left.""" # This is used by Ray's task cancellation to defer cancellation interrupts during # problematic areas, e.g. task argument deserialization. def __init__(self): - # Whether the task has been cancelled while in the context. - self.task_cancelled = False - # The original SIGINT handler. - self.orig_sigint_handler = None + # Whether a SIGINT signal was received during the context. + self.signal_received = False + # The overridden SIGINT handler + self.overridden_sigint_handler = None # The original signal method. self.orig_signal = None @@ -1861,32 +1861,29 @@ def create_if_main_thread(cls) -> contextlib.AbstractContextManager: else: return contextlib.nullcontext() - def _set_task_cancelled(self, signum, frame): + def _set_signal_received(self, signum, frame): """SIGINT handler that defers the signal.""" - self.task_cancelled = True + self.signal_received = True def _signal_monkey_patch(self, signum, handler): - """Monkey patch for signal.signal that raises an error if a SIGINT handler is - registered within the DeferSigint context. - """ - # Only raise an error if setting a SIGINT handler in the main thread; if setting - # a handler in a non-main thread, signal.signal will raise an error anyway - # indicating that Python does not allow that. + """Monkey patch for signal.signal that defers the setting of new signal + handler after the DeferSigint context exits.""" + # Only handle it in the main thread because if setting a handler in a non-main + # thread, signal.signal will raise an error because Python doesn't allow it. if ( threading.current_thread() == threading.main_thread() and signum == signal.SIGINT ): - raise ValueError( - "Can't set signal handler for SIGINT while SIGINT is being deferred " - "within a DeferSigint context." - ) + orig_sigint_handler = self.overridden_sigint_handler + self.overridden_sigint_handler = handler + return orig_sigint_handler return self.orig_signal(signum, handler) def __enter__(self): # Save original SIGINT handler for later restoration. - self.orig_sigint_handler = signal.getsignal(signal.SIGINT) + self.overridden_sigint_handler = signal.getsignal(signal.SIGINT) # Set SIGINT signal handler that defers the signal. - signal.signal(signal.SIGINT, self._set_task_cancelled) + signal.signal(signal.SIGINT, self._set_signal_received) # Monkey patch signal.signal to raise an error if a SIGINT handler is registered # within the context. self.orig_signal = signal.signal @@ -1894,16 +1891,16 @@ def __enter__(self): return self def __exit__(self, exc_type, exc, exc_tb): - assert self.orig_sigint_handler is not None + assert self.overridden_sigint_handler is not None assert self.orig_signal is not None # Restore original signal.signal function. signal.signal = self.orig_signal - # Restore original SIGINT handler. - signal.signal(signal.SIGINT, self.orig_sigint_handler) - if exc_type is None and self.task_cancelled: - # No exception raised in context but task has been cancelled, so we raise - # KeyboardInterrupt to go through the task cancellation path. - raise KeyboardInterrupt + # Restore overridden SIGINT handler. + signal.signal(signal.SIGINT, self.overridden_sigint_handler) + if exc_type is None and self.signal_received: + # No exception raised in context, call the original SIGINT handler. + # By default, this means raising KeyboardInterrupt. + self.overridden_sigint_handler(signal.SIGINT, None) else: # If exception was raised in context, returning False will cause it to be # reraised. diff --git a/python/ray/tests/test_cancel.py b/python/ray/tests/test_cancel.py index 5cdf44218db9..c46175670e60 100644 --- a/python/ray/tests/test_cancel.py +++ b/python/ray/tests/test_cancel.py @@ -106,7 +106,7 @@ def test_defer_sigint(): orig_sigint_handler = signal.getsignal(signal.SIGINT) try: with DeferSigint(): - # Send singal to current process. + # Send signal to current process. # NOTE: We use _thread.interrupt_main() instead of os.kill() in order to # support Windows. _thread.interrupt_main() @@ -124,13 +124,62 @@ def test_defer_sigint(): pytest.fail("SIGINT signal was never sent in test") -def test_defer_sigint_monkey_patch(): - # Tests that setting a SIGINT signal handler within a DeferSigint context is not - # allowed. +def test_defer_sigint_monkey_patch_handler_called_when_exit(): + # Tests that the SIGINT signal handlers set within a DeferSigint + # is triggered at most once and only at context exit. orig_sigint_handler = signal.getsignal(signal.SIGINT) - with pytest.raises(ValueError): - with DeferSigint(): - signal.signal(signal.SIGINT, orig_sigint_handler) + handler_called_times = 0 + + def new_sigint_handler(signum, frame): + nonlocal handler_called_times + handler_called_times += 1 + + with DeferSigint(): + signal.signal(signal.SIGINT, new_sigint_handler) + for _ in range(3): + _thread.interrupt_main() + time.sleep(1) + assert handler_called_times == 0 + + assert handler_called_times == 1 + + # Restore original SIGINT handler. + signal.signal(signal.SIGINT, orig_sigint_handler) + + +def test_defer_sigint_monkey_patch_only_last_handler_called(): + # Tests that only the last SIGINT signal handler set within a DeferSigint + # is triggered at most once and only at context exit. + orig_sigint_handler = signal.getsignal(signal.SIGINT) + + handler_1_called_times = 0 + handler_2_called_times = 0 + + def sigint_handler_1(signum, frame): + nonlocal handler_1_called_times + handler_1_called_times += 1 + + def sigint_handler_2(signum, frame): + nonlocal handler_2_called_times + handler_2_called_times += 1 + + with DeferSigint(): + signal.signal(signal.SIGINT, sigint_handler_1) + for _ in range(3): + _thread.interrupt_main() + time.sleep(1) + signal.signal(signal.SIGINT, sigint_handler_2) + for _ in range(3): + _thread.interrupt_main() + time.sleep(1) + assert handler_1_called_times == 0 + assert handler_2_called_times == 0 + + assert handler_1_called_times == 0 + assert handler_2_called_times == 1 + + # Restore original SIGINT handler. + signal.signal(signal.SIGINT, orig_sigint_handler) def test_defer_sigint_noop_in_non_main_thread():