Skip to content

Commit

Permalink
Revert "[core][compiled graphs] Fix test_torch_tensor_dag_gpu CI fail… (
Browse files Browse the repository at this point in the history
ray-project#48250)

…ure (ray-project#48204)"

This reverts commit 23bb654.

Revert revert of ray-project#47702.
Signed-off-by: JP-sDEV <[email protected]>
  • Loading branch information
rkooo567 authored and JP-sDEV committed Nov 14, 2024
1 parent 0075558 commit d2826f1
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 362 deletions.
19 changes: 0 additions & 19 deletions python/ray/_private/ray_experimental_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ async def _exec_async():
results += timeit(
"[unstable] compiled single-actor DAG calls", lambda: _exec(compiled_dag)
)
compiled_dag.teardown()
del a

# Single-actor asyncio DAG calls
Expand All @@ -194,10 +193,6 @@ async def _exec_async():
"[unstable] compiled single-actor asyncio DAG calls",
)
)
# TODO: Need to explicitly tear down DAGs with enable_asyncio=True because
# these DAGs create a background thread that can segfault if the CoreWorker
# is torn down first.
compiled_dag.teardown()
del a

# Scatter-gather DAG calls
Expand All @@ -215,7 +210,6 @@ async def _exec_async():
f"[unstable] compiled scatter-gather DAG calls, n={n_cpu} actors",
lambda: _exec(compiled_dag),
)
compiled_dag.teardown()

# Scatter-gather asyncio DAG calls

Expand All @@ -228,10 +222,6 @@ async def _exec_async():
f"[unstable] compiled scatter-gather asyncio DAG calls, n={n_cpu} actors",
)
)
# TODO: Need to explicitly tear down DAGs with enable_asyncio=True because
# these DAGs create a background thread that can segfault if the CoreWorker
# is torn down first.
compiled_dag.teardown()

# Chain DAG calls

Expand All @@ -249,7 +239,6 @@ async def _exec_async():
f"[unstable] compiled chain DAG calls, n={n_cpu} actors",
lambda: _exec(compiled_dag),
)
compiled_dag.teardown()

# Chain asyncio DAG calls

Expand All @@ -262,10 +251,6 @@ async def _exec_async():
results += loop.run_until_complete(
exec_async(f"[unstable] compiled chain asyncio DAG calls, n={n_cpu} actors")
)
# TODO: Need to explicitly tear down DAGs with enable_asyncio=True because
# these DAGs create a background thread that can segfault if the CoreWorker
# is torn down first.
compiled_dag.teardown()

# Multiple args with small payloads

Expand All @@ -288,7 +273,6 @@ async def _exec_async():
f"n={n_actors} actors",
lambda: _exec(compiled_dag, num_args=n_actors, payload_size=payload_size),
)
compiled_dag.teardown()

# Multiple args with medium payloads

Expand All @@ -306,7 +290,6 @@ async def _exec_async():
f"n={n_actors} actors",
lambda: _exec(compiled_dag, num_args=n_actors, payload_size=payload_size),
)
compiled_dag.teardown()

# Multiple args with large payloads

Expand All @@ -324,7 +307,6 @@ async def _exec_async():
f"n={n_actors} actors",
lambda: _exec(compiled_dag, num_args=n_actors, payload_size=payload_size),
)
compiled_dag.teardown()

# Worst case for multiple arguments: a single actor takes all the arguments
# with small payloads.
Expand All @@ -345,7 +327,6 @@ async def _exec_async():
"n=1 actors",
lambda: _exec(compiled_dag, num_args=n_args, payload_size=payload_size),
)
compiled_dag.teardown()

ray.shutdown()

Expand Down
5 changes: 5 additions & 0 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1874,6 +1874,11 @@ def shutdown(_exiting_interpreter: bool = False):
and false otherwise. If we are exiting the interpreter, we will
wait a little while to print any extra error messages.
"""
# Make sure to clean up compiled dag node if exists.
from ray.dag.compiled_dag_node import _shutdown_all_compiled_dags

_shutdown_all_compiled_dags()

if _exiting_interpreter and global_worker.mode == SCRIPT_MODE:
# This is a duration to sleep before shutting down everything in order
# to make sure that log messages finish printing.
Expand Down
41 changes: 33 additions & 8 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import weakref
import asyncio
from collections import defaultdict
from dataclasses import dataclass, asdict
Expand All @@ -8,6 +9,7 @@
import uuid
import traceback

import ray.exceptions
from ray.experimental.channel.cached_channel import CachedChannel
from ray.experimental.channel.gpu_communicator import GPUCommunicator
import ray
Expand Down Expand Up @@ -52,6 +54,21 @@

logger = logging.getLogger(__name__)

# Keep tracking of every compiled dag created during the lifetime of
# this process. It tracks them as weakref meaning when the compiled dag
# is GC'ed, it is automatically removed from here. It is used to teardown
# compiled dags at interpret shutdown time.
_compiled_dags = weakref.WeakValueDictionary()


# Relying on __del__ doesn't work well upon shutdown because
# the destructor order is not guaranteed. We call this function
# upon `ray.worker.shutdown` which is registered to atexit handler
# so that teardown is properly called before objects are destructed.
def _shutdown_all_compiled_dags():
for _, compiled_dag in _compiled_dags.items():
compiled_dag.teardown()


@DeveloperAPI
def do_allocate_channel(
Expand Down Expand Up @@ -1648,7 +1665,7 @@ def _is_same_actor(idx1: int, idx2: int) -> bool:
return False

def _monitor_failures(self):
outer = self
outer = weakref.proxy(self)

class Monitor(threading.Thread):
def __init__(self):
Expand All @@ -1657,6 +1674,8 @@ def __init__(self):
# Lock to make sure that we only perform teardown for this DAG
# once.
self.in_teardown_lock = threading.Lock()
self.name = "CompiledGraphMonitorThread"
self._teardown_done = False

def wait_teardown(self):
for actor, ref in outer.worker_task_refs.items():
Expand Down Expand Up @@ -1686,6 +1705,9 @@ def wait_teardown(self):
def teardown(self, wait: bool):
do_teardown = False
with self.in_teardown_lock:
if self._teardown_done:
return

if not self.in_teardown:
do_teardown = True
self.in_teardown = True
Expand All @@ -1709,9 +1731,11 @@ def teardown(self, wait: bool):
]
for cancel_ref in cancel_refs:
try:
# TODO(swang): Suppress exceptions from actors trying to
# read closed channels when DAG is being torn down.
ray.get(cancel_ref, timeout=30)
except ray.exceptions.RayChannelError:
# Channel error happens when a channel is closed
# or timed out. In this case, do not log.
pass
except Exception:
logger.exception("Error cancelling worker task")
pass
Expand All @@ -1724,6 +1748,9 @@ def teardown(self, wait: bool):
self.wait_teardown()
logger.info("Teardown complete")

with self.in_teardown_lock:
self._teardown_done = True

def run(self):
try:
ray.get(list(outer.worker_task_refs.values()))
Expand Down Expand Up @@ -2140,11 +2167,7 @@ def teardown(self):
def __del__(self):
monitor = getattr(self, "_monitor", None)
if monitor is not None:
# Teardown asynchronously.
# NOTE(swang): Somehow, this can get called after the CoreWorker
# has already been destructed, so it is not safe to block in
# ray.get.
monitor.teardown(wait=False)
monitor.teardown(wait=True)


@DeveloperAPI
Expand Down Expand Up @@ -2173,4 +2196,6 @@ def _build_compiled_dag(node):
root = dag._find_root()
root.traverse_and_apply(_build_compiled_dag)
compiled_dag._get_or_compile()
global _compiled_dags
_compiled_dags[compiled_dag.get_id()] = compiled_dag
return compiled_dag
Loading

0 comments on commit d2826f1

Please sign in to comment.