Skip to content

Commit

Permalink
[core][compiled graphs] Fix test_torch_tensor_dag_gpu CI failure (ray…
Browse files Browse the repository at this point in the history
…-project#48204)

## Why are these changes needed?

This PR fixes test_torch_tensor_dag_gpu with the following quick patches:
1. Revert ray-project#47702 , otherwise there is segfault
2. Move TestNcclGroup as an inner class for the tests, otherwise there are the following error:

```
(TorchTensorWorker pid=2261373) No module named 'test_torch_tensor_dag'
(TorchTensorWorker pid=2261373) Traceback (most recent call last):
(TorchTensorWorker pid=2261373)   File "/home/ubuntu/ray/python/ray/_private/serialization.py", line 460, in deserialize_objects
(TorchTensorWorker pid=2261373)     obj = self._deserialize_object(data, metadata, object_ref)
(TorchTensorWorker pid=2261373)   File "/home/ubuntu/ray/python/ray/_private/serialization.py", line 317, in _deserialize_object
(TorchTensorWorker pid=2261373)     return self._deserialize_msgpack_data(data, metadata_fields)
(TorchTensorWorker pid=2261373)   File "/home/ubuntu/ray/python/ray/_private/serialization.py", line 272, in _deserialize_msgpack_data
(TorchTensorWorker pid=2261373)     python_objects = self._deserialize_pickle5_data(pickle5_data)
(TorchTensorWorker pid=2261373)   File "/home/ubuntu/ray/python/ray/_private/serialization.py", line 262, in _deserialize_pickle5_data
(TorchTensorWorker pid=2261373)     obj = pickle.loads(in_band)
(TorchTensorWorker pid=2261373) ModuleNotFoundError: No module named 'test_torch_tensor_dag'
```

Signed-off-by: JP-sDEV <[email protected]>
  • Loading branch information
ruisearch42 authored and JP-sDEV committed Nov 14, 2024
1 parent f11dadd commit b7b5a32
Show file tree
Hide file tree
Showing 10 changed files with 362 additions and 195 deletions.
19 changes: 19 additions & 0 deletions python/ray/_private/ray_experimental_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ async def _exec_async():
results += timeit(
"[unstable] compiled single-actor DAG calls", lambda: _exec(compiled_dag)
)
compiled_dag.teardown()
del a

# Single-actor asyncio DAG calls
Expand All @@ -194,6 +195,10 @@ async def _exec_async():
"[unstable] compiled single-actor asyncio DAG calls",
)
)
# TODO: Need to explicitly tear down DAGs with enable_asyncio=True because
# these DAGs create a background thread that can segfault if the CoreWorker
# is torn down first.
compiled_dag.teardown()
del a

# Scatter-gather DAG calls
Expand All @@ -211,6 +216,7 @@ async def _exec_async():
f"[unstable] compiled scatter-gather DAG calls, n={n_cpu} actors",
lambda: _exec(compiled_dag),
)
compiled_dag.teardown()

# Scatter-gather asyncio DAG calls

Expand All @@ -223,6 +229,10 @@ async def _exec_async():
f"[unstable] compiled scatter-gather asyncio DAG calls, n={n_cpu} actors",
)
)
# TODO: Need to explicitly tear down DAGs with enable_asyncio=True because
# these DAGs create a background thread that can segfault if the CoreWorker
# is torn down first.
compiled_dag.teardown()

# Chain DAG calls

Expand All @@ -240,6 +250,7 @@ async def _exec_async():
f"[unstable] compiled chain DAG calls, n={n_cpu} actors",
lambda: _exec(compiled_dag),
)
compiled_dag.teardown()

# Chain asyncio DAG calls

Expand All @@ -252,6 +263,10 @@ async def _exec_async():
results += loop.run_until_complete(
exec_async(f"[unstable] compiled chain asyncio DAG calls, n={n_cpu} actors")
)
# TODO: Need to explicitly tear down DAGs with enable_asyncio=True because
# these DAGs create a background thread that can segfault if the CoreWorker
# is torn down first.
compiled_dag.teardown()

# Multiple args with small payloads

Expand All @@ -274,6 +289,7 @@ async def _exec_async():
f"n={n_actors} actors",
lambda: _exec(compiled_dag, num_args=n_actors, payload_size=payload_size),
)
compiled_dag.teardown()

# Multiple args with medium payloads

Expand All @@ -291,6 +307,7 @@ async def _exec_async():
f"n={n_actors} actors",
lambda: _exec(compiled_dag, num_args=n_actors, payload_size=payload_size),
)
compiled_dag.teardown()

# Multiple args with large payloads

