diff --git a/python/ray/_private/ray_experimental_perf.py b/python/ray/_private/ray_experimental_perf.py index 2b07b71793ef..b46408c2abe1 100644 --- a/python/ray/_private/ray_experimental_perf.py +++ b/python/ray/_private/ray_experimental_perf.py @@ -180,7 +180,6 @@ async def _exec_async(): results += timeit( "[unstable] compiled single-actor DAG calls", lambda: _exec(compiled_dag) ) - compiled_dag.teardown() del a # Single-actor asyncio DAG calls @@ -194,10 +193,6 @@ async def _exec_async(): "[unstable] compiled single-actor asyncio DAG calls", ) ) - # TODO: Need to explicitly tear down DAGs with enable_asyncio=True because - # these DAGs create a background thread that can segfault if the CoreWorker - # is torn down first. - compiled_dag.teardown() del a # Scatter-gather DAG calls @@ -215,7 +210,6 @@ async def _exec_async(): f"[unstable] compiled scatter-gather DAG calls, n={n_cpu} actors", lambda: _exec(compiled_dag), ) - compiled_dag.teardown() # Scatter-gather asyncio DAG calls @@ -228,10 +222,6 @@ async def _exec_async(): f"[unstable] compiled scatter-gather asyncio DAG calls, n={n_cpu} actors", ) ) - # TODO: Need to explicitly tear down DAGs with enable_asyncio=True because - # these DAGs create a background thread that can segfault if the CoreWorker - # is torn down first. - compiled_dag.teardown() # Chain DAG calls @@ -249,7 +239,6 @@ async def _exec_async(): f"[unstable] compiled chain DAG calls, n={n_cpu} actors", lambda: _exec(compiled_dag), ) - compiled_dag.teardown() # Chain asyncio DAG calls @@ -262,10 +251,6 @@ async def _exec_async(): results += loop.run_until_complete( exec_async(f"[unstable] compiled chain asyncio DAG calls, n={n_cpu} actors") ) - # TODO: Need to explicitly tear down DAGs with enable_asyncio=True because - # these DAGs create a background thread that can segfault if the CoreWorker - # is torn down first. - compiled_dag.teardown() # Multiple args with small payloads @@ -288,7 +273,6 @@ async def _exec_async(): f"n={n_actors} actors", lambda: _exec(compiled_dag, num_args=n_actors, payload_size=payload_size), ) - compiled_dag.teardown() # Multiple args with medium payloads @@ -306,7 +290,6 @@ async def _exec_async(): f"n={n_actors} actors", lambda: _exec(compiled_dag, num_args=n_actors, payload_size=payload_size), ) - compiled_dag.teardown() # Multiple args with large payloads @@ -324,7 +307,6 @@ async def _exec_async(): f"n={n_actors} actors", lambda: _exec(compiled_dag, num_args=n_actors, payload_size=payload_size), ) - compiled_dag.teardown() # Worst case for multiple arguments: a single actor takes all the arguments # with small payloads. @@ -345,7 +327,6 @@ async def _exec_async(): "n=1 actors", lambda: _exec(compiled_dag, num_args=n_args, payload_size=payload_size), ) - compiled_dag.teardown() ray.shutdown() diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 0abfb5757692..b7a85221b391 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -1874,6 +1874,11 @@ def shutdown(_exiting_interpreter: bool = False): and false otherwise. If we are exiting the interpreter, we will wait a little while to print any extra error messages. """ + # Make sure to clean up compiled dag node if exists. + from ray.dag.compiled_dag_node import _shutdown_all_compiled_dags + + _shutdown_all_compiled_dags() + if _exiting_interpreter and global_worker.mode == SCRIPT_MODE: # This is a duration to sleep before shutting down everything in order # to make sure that log messages finish printing. diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index 692a155bc114..8049d22dfb3e 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -1,3 +1,4 @@ +import weakref import asyncio from collections import defaultdict from dataclasses import dataclass, asdict @@ -8,6 +9,7 @@ import uuid import traceback +import ray.exceptions from ray.experimental.channel.cached_channel import CachedChannel from ray.experimental.channel.gpu_communicator import GPUCommunicator import ray @@ -52,6 +54,21 @@ logger = logging.getLogger(__name__) +# Keep tracking of every compiled dag created during the lifetime of +# this process. It tracks them as weakref meaning when the compiled dag +# is GC'ed, it is automatically removed from here. It is used to teardown +# compiled dags at interpret shutdown time. +_compiled_dags = weakref.WeakValueDictionary() + + +# Relying on __del__ doesn't work well upon shutdown because +# the destructor order is not guaranteed. We call this function +# upon `ray.worker.shutdown` which is registered to atexit handler +# so that teardown is properly called before objects are destructed. +def _shutdown_all_compiled_dags(): + for _, compiled_dag in _compiled_dags.items(): + compiled_dag.teardown() + @DeveloperAPI def do_allocate_channel( @@ -1648,7 +1665,7 @@ def _is_same_actor(idx1: int, idx2: int) -> bool: return False def _monitor_failures(self): - outer = self + outer = weakref.proxy(self) class Monitor(threading.Thread): def __init__(self): @@ -1657,6 +1674,8 @@ def __init__(self): # Lock to make sure that we only perform teardown for this DAG # once. self.in_teardown_lock = threading.Lock() + self.name = "CompiledGraphMonitorThread" + self._teardown_done = False def wait_teardown(self): for actor, ref in outer.worker_task_refs.items(): @@ -1686,6 +1705,9 @@ def wait_teardown(self): def teardown(self, wait: bool): do_teardown = False with self.in_teardown_lock: + if self._teardown_done: + return + if not self.in_teardown: do_teardown = True self.in_teardown = True @@ -1709,9 +1731,11 @@ def teardown(self, wait: bool): ] for cancel_ref in cancel_refs: try: - # TODO(swang): Suppress exceptions from actors trying to - # read closed channels when DAG is being torn down. ray.get(cancel_ref, timeout=30) + except ray.exceptions.RayChannelError: + # Channel error happens when a channel is closed + # or timed out. In this case, do not log. + pass except Exception: logger.exception("Error cancelling worker task") pass @@ -1724,6 +1748,9 @@ def teardown(self, wait: bool): self.wait_teardown() logger.info("Teardown complete") + with self.in_teardown_lock: + self._teardown_done = True + def run(self): try: ray.get(list(outer.worker_task_refs.values())) @@ -2140,11 +2167,7 @@ def teardown(self): def __del__(self): monitor = getattr(self, "_monitor", None) if monitor is not None: - # Teardown asynchronously. - # NOTE(swang): Somehow, this can get called after the CoreWorker - # has already been destructed, so it is not safe to block in - # ray.get. - monitor.teardown(wait=False) + monitor.teardown(wait=True) @DeveloperAPI @@ -2173,4 +2196,6 @@ def _build_compiled_dag(node): root = dag._find_root() root.traverse_and_apply(_build_compiled_dag) compiled_dag._get_or_compile() + global _compiled_dags + _compiled_dags[compiled_dag.get_id()] = compiled_dag return compiled_dag diff --git a/python/ray/dag/tests/experimental/test_accelerated_dag.py b/python/ray/dag/tests/experimental/test_accelerated_dag.py index 7970f0dbf541..38661e1a73ad 100644 --- a/python/ray/dag/tests/experimental/test_accelerated_dag.py +++ b/python/ray/dag/tests/experimental/test_accelerated_dag.py @@ -13,6 +13,8 @@ import pytest + +from ray._private.test_utils import run_string_as_driver from ray.exceptions import RayChannelError, RayChannelTimeoutError import ray import ray._private @@ -171,10 +173,6 @@ def test_basic(ray_start_regular): # Delete the buffer so that the next DAG output can be written. del result - # Note: must teardown before starting a new Ray session, otherwise you'll get - # a segfault from the dangling monitor thread upon the new Ray init. - compiled_dag.teardown() - def test_two_returns_first(ray_start_regular): a = Actor.remote(0) @@ -187,8 +185,6 @@ def test_two_returns_first(ray_start_regular): res = ray.get(compiled_dag.execute(1)) assert res == 1 - compiled_dag.teardown() - def test_two_returns_second(ray_start_regular): a = Actor.remote(0) @@ -201,8 +197,6 @@ def test_two_returns_second(ray_start_regular): res = ray.get(compiled_dag.execute(1)) assert res == 2 - compiled_dag.teardown() - @pytest.mark.parametrize("single_fetch", [True, False]) def test_two_returns_one_reader(ray_start_regular, single_fetch): @@ -225,8 +219,6 @@ def test_two_returns_one_reader(ray_start_regular, single_fetch): res = ray.get(refs) assert res == [1, 2] - compiled_dag.teardown() - @pytest.mark.parametrize("single_fetch", [True, False]) def test_two_returns_two_readers(ray_start_regular, single_fetch): @@ -250,8 +242,6 @@ def test_two_returns_two_readers(ray_start_regular, single_fetch): res = ray.get(refs) assert res == [1, 2] - compiled_dag.teardown() - @pytest.mark.parametrize("single_fetch", [True, False]) def test_inc_two_returns(ray_start_regular, single_fetch): @@ -271,8 +261,6 @@ def test_inc_two_returns(ray_start_regular, single_fetch): res = ray.get(refs) assert res == [i + 1, i + 2] - compiled_dag.teardown() - def test_two_as_one_return(ray_start_regular): a = Actor.remote(0) @@ -285,8 +273,6 @@ def test_two_as_one_return(ray_start_regular): res = ray.get(compiled_dag.execute(1)) assert res == (1, 2) - compiled_dag.teardown() - def test_multi_output_get_exception(ray_start_regular): a = Actor.remote(0) @@ -308,8 +294,6 @@ def test_multi_output_get_exception(ray_start_regular): ): ray.get(refs) - compiled_dag.teardown() - # TODO(wxdeng): Fix segfault. If this test is run, the following tests # will segfault. @@ -366,8 +350,6 @@ def test_kwargs_not_supported(ray_start_regular): compiled_dag = dag.experimental_compile() assert ray.get(compiled_dag.execute(2)) == 3 - compiled_dag.teardown() - def test_out_of_order_get(ray_start_regular): c = Collector.remote() @@ -384,8 +366,6 @@ def test_out_of_order_get(ray_start_regular): result_a = ray.get(ref_a) assert result_a == ["a"] - compiled_dag.teardown() - def test_actor_multi_methods(ray_start_regular): a = Actor.remote(0) @@ -398,8 +378,6 @@ def test_actor_multi_methods(ray_start_regular): result = ray.get(ref) assert result == 1 - compiled_dag.teardown() - @pytest.mark.parametrize("single_fetch", [True, False]) def test_actor_methods_execution_order(ray_start_regular, single_fetch): @@ -421,8 +399,6 @@ def test_actor_methods_execution_order(ray_start_regular, single_fetch): else: assert ray.get(refs) == [4, 1] - compiled_dag.teardown() - def test_actor_method_multi_binds(ray_start_regular): a = Actor.remote(0) @@ -435,8 +411,6 @@ def test_actor_method_multi_binds(ray_start_regular): result = ray.get(ref) assert result == 2 - compiled_dag.teardown() - def test_actor_method_bind_same_constant(ray_start_regular): a = Actor.remote(0) @@ -451,8 +425,6 @@ def test_actor_method_bind_same_constant(ray_start_regular): result = ray.get(ref) assert result == 5 - compiled_dag.teardown() - def test_actor_method_bind_same_input(ray_start_regular): actor = Actor.remote(0) @@ -469,7 +441,6 @@ def test_actor_method_bind_same_input(ray_start_regular): ref = compiled_dag.execute(i) result = ray.get(ref) assert result == expected[i] - compiled_dag.teardown() def test_actor_method_bind_same_input_attr(ray_start_regular): @@ -487,7 +458,6 @@ def test_actor_method_bind_same_input_attr(ray_start_regular): ref = compiled_dag.execute(i) result = ray.get(ref) assert result == expected[i] - compiled_dag.teardown() def test_actor_method_bind_diff_input_attr_1(ray_start_regular): @@ -509,8 +479,6 @@ def test_actor_method_bind_diff_input_attr_1(ray_start_regular): ref = compiled_dag.execute(2, 3) assert ray.get(ref) == [0, 1, 2, 4, 6, 9] - compiled_dag.teardown() - def test_actor_method_bind_diff_input_attr_2(ray_start_regular): actor = Actor.remote(0) @@ -533,8 +501,6 @@ def test_actor_method_bind_diff_input_attr_2(ray_start_regular): ref = compiled_dag.execute(2, 3) assert ray.get(ref) == [0, 0, 1, 2, 3, 5, 7, 9, 12] - compiled_dag.teardown() - def test_actor_method_bind_diff_input_attr_3(ray_start_regular): actor = Actor.remote(0) @@ -552,8 +518,6 @@ def test_actor_method_bind_diff_input_attr_3(ray_start_regular): ref = compiled_dag.execute(2, 3) assert ray.get(ref) == 9 - compiled_dag.teardown() - def test_actor_method_bind_diff_input_attr_4(ray_start_regular): actor = Actor.remote(0) @@ -572,8 +536,6 @@ def test_actor_method_bind_diff_input_attr_4(ray_start_regular): ref = compiled_dag.execute(2, 3, 4) assert ray.get(ref) == [1, 3, 6, 9, 14, 18] - compiled_dag.teardown() - def test_actor_method_bind_diff_input_attr_5(ray_start_regular): actor = Actor.remote(0) @@ -592,8 +554,6 @@ def test_actor_method_bind_diff_input_attr_5(ray_start_regular): ref = compiled_dag.execute(2, 3, 4) assert ray.get(ref) == [1, 3, 6, 10, 15, 21] - compiled_dag.teardown() - def test_actor_method_bind_diff_kwargs_input_attr(ray_start_regular): actor = Actor.remote(0) @@ -614,8 +574,6 @@ def test_actor_method_bind_diff_kwargs_input_attr(ray_start_regular): ref = compiled_dag.execute(x=2, y=3) assert ray.get(ref) == [0, 1, 2, 4, 6, 9] - compiled_dag.teardown() - def test_actor_method_bind_same_arg(ray_start_regular): a1 = Actor.remote(0) @@ -634,7 +592,6 @@ def test_actor_method_bind_same_arg(ray_start_regular): ref = compiled_dag.execute(i) result = ray.get(ref) assert result == expected[i] - compiled_dag.teardown() def test_mixed_bind_same_input(ray_start_regular): @@ -654,7 +611,6 @@ def test_mixed_bind_same_input(ray_start_regular): ref = compiled_dag.execute(i) result = ray.get(ref) assert result == expected[i] - compiled_dag.teardown() def test_regular_args(ray_start_regular): @@ -670,8 +626,6 @@ def test_regular_args(ray_start_regular): result = ray.get(ref) assert result == (i + 1) * 3 - compiled_dag.teardown() - class TestMultiArgs: def test_multi_args_basic(self, ray_start_regular): @@ -689,8 +643,6 @@ def test_multi_args_basic(self, ray_start_regular): result = ray.get(ref) assert result == [3, 2] - compiled_dag.teardown() - def test_multi_args_single_actor(self, ray_start_regular): c = Collector.remote() with InputNode() as i: @@ -725,8 +677,6 @@ def test_multi_args_single_actor(self, ray_start_regular): ): compiled_dag.execute(args=(2, 3)) - compiled_dag.teardown() - def test_multi_args_branch(self, ray_start_regular): a = Actor.remote(0) c = Collector.remote() @@ -740,8 +690,6 @@ def test_multi_args_branch(self, ray_start_regular): result = ray.get(ref) assert result == [2, 3] - compiled_dag.teardown() - def test_kwargs_basic(self, ray_start_regular): a1 = Actor.remote(0) a2 = Actor.remote(0) @@ -757,8 +705,6 @@ def test_kwargs_basic(self, ray_start_regular): result = ray.get(ref) assert result == [3, 2] - compiled_dag.teardown() - def test_kwargs_single_actor(self, ray_start_regular): c = Collector.remote() with InputNode() as i: @@ -791,8 +737,6 @@ def test_kwargs_single_actor(self, ray_start_regular): ): compiled_dag.execute(x=3) - compiled_dag.teardown() - def test_kwargs_branch(self, ray_start_regular): a = Actor.remote(0) c = Collector.remote() @@ -806,8 +750,6 @@ def test_kwargs_branch(self, ray_start_regular): result = ray.get(ref) assert result == [3, 2] - compiled_dag.teardown() - def test_multi_args_and_kwargs(self, ray_start_regular): a1 = Actor.remote(0) a2 = Actor.remote(0) @@ -823,8 +765,6 @@ def test_multi_args_and_kwargs(self, ray_start_regular): result = ray.get(ref) assert result == [3, 4, 2] - compiled_dag.teardown() - def test_multi_args_and_torch_type(self, ray_start_regular): a1 = Actor.remote(0) a2 = Actor.remote(0) @@ -848,8 +788,6 @@ def test_multi_args_and_torch_type(self, ray_start_regular): assert torch.equal(tensors[0], cpu_tensors[1]) assert torch.equal(tensors[1], cpu_tensors[0]) - compiled_dag.teardown() - def test_mix_entire_input_and_args(self, ray_start_regular): """ It is not allowed to consume both the entire input and a partial @@ -883,8 +821,6 @@ def test_multi_args_same_actor(self, ray_start_regular): result = ray.get(ref) assert result == [1, 3] - compiled_dag.teardown() - def test_multi_args_basic_asyncio(self, ray_start_regular): a1 = Actor.remote(0) a2 = Actor.remote(0) @@ -902,7 +838,6 @@ async def main(): loop = get_or_create_event_loop() loop.run_until_complete(asyncio.gather(main())) - compiled_dag.teardown() def test_multi_args_branch_asyncio(self, ray_start_regular): a = Actor.remote(0) @@ -920,7 +855,6 @@ async def main(): loop = get_or_create_event_loop() loop.run_until_complete(asyncio.gather(main())) - compiled_dag.teardown() def test_kwargs_basic_asyncio(self, ray_start_regular): a1 = Actor.remote(0) @@ -940,7 +874,6 @@ async def main(): loop = get_or_create_event_loop() loop.run_until_complete(asyncio.gather(main())) - compiled_dag.teardown() def test_kwargs_branch_asyncio(self, ray_start_regular): a = Actor.remote(0) @@ -958,7 +891,6 @@ async def main(): loop = get_or_create_event_loop() loop.run_until_complete(asyncio.gather(main())) - compiled_dag.teardown() def test_multi_args_and_kwargs_asyncio(self, ray_start_regular): a1 = Actor.remote(0) @@ -978,7 +910,6 @@ async def main(): loop = get_or_create_event_loop() loop.run_until_complete(asyncio.gather(main())) - compiled_dag.teardown() @pytest.mark.parametrize("num_actors", [1, 4]) @@ -1002,8 +933,6 @@ def test_scatter_gather_dag(ray_start_regular, num_actors, single_fetch): results = ray.get(refs) assert results == [i + 1] * num_actors - compiled_dag.teardown() - @pytest.mark.parametrize("num_actors", [1, 4]) def test_chain_dag(ray_start_regular, num_actors): @@ -1020,8 +949,6 @@ def test_chain_dag(ray_start_regular, num_actors): result = ray.get(ref) assert result == list(range(num_actors)) - compiled_dag.teardown() - def test_get_timeout(ray_start_regular): a = Actor.remote(0) @@ -1043,8 +970,6 @@ def test_get_timeout(ray_start_regular): timed_out = True assert timed_out - compiled_dag.teardown() - def test_buffered_get_timeout(ray_start_regular): a = Actor.remote(0) @@ -1065,8 +990,6 @@ def test_buffered_get_timeout(ray_start_regular): # be raised. ray.get(refs[-1], timeout=3.5) - compiled_dag.teardown() - def test_get_with_zero_timeout(ray_start_regular): a = Actor.remote(0) @@ -1081,8 +1004,6 @@ def test_get_with_zero_timeout(ray_start_regular): result = ray.get(ref, timeout=0) assert result == 1 - compiled_dag.teardown() - def test_dag_exception_basic(ray_start_regular, capsys): # Test application throwing exceptions with a single task. @@ -1108,8 +1029,6 @@ def test_dag_exception_basic(ray_start_regular, capsys): # Can use the DAG after exceptions are thrown. assert ray.get(compiled_dag.execute(1)) == 1 - compiled_dag.teardown() - def test_dag_exception_chained(ray_start_regular, capsys): # Test application throwing exceptions with a task that depends on another @@ -1137,8 +1056,6 @@ def test_dag_exception_chained(ray_start_regular, capsys): # Can use the DAG after exceptions are thrown. assert ray.get(compiled_dag.execute(1)) == 2 - compiled_dag.teardown() - @pytest.mark.parametrize("single_fetch", [True, False]) def test_dag_exception_multi_output(ray_start_regular, single_fetch, capsys): @@ -1186,8 +1103,6 @@ def test_dag_exception_multi_output(ray_start_regular, single_fetch, capsys): else: assert ray.get(refs) == [1, 1] - compiled_dag.teardown() - def test_dag_errors(ray_start_regular): a = Actor.remote(0) @@ -1281,7 +1196,6 @@ def f(x): ), ): ray.get(ref) - compiled_dag.teardown() class TestDAGExceptionCompileMultipleTimes: @@ -1340,8 +1254,7 @@ def test_compile_twice_with_multioutputnode_without_teardown( "object multiple times no matter whether `teardown` is called or not. " "Please reuse the existing compiled DAG or create a new one.", ): - compiled_dag = dag.experimental_compile() - compiled_dag.teardown() + compiled_dag = dag.experimental_compile() # noqa def test_compile_twice_with_different_nodes(self, ray_start_regular): a = Actor.remote(0) @@ -1386,7 +1299,6 @@ def test_exceed_max_buffered_results(ray_start_regular): ray.get(ref) del refs - compiled_dag.teardown() @pytest.mark.parametrize("single_fetch", [True, False]) @@ -1425,7 +1337,6 @@ def test_exceed_max_buffered_results_multi_output(ray_start_regular, single_fetc ray.get(ref) del refs - compiled_dag.teardown() def test_compiled_dag_ref_del(ray_start_regular): @@ -1441,8 +1352,6 @@ def test_compiled_dag_ref_del(ray_start_regular): ref = compiled_dag.execute(1) del ref - compiled_dag.teardown() - def test_dag_fault_tolerance_chain(ray_start_regular): actors = [ @@ -1484,8 +1393,6 @@ def test_dag_fault_tolerance_chain(ray_start_regular): results = ray.get(ref) assert results == i - compiled_dag.teardown() - @pytest.mark.parametrize("single_fetch", [True, False]) def test_dag_fault_tolerance(ray_start_regular, single_fetch): @@ -1536,8 +1443,6 @@ def test_dag_fault_tolerance(ray_start_regular, single_fetch): else: ray.get(refs) - compiled_dag.teardown() - @pytest.mark.parametrize("single_fetch", [True, False]) def test_dag_fault_tolerance_sys_exit(ray_start_regular, single_fetch): @@ -1588,8 +1493,6 @@ def test_dag_fault_tolerance_sys_exit(ray_start_regular, single_fetch): else: ray.get(refs) - compiled_dag.teardown() - def test_dag_teardown_while_running(ray_start_regular): a = Actor.remote(0) @@ -1614,8 +1517,6 @@ def test_dag_teardown_while_running(ray_start_regular): result = ray.get(ref) assert result == 0.1 - compiled_dag.teardown() - @pytest.mark.parametrize("max_queue_size", [None, 2]) def test_asyncio(ray_start_regular, max_queue_size): @@ -1638,9 +1539,6 @@ async def main(i): assert (result == val).all() loop.run_until_complete(asyncio.gather(*[main(i) for i in range(10)])) - # Note: must teardown before starting a new Ray session, otherwise you'll get - # a segfault from the dangling monitor thread upon the new Ray init. - compiled_dag.teardown() @pytest.mark.parametrize("max_queue_size", [None, 2]) @@ -1664,7 +1562,6 @@ async def main(): assert result_a == ["a"] loop.run_until_complete(main()) - compiled_dag.teardown() @pytest.mark.parametrize("max_queue_size", [None, 2]) @@ -1699,9 +1596,6 @@ async def main(i): assert (result == val).all() loop.run_until_complete(asyncio.gather(*[main(i) for i in range(10)])) - # Note: must teardown before starting a new Ray session, otherwise you'll get - # a segfault from the dangling monitor thread upon the new Ray init. - compiled_dag.teardown() @pytest.mark.parametrize("max_queue_size", [None, 2]) @@ -1739,9 +1633,6 @@ async def main(): assert result == 2 loop.run_until_complete(main()) - # Note: must teardown before starting a new Ray session, otherwise you'll get - # a segfault from the dangling monitor thread upon the new Ray init. - compiled_dag.teardown() class TestCompositeChannel: @@ -1777,8 +1668,6 @@ def test_composite_channel_one_actor(self, ray_start_regular): ref = compiled_dag.execute(3) assert ray.get(ref) == 108 - compiled_dag.teardown() - def test_composite_channel_two_actors(self, ray_start_regular): """ In this test, there are three 'inc' tasks on the two Ray actors, chained @@ -1811,8 +1700,6 @@ def test_composite_channel_two_actors(self, ray_start_regular): ref = compiled_dag.execute(3) assert ray.get(ref) == 829 - compiled_dag.teardown() - @pytest.mark.parametrize("single_fetch", [True, False]) def test_composite_channel_multi_output(self, ray_start_regular, single_fetch): """ @@ -1847,8 +1734,6 @@ def test_composite_channel_multi_output(self, ray_start_regular, single_fetch): else: assert ray.get(refs) == [10, 106] - compiled_dag.teardown() - @pytest.mark.parametrize("single_fetch", [True, False]) def test_intra_process_channel_with_multi_readers( self, ray_start_regular, single_fetch @@ -1895,8 +1780,6 @@ def test_intra_process_channel_with_multi_readers( else: assert ray.get(refs) == [3, 3] - compiled_dag.teardown() - class TestLeafNode: def test_leaf_node_one_actor(self, ray_start_regular): @@ -1920,7 +1803,6 @@ def test_leaf_node_one_actor(self, ray_start_regular): ref = compiled_dag.execute(10) assert ray.get(ref) == [20] - compiled_dag.teardown() def test_leaf_node_two_actors(self, ray_start_regular): """ @@ -1943,7 +1825,6 @@ def test_leaf_node_two_actors(self, ray_start_regular): ref = compiled_dag.execute(10) assert ray.get(ref) == [120, 220] - compiled_dag.teardown() def test_output_node(ray_start_regular): @@ -1997,7 +1878,6 @@ def echo(self, data): ref = compiled_dag.execute(x=1, y=2) assert ray.get(ref) == [1, 2, 1] - compiled_dag.teardown() @pytest.mark.parametrize("single_fetch", [True, False]) @@ -2079,7 +1959,6 @@ def read_input(self, input): "BWD rank-1, batch-1", "BWD rank-1, batch-2", ] - output_dag.teardown() def test_channel_read_after_close(ray_start_regular): @@ -2231,7 +2110,6 @@ async def main(): loop = get_or_create_event_loop() loop.run_until_complete(main()) - async_dag.teardown() def test_event_profiling(ray_start_regular, monkeypatch): @@ -2261,8 +2139,6 @@ def test_event_profiling(ray_start_regular, monkeypatch): assert event.method_name == "inc" assert event.operation in ["READ", "COMPUTE", "WRITE"] - adag.teardown() - @ray.remote class TestWorker: @@ -2465,6 +2341,43 @@ def call(self, value): assert torch.equal(ray.get(ref), torch.tensor([5, 5, 5, 5, 5])) +def test_async_shutdown(shutdown_only): + """Verify that when async API is used, shutdown doesn't hang + because of threads joining at exit. + """ + + script = """ +import asyncio +import ray +from ray.dag import InputNode, MultiOutputNode + +async def main(): + @ray.remote + class A: + def f(self, i): + return i + + a = A.remote() + b = A.remote() + + with InputNode() as inp: + x = a.f.bind(inp) + y = b.f.bind(inp) + dag = MultiOutputNode([x, y]) + + adag = dag.experimental_compile(enable_asyncio=True) + refs = await adag.execute_async(1) + outputs = [] + for ref in refs: + outputs.append(await ref) + print(outputs) + +asyncio.run(main()) + """ + + print(run_string_as_driver(script)) + + def test_multi_arg_exception(shutdown_only): a = Actor.remote(0) with InputNode() as i: @@ -2479,8 +2392,6 @@ def test_multi_arg_exception(shutdown_only): with pytest.raises(RuntimeError): ray.get(y) - compiled_dag.teardown() - def test_multi_arg_exception_async(shutdown_only): a = Actor.remote(0) @@ -2501,8 +2412,6 @@ async def main(): loop = get_or_create_event_loop() loop.run_until_complete(main()) - compiled_dag.teardown() - class TestVisualization: diff --git a/python/ray/dag/tests/experimental/test_detect_deadlock_dag.py b/python/ray/dag/tests/experimental/test_detect_deadlock_dag.py index bedfb2701ba5..42ac5a2dc672 100644 --- a/python/ray/dag/tests/experimental/test_detect_deadlock_dag.py +++ b/python/ray/dag/tests/experimental/test_detect_deadlock_dag.py @@ -64,8 +64,7 @@ def test_invalid_graph_1_actor(ray_start_regular, tensor_transport): dag = a.no_op.bind(dag) if tensor_transport == TorchTensorType.AUTO: - compiled_graph = dag.experimental_compile() - compiled_graph.teardown() + dag.experimental_compile() elif tensor_transport == TorchTensorType.NCCL: with pytest.raises(ValueError, match=INVALID_GRAPH): dag.experimental_compile() @@ -140,8 +139,7 @@ def test_valid_graph_2_actors_1(ray_start_regular, tensor_transport): ] ) - compiled_graph = dag.experimental_compile() - compiled_graph.teardown() + dag.experimental_compile() @pytest.mark.parametrize("ray_start_regular", [{"num_gpus": 2}], indirect=True) @@ -169,8 +167,7 @@ def test_valid_graph_2_actors_2(ray_start_regular): dag.with_type_hint(TorchTensorType(transport="nccl")) dag = b.no_op.bind(dag) - compiled_dag = dag.experimental_compile() - compiled_dag.teardown() + dag.experimental_compile() @pytest.mark.parametrize("ray_start_regular", [{"num_gpus": 2}], indirect=True) @@ -207,8 +204,7 @@ def test_invalid_graph_2_actors_1(ray_start_regular, tensor_transport): ) if tensor_transport == TorchTensorType.AUTO: - compiled_graph = dag.experimental_compile() - compiled_graph.teardown() + dag.experimental_compile() elif tensor_transport == TorchTensorType.NCCL: with pytest.raises(ValueError, match=INVALID_GRAPH): dag.experimental_compile() @@ -245,8 +241,7 @@ def test_invalid_graph_2_actors_2(ray_start_regular, tensor_transport): ) if tensor_transport == TorchTensorType.AUTO: - compiled_graph = dag.experimental_compile() - compiled_graph.teardown() + dag.experimental_compile() elif tensor_transport == TorchTensorType.NCCL: with pytest.raises(ValueError, match=INVALID_GRAPH): dag.experimental_compile() @@ -278,8 +273,7 @@ def test_valid_graph_3_actors_1(ray_start_regular, tensor_transport): ] ) - compiled_graph = dag.experimental_compile() - compiled_graph.teardown() + dag.experimental_compile() @pytest.mark.parametrize("ray_start_regular", [{"num_gpus": 3}], indirect=True) @@ -304,8 +298,7 @@ def test_valid_graph_3_actors_2(ray_start_regular): branch2.with_type_hint(TorchTensorType(transport="nccl")) dag = a.no_op_two.bind(branch1, branch2) - compiled_dag = dag.experimental_compile() - compiled_dag.teardown() + dag.experimental_compile() if __name__ == "__main__": diff --git a/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py b/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py index 7bfd84502901..bdb9dafbdd62 100644 --- a/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py +++ b/python/ray/dag/tests/experimental/test_execution_schedule_gpu.py @@ -194,8 +194,6 @@ def test_simulate_pp_2workers_2batches_1f1b( for tensor in tensors: assert torch.equal(tensor, tensor_cpu) - compiled_dag.teardown() - @pytest.mark.parametrize("ray_start_regular", [{"num_gpus": 4}], indirect=True) def test_simulate_pp_4workers_8batches_1f1b(ray_start_regular, monkeypatch): @@ -218,7 +216,6 @@ def test_simulate_pp_4workers_8batches_1f1b(ray_start_regular, monkeypatch): assert len(tensors) == num_microbatches for t in tensors: assert torch.equal(t, tensor_cpu) - compiled_dag.teardown() @pytest.mark.parametrize("ray_start_regular", [{"num_gpus": 3}], indirect=True) @@ -286,8 +283,6 @@ def test_three_actors_with_nccl_1(ray_start_regular): for t in tensors: assert torch.equal(t, tensor_cpu) - compiled_dag.teardown() - @pytest.mark.parametrize("ray_start_regular", [{"num_gpus": 3}], indirect=True) @pytest.mark.parametrize("single_fetch", [True, False]) @@ -369,8 +364,6 @@ def test_three_actors_with_nccl_2(ray_start_regular, single_fetch, monkeypatch): for tensor in tensors: assert torch.equal(tensor, tensor_cpu) - compiled_dag.teardown() - if __name__ == "__main__": if os.environ.get("PARALLEL_CI"): diff --git a/python/ray/dag/tests/experimental/test_multi_args_gpu.py b/python/ray/dag/tests/experimental/test_multi_args_gpu.py index 4d9dcf8e11a8..d484132f4998 100644 --- a/python/ray/dag/tests/experimental/test_multi_args_gpu.py +++ b/python/ray/dag/tests/experimental/test_multi_args_gpu.py @@ -68,8 +68,6 @@ def backward(self, data): assert torch.equal(tensors[2], tensor_cpu_list[2]) assert torch.equal(tensors[3], tensor_cpu_list[2]) - compiled_dag.teardown() - if __name__ == "__main__": if os.environ.get("PARALLEL_CI"): diff --git a/python/ray/dag/tests/experimental/test_multi_node_dag.py b/python/ray/dag/tests/experimental/test_multi_node_dag.py index b8cb128eab27..822b0b17cfd9 100644 --- a/python/ray/dag/tests/experimental/test_multi_node_dag.py +++ b/python/ray/dag/tests/experimental/test_multi_node_dag.py @@ -102,8 +102,6 @@ def _get_node_id(self) -> "ray.NodeID": for i in range(1, 10): assert ray.get(adag.execute(1)) == [i, i, i] - adag.teardown() - def test_bunch_readers_on_different_nodes(ray_start_cluster): cluster = ray_start_cluster @@ -143,8 +141,6 @@ def _get_node_id(self) -> "ray.NodeID": i for _ in range(ACTORS_PER_NODE * (NUM_REMOTE_NODES + 1)) ] - adag.teardown() - @pytest.mark.parametrize("single_fetch", [True, False]) def test_pp(ray_start_cluster, single_fetch): @@ -190,8 +186,6 @@ def execute_model(self, val): # So that raylets' error messages are printed to the driver time.sleep(2) - compiled_dag.teardown() - def test_payload_large(ray_start_cluster, monkeypatch): GRPC_MAX_SIZE = 1024 * 1024 * 5 @@ -241,10 +235,6 @@ def get_node_id(self): result = ray.get(ref) assert result == val - # Note: must teardown before starting a new Ray session, otherwise you'll get - # a segfault from the dangling monitor thread upon the new Ray init. - compiled_dag.teardown() - @pytest.mark.parametrize("num_actors", [1, 4]) @pytest.mark.parametrize("num_nodes", [1, 4]) @@ -295,8 +285,6 @@ def _get_node_id(self) -> "ray.NodeID": result = ray.get(ref) assert result == [val for _ in range(ACTORS_PER_NODE * (NUM_REMOTE_NODES + 1))] - compiled_dag.teardown() - def test_multi_node_dag_from_actor(ray_start_cluster): cluster = ray_start_cluster diff --git a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py index 70e9a14296b9..3cd8bc8765d1 100644 --- a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py +++ b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py @@ -1,9 +1,9 @@ # coding: utf-8 import logging +import time import os import socket import sys -import time from typing import List, Optional, Tuple import pytest @@ -19,6 +19,7 @@ TorchTensorAllocator, ) from ray.experimental.channel.nccl_group import _NcclGroup + from ray.experimental.channel.torch_tensor_type import TorchTensorType from ray.tests.conftest import * # noqa from ray.experimental.util.types import ReduceOp @@ -103,6 +104,74 @@ def forward(self, inp): return torch.randn(10, 10) +class TestNcclGroup(GPUCommunicator): + """ + A custom NCCL group for testing. This is a simple wrapper around `_NcclGroup`. + """ + + def __init__(self, world_size, comm_id, actor_handles): + self._world_size = world_size + self._comm_id = comm_id + self._actor_handles = actor_handles + self._inner = None + + def initialize(self, rank: int) -> None: + self._inner = _NcclGroup( + self._world_size, + self._comm_id, + rank, + self._actor_handles, + torch.cuda.current_stream().cuda_stream, + ) + + def get_rank(self, actor: ray.actor.ActorHandle) -> int: + # Implement this without forwarding to `_inner` to allow the method + # to be called before initialization. + actor_ids = [a._ray_actor_id for a in self._actor_handles] + try: + rank = actor_ids.index(actor._ray_actor_id) + except ValueError: + raise ValueError("Actor is not in the NCCL group.") + return rank + + def get_world_size(self) -> int: + # Implement this without forwarding to `_inner` to allow the method + # to be called before initialization. + return self._world_size + + def get_self_rank(self) -> Optional[int]: + if self._inner is None: + return None + return self._inner.get_self_rank() + + def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: + return self._actor_handles + + def send(self, value: "torch.Tensor", peer_rank: int) -> None: + return self._inner.send(value, peer_rank) + + def recv( + self, + shape: Tuple[int], + dtype: "torch.dtype", + peer_rank: int, + allocator: Optional[TorchTensorAllocator] = None, + ) -> "torch.Tensor": + return self._inner.recv(shape, dtype, peer_rank, allocator=allocator) + + def allreduce( + self, + send_buf: "torch.Tensor", + recv_buf: "torch.Tensor", + op: ReduceOp = ReduceOp.SUM, + ) -> None: + self._inner.allreduce(send_buf, recv_buf, op) + recv_buf += 1 + + def destroy(self) -> None: + return self._inner.destroy() + + @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) def test_torch_tensor_p2p(ray_start_regular): if USE_GPU: @@ -165,7 +234,6 @@ def test_torch_tensor_p2p(ray_start_regular): ref = compiled_dag.execute((shape, dtype, 1)) ray.get(ref) - compiled_dag.teardown() @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) @@ -203,8 +271,6 @@ def test_torch_tensor_as_dag_input(ray_start_regular): result = ray.get(ref) assert result == (i, (20,), dtype) - compiled_dag.teardown() - @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) def test_torch_tensor_nccl(ray_start_regular): @@ -266,7 +332,6 @@ def test_torch_tensor_nccl(ray_start_regular): ref = compiled_dag.execute(i) result = ray.get(ref) assert result == (i, shape, dtype) - compiled_dag.teardown() # TODO(swang): Check that actors are still alive. Currently this fails due # to a ref counting assertion error. @@ -305,8 +370,6 @@ def test_torch_tensor_nccl_dynamic(ray_start_regular): result = ray.get(ref) assert result == (i, shape, dtype) - compiled_dag.teardown() - @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) def test_torch_tensor_custom_comm(ray_start_regular): @@ -324,77 +387,6 @@ def test_torch_tensor_custom_comm(ray_start_regular): from cupy.cuda import nccl - class TestNcclGroup(GPUCommunicator): - """ - A custom NCCL group for testing. This is a simple wrapper around `_NcclGroup`. - """ - - def __init__(self, world_size, comm_id, actor_handles): - self._world_size = world_size - self._comm_id = comm_id - self._actor_handles = actor_handles - self._inner = None - - def initialize(self, rank: int) -> None: - print(f"initializing rank {rank}") - try: - self._inner = _NcclGroup( - self._world_size, - self._comm_id, - rank, - self._actor_handles, - torch.cuda.current_stream().cuda_stream, - ) - except Exception as e: - print(f"Got {e}") - - def get_rank(self, actor: ray.actor.ActorHandle) -> int: - # Implement this without forwarding to `_inner` to allow the method - # to be called before initialization. - actor_ids = [a._ray_actor_id for a in self._actor_handles] - try: - rank = actor_ids.index(actor._ray_actor_id) - except ValueError: - raise ValueError("Actor is not in the NCCL group.") - return rank - - def get_world_size(self) -> int: - # Implement this without forwarding to `_inner` to allow the method - # to be called before initialization. - return self._world_size - - def get_self_rank(self) -> Optional[int]: - if self._inner is None: - return None - return self._inner.get_self_rank() - - def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: - return self._actor_handles - - def send(self, value: "torch.Tensor", peer_rank: int) -> None: - return self._inner.send(value, peer_rank) - - def recv( - self, - shape: Tuple[int], - dtype: "torch.dtype", - peer_rank: int, - allocator: Optional[TorchTensorAllocator] = None, - ) -> "torch.Tensor": - return self._inner.recv(shape, dtype, peer_rank, allocator=allocator) - - def allreduce( - self, - send_buf: "torch.Tensor", - recv_buf: "torch.Tensor", - op: ReduceOp = ReduceOp.SUM, - ) -> None: - self._inner.allreduce(send_buf, recv_buf, op) - recv_buf += 1 - - def destroy(self) -> None: - return self._inner.destroy() - comm_id = nccl.get_unique_id() nccl_group = TestNcclGroup(2, comm_id, [sender, receiver]) with InputNode() as inp: @@ -412,8 +404,6 @@ def destroy(self) -> None: result = ray.get(ref) assert result == (i, shape, dtype) - compiled_dag.teardown() - @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) def test_torch_tensor_custom_comm_invalid(ray_start_regular): @@ -645,8 +635,6 @@ def destroy(self) -> None: result = ray.get(ref) assert result == (i, shape, dtype) - compiled_dag.teardown() - @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) def test_torch_tensor_nccl_wrong_shape(ray_start_regular): @@ -693,8 +681,6 @@ def test_torch_tensor_nccl_wrong_shape(ray_start_regular): with pytest.raises(RayChannelError): ref = compiled_dag.execute(shape=(20,), dtype=dtype, value=1) - compiled_dag.teardown() - # TODO(swang): This currently requires time.sleep to avoid some issue with # following tests. time.sleep(3) @@ -738,8 +724,6 @@ def test_torch_tensor_nccl_nested(ray_start_regular): expected_result = {0: (0, shape, dtype)} assert result == expected_result - compiled_dag.teardown() - @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) def test_torch_tensor_nccl_nested_dynamic(ray_start_regular): @@ -778,8 +762,6 @@ def test_torch_tensor_nccl_nested_dynamic(ray_start_regular): expected_result = {j: (j, shape, dtype) for j in range(i)} assert result == expected_result - compiled_dag.teardown() - @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) def test_torch_tensor_nccl_direct_return_error(ray_start_regular): @@ -826,12 +808,6 @@ def test_torch_tensor_nccl_direct_return_error(ray_start_regular): with pytest.raises(RayChannelError): ref = compiled_dag.execute(shape=shape, dtype=dtype, value=1, send_tensor=True) - compiled_dag.teardown() - - # TODO(swang): This currently requires time.sleep to avoid some issue with - # following tests. - time.sleep(3) - @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) def test_torch_tensor_exceptions(ray_start_regular): @@ -896,8 +872,6 @@ def test_torch_tensor_exceptions(ray_start_regular): result = ray.get(ref) assert result == (i, shape, dtype) - compiled_dag.teardown() - @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) def test_torch_tensor_nccl_all_reduce(ray_start_regular): @@ -1062,77 +1036,6 @@ def test_torch_tensor_nccl_all_reduce_custom_comm(ray_start_regular): num_workers = 2 workers = [actor_cls.remote() for _ in range(num_workers)] - class TestNcclGroup(GPUCommunicator): - """ - A custom NCCL group for testing. This is a simple wrapper around `_NcclGroup`. - """ - - def __init__(self, world_size, comm_id, actor_handles): - self._world_size = world_size - self._comm_id = comm_id - self._actor_handles = actor_handles - self._inner = None - - def initialize(self, rank: int) -> None: - print(f"initializing rank {rank}") - try: - self._inner = _NcclGroup( - self._world_size, - self._comm_id, - rank, - self._actor_handles, - torch.cuda.current_stream().cuda_stream, - ) - except Exception as e: - print(f"Got {e}") - - def get_rank(self, actor: ray.actor.ActorHandle) -> int: - # Implement this without forwarding to `_inner` to allow the method - # to be called before initialization. - actor_ids = [a._ray_actor_id for a in self._actor_handles] - try: - rank = actor_ids.index(actor._ray_actor_id) - except ValueError: - raise ValueError("Actor is not in the NCCL group.") - return rank - - def get_world_size(self) -> int: - # Implement this without forwarding to `_inner` to allow the method - # to be called before initialization. - return self._world_size - - def get_self_rank(self) -> Optional[int]: - if self._inner is None: - return None - return self._inner.get_self_rank() - - def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: - return self._actor_handles - - def send(self, value: "torch.Tensor", peer_rank: int) -> None: - return self._inner.send(value, peer_rank) - - def recv( - self, - shape: Tuple[int], - dtype: "torch.dtype", - peer_rank: int, - allocator: Optional[TorchTensorAllocator] = None, - ) -> "torch.Tensor": - return self._inner.recv(shape, dtype, peer_rank, allocator=allocator) - - def allreduce( - self, - send_buf: "torch.Tensor", - recv_buf: "torch.Tensor", - op: ReduceOp = ReduceOp.SUM, - ) -> None: - self._inner.allreduce(send_buf, recv_buf, op) - recv_buf += 1 - - def destroy(self) -> None: - return self._inner.destroy() - from cupy.cuda import nccl comm_id = nccl.get_unique_id() diff --git a/python/ray/experimental/channel/common.py b/python/ray/experimental/channel/common.py index 84d5d5a6c111..0abdd0deb670 100644 --- a/python/ray/experimental/channel/common.py +++ b/python/ray/experimental/channel/common.py @@ -1,12 +1,14 @@ import asyncio import concurrent import copy +import sys import threading import time from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union import ray +import ray.exceptions from ray.experimental.channel.gpu_communicator import GPUCommunicator from ray.experimental.channel.serialization_context import _SerializationContext from ray.util.annotations import DeveloperAPI, PublicAPI @@ -19,6 +21,29 @@ import torch +def retry_and_check_interpreter_exit(f) -> bool: + """This function is only useful when f contains channel read/write. + + Keep retrying channel read/write inside `f` and check if interpreter exits. + It is important in case the read/write happens in a separate thread pool. + See https://github.com/ray-project/ray/pull/47702 + """ + exiting = False + while True: + try: + # results.append(c.read(timeout=1)) + f() + break + except ray.exceptions.RayChannelTimeoutError: + if sys.is_finalizing(): + # Interpreter exits. We should ignore the error and + # stop reading so that the thread can join. + exiting = True + break + + return exiting + + # Holds the input arguments for an accelerated DAG node. @PublicAPI(stability="alpha") class RayDAGArgs(NamedTuple): @@ -356,7 +381,15 @@ def start(self): self._background_task = asyncio.ensure_future(self.run()) def _run(self): - return [c.read() for c in self._input_channels] + results = [] + for c in self._input_channels: + exiting = retry_and_check_interpreter_exit( + lambda: results.append(c.read(timeout=1)) + ) + if exiting: + break + + return results async def run(self): loop = asyncio.get_running_loop() @@ -378,6 +411,7 @@ async def run(self): def close(self): super().close() self._background_task_executor.shutdown(cancel_futures=True) + self._background_task.cancel() @DeveloperAPI @@ -540,7 +574,11 @@ def _run(self, res): for i, channel in enumerate(self._output_channels): idx = self._output_idxs[i] res_i = _adapt(res, idx, self._is_input) - channel.write(res_i) + exiting = retry_and_check_interpreter_exit( + lambda: channel.write(res_i, timeout=1) + ) + if exiting: + break async def run(self): loop = asyncio.get_event_loop()