Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Revert "[core][compiled graphs] Fix test_torch_tensor_dag_gpu… #48275

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/ray/_private/ray_experimental_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ async def _exec_async():
results += timeit(
"[unstable] compiled single-actor DAG calls", lambda: _exec(compiled_dag)
)
compiled_dag.teardown()
del a

# Single-actor asyncio DAG calls
Expand All @@ -193,6 +194,10 @@ async def _exec_async():
"[unstable] compiled single-actor asyncio DAG calls",
)
)
# TODO: Need to explicitly tear down DAGs with enable_asyncio=True because
# these DAGs create a background thread that can segfault if the CoreWorker
# is torn down first.
compiled_dag.teardown()
del a

# Scatter-gather DAG calls
Expand All @@ -210,6 +215,7 @@ async def _exec_async():
f"[unstable] compiled scatter-gather DAG calls, n={n_cpu} actors",
lambda: _exec(compiled_dag),
)
compiled_dag.teardown()

# Scatter-gather asyncio DAG calls

Expand All @@ -222,6 +228,10 @@ async def _exec_async():
f"[unstable] compiled scatter-gather asyncio DAG calls, n={n_cpu} actors",
)
)
# TODO: Need to explicitly tear down DAGs with enable_asyncio=True because
# these DAGs create a background thread that can segfault if the CoreWorker
# is torn down first.
compiled_dag.teardown()

# Chain DAG calls

Expand All @@ -239,6 +249,7 @@ async def _exec_async():
f"[unstable] compiled chain DAG calls, n={n_cpu} actors",
lambda: _exec(compiled_dag),
)
compiled_dag.teardown()

# Chain asyncio DAG calls

Expand All @@ -251,6 +262,10 @@ async def _exec_async():
results += loop.run_until_complete(
exec_async(f"[unstable] compiled chain asyncio DAG calls, n={n_cpu} actors")
)
# TODO: Need to explicitly tear down DAGs with enable_asyncio=True because
# these DAGs create a background thread that can segfault if the CoreWorker
# is torn down first.
compiled_dag.teardown()

# Multiple args with small payloads

Expand All @@ -273,6 +288,7 @@ async def _exec_async():
f"n={n_actors} actors",
lambda: _exec(compiled_dag, num_args=n_actors, payload_size=payload_size),
)
compiled_dag.teardown()

# Multiple args with medium payloads

Expand All @@ -290,6 +306,7 @@ async def _exec_async():
f"n={n_actors} actors",
lambda: _exec(compiled_dag, num_args=n_actors, payload_size=payload_size),
)
compiled_dag.teardown()

# Multiple args with large payloads

Expand All @@ -307,6 +324,7 @@ async def _exec_async():
f"n={n_actors} actors",
lambda: _exec(compiled_dag, num_args=n_actors, payload_size=payload_size),
)
compiled_dag.teardown()

# Worst case for multiple arguments: a single actor takes all the arguments
# with small payloads.
Expand All @@ -327,6 +345,7 @@ async def _exec_async():
"n=1 actors",
lambda: _exec(compiled_dag, num_args=n_args, payload_size=payload_size),
)
compiled_dag.teardown()

ray.shutdown()

Expand Down
5 changes: 0 additions & 5 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1874,11 +1874,6 @@ def shutdown(_exiting_interpreter: bool = False):
and false otherwise. If we are exiting the interpreter, we will
wait a little while to print any extra error messages.
"""
# Make sure to clean up compiled dag node if exists.
from ray.dag.compiled_dag_node import _shutdown_all_compiled_dags

_shutdown_all_compiled_dags()

if _exiting_interpreter and global_worker.mode == SCRIPT_MODE:
# This is a duration to sleep before shutting down everything in order
# to make sure that log messages finish printing.
Expand Down
41 changes: 8 additions & 33 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import weakref
import asyncio
from collections import defaultdict
from dataclasses import dataclass, asdict
Expand All @@ -9,7 +8,6 @@
import uuid
import traceback

import ray.exceptions
from ray.experimental.channel.cached_channel import CachedChannel
from ray.experimental.channel.gpu_communicator import GPUCommunicator
import ray
Expand Down Expand Up @@ -54,21 +52,6 @@

logger = logging.getLogger(__name__)

# Keep tracking of every compiled dag created during the lifetime of
# this process. It tracks them as weakref meaning when the compiled dag
# is GC'ed, it is automatically removed from here. It is used to teardown
# compiled dags at interpret shutdown time.
_compiled_dags = weakref.WeakValueDictionary()


# Relying on __del__ doesn't work well upon shutdown because
# the destructor order is not guaranteed. We call this function
# upon `ray.worker.shutdown` which is registered to atexit handler
# so that teardown is properly called before objects are destructed.
def _shutdown_all_compiled_dags():
for _, compiled_dag in _compiled_dags.items():
compiled_dag.teardown()


@DeveloperAPI
def do_allocate_channel(
Expand Down Expand Up @@ -1665,7 +1648,7 @@ def _is_same_actor(idx1: int, idx2: int) -> bool:
return False

def _monitor_failures(self):
outer = weakref.proxy(self)
outer = self

class Monitor(threading.Thread):
def __init__(self):
Expand All @@ -1674,8 +1657,6 @@ def __init__(self):
# Lock to make sure that we only perform teardown for this DAG
# once.
self.in_teardown_lock = threading.Lock()
self.name = "CompiledGraphMonitorThread"
self._teardown_done = False

def wait_teardown(self):
for actor, ref in outer.worker_task_refs.items():
Expand Down Expand Up @@ -1705,9 +1686,6 @@ def wait_teardown(self):
def teardown(self, wait: bool):
do_teardown = False
with self.in_teardown_lock:
if self._teardown_done:
return

if not self.in_teardown:
do_teardown = True
self.in_teardown = True
Expand All @@ -1731,11 +1709,9 @@ def teardown(self, wait: bool):
]
for cancel_ref in cancel_refs:
try:
# TODO(swang): Suppress exceptions from actors trying to
# read closed channels when DAG is being torn down.
ray.get(cancel_ref, timeout=30)
except ray.exceptions.RayChannelError:
# Channel error happens when a channel is closed
# or timed out. In this case, do not log.
pass
except Exception:
logger.exception("Error cancelling worker task")
pass
Expand All @@ -1748,9 +1724,6 @@ def teardown(self, wait: bool):
self.wait_teardown()
logger.info("Teardown complete")

with self.in_teardown_lock:
self._teardown_done = True

def run(self):
try:
ray.get(list(outer.worker_task_refs.values()))
Expand Down Expand Up @@ -2167,7 +2140,11 @@ def teardown(self):
def __del__(self):
monitor = getattr(self, "_monitor", None)
if monitor is not None:
monitor.teardown(wait=True)
# Teardown asynchronously.
# NOTE(swang): Somehow, this can get called after the CoreWorker
# has already been destructed, so it is not safe to block in
# ray.get.
monitor.teardown(wait=False)


@DeveloperAPI
Expand Down Expand Up @@ -2196,6 +2173,4 @@ def _build_compiled_dag(node):
root = dag._find_root()
root.traverse_and_apply(_build_compiled_dag)
compiled_dag._get_or_compile()
global _compiled_dags
_compiled_dags[compiled_dag.get_id()] = compiled_dag
return compiled_dag
Loading