Expand All @@ -308,6 +325,7 @@ async def _exec_async():
f"n={n_actors} actors",
lambda: _exec(compiled_dag, num_args=n_actors, payload_size=payload_size),
)
compiled_dag.teardown()

# Worst case for multiple arguments: a single actor takes all the arguments
# with small payloads.
Expand All @@ -328,6 +346,7 @@ async def _exec_async():
"n=1 actors",
lambda: _exec(compiled_dag, num_args=n_args, payload_size=payload_size),
)
compiled_dag.teardown()

ray.shutdown()

Expand Down
5 changes: 0 additions & 5 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1874,11 +1874,6 @@ def shutdown(_exiting_interpreter: bool = False):
and false otherwise. If we are exiting the interpreter, we will
wait a little while to print any extra error messages.
"""
# Make sure to clean up compiled dag node if exists.
from ray.dag.compiled_dag_node import _shutdown_all_compiled_dags

_shutdown_all_compiled_dags()

if _exiting_interpreter and global_worker.mode == SCRIPT_MODE:
# This is a duration to sleep before shutting down everything in order
# to make sure that log messages finish printing.
Expand Down
41 changes: 8 additions & 33 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import weakref
import asyncio
from collections import defaultdict
from dataclasses import dataclass, asdict
Expand All @@ -9,7 +8,6 @@
import uuid
import traceback

import ray.exceptions
from ray.experimental.channel.cached_channel import CachedChannel
from ray.experimental.channel.gpu_communicator import GPUCommunicator
import ray
Expand Down Expand Up @@ -54,21 +52,6 @@

logger = logging.getLogger(__name__)

# Keep tracking of every compiled dag created during the lifetime of
# this process. It tracks them as weakref meaning when the compiled dag
# is GC'ed, it is automatically removed from here. It is used to teardown
# compiled dags at interpret shutdown time.
_compiled_dags = weakref.WeakValueDictionary()


# Relying on __del__ doesn't work well upon shutdown because
# the destructor order is not guaranteed. We call this function
# upon `ray.worker.shutdown` which is registered to atexit handler
# so that teardown is properly called before objects are destructed.
def _shutdown_all_compiled_dags():
for _, compiled_dag in _compiled_dags.items():
compiled_dag.teardown()


@DeveloperAPI
def do_allocate_channel(
Expand Down Expand Up @@ -1665,7 +1648,7 @@ def _is_same_actor(idx1: int, idx2: int) -> bool:
return False

def _monitor_failures(self):
outer = weakref.proxy(self)
outer = self

class Monitor(threading.Thread):
def __init__(self):
Expand All @@ -1674,8 +1657,6 @@ def __init__(self):
# Lock to make sure that we only perform teardown for this DAG
# once.
self.in_teardown_lock = threading.Lock()
self.name = "CompiledGraphMonitorThread"
self._teardown_done = False

def wait_teardown(self):
for actor, ref in outer.worker_task_refs.items():
Expand Down Expand Up @@ -1705,9 +1686,6 @@ def wait_teardown(self):
def teardown(self, wait: bool):
do_teardown = False
with self.in_teardown_lock:
if self._teardown_done:
return

if not self.in_teardown:
do_teardown = True
self.in_teardown = True
Expand All @@ -1731,11 +1709,9 @@ def teardown(self, wait: bool):
]
for cancel_ref in cancel_refs:
try:
# TODO(swang): Suppress exceptions from actors trying to
# read closed channels when DAG is being torn down.
ray.get(cancel_ref, timeout=30)
except ray.exceptions.RayChannelError:
# Channel error happens when a channel is closed
# or timed out. In this case, do not log.
pass
except Exception:
logger.exception("Error cancelling worker task")
pass
Expand All @@ -1748,9 +1724,6 @@ def teardown(self, wait: bool):
self.wait_teardown()
logger.info("Teardown complete")

with self.in_teardown_lock:
self._teardown_done = True

def run(self):
try:
ray.get(list(outer.worker_task_refs.values()))
Expand Down Expand Up @@ -2038,7 +2011,11 @@ def teardown(self):
def __del__(self):
monitor = getattr(self, "_monitor", None)
if monitor is not None:
monitor.teardown(wait=True)
# Teardown asynchronously.
# NOTE(swang): Somehow, this can get called after the CoreWorker
# has already been destructed, so it is not safe to block in
# ray.get.
monitor.teardown(wait=False)


@DeveloperAPI
Expand Down Expand Up @@ -2067,6 +2044,4 @@ def _build_compiled_dag(node):
root = dag._find_root()
root.traverse_and_apply(_build_compiled_dag)
compiled_dag._get_or_compile()
global _compiled_dags
_compiled_dags[compiled_dag.get_id()] = compiled_dag
return compiled_dag
Loading

0 comments on commit b7b5a32

Please sign in to comment.