Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serve] recursive cancellation #47873

Merged
merged 19 commits into from
Oct 29, 2024
Merged
10 changes: 10 additions & 0 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def _wrap_user_method_call(self, request_metadata: RequestMetadata):
ray.serve.context._RequestContext(
route=request_metadata.route,
request_id=request_metadata.request_id,
_internal_request_id=request_metadata.internal_request_id,
app_name=self._deployment_id.app_name,
multiplexed_model_id=request_metadata.multiplexed_model_id,
grpc_context=request_metadata.grpc_context,
Expand All @@ -359,6 +360,15 @@ def _wrap_user_method_call(self, request_metadata: RequestMetadata):
yield
except asyncio.CancelledError as e:
user_exception = e

# Recursively cancel child requests
requests_pending_assignment = (
ray.serve.context._get_requests_pending_assignment(
request_metadata.internal_request_id
)
)
for task in requests_pending_assignment.values():
task.cancel()
except Exception as e:
user_exception = e
logger.error(f"Request failed:\n{e}")
Expand Down
4 changes: 2 additions & 2 deletions python/ray/serve/_private/replica_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def __anext__(self):
raise NotImplementedError

@abstractmethod
def add_callback(self, callback: Callable):
def add_done_callback(self, callback: Callable):
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -133,7 +133,7 @@ async def __anext__(self):
next_obj_ref = await self._obj_ref_gen.__anext__()
return await next_obj_ref

def add_callback(self, callback: Callable):
def add_done_callback(self, callback: Callable):
if self._obj_ref_gen is not None:
self._obj_ref_gen.completed()._on_completed(callback)
else:
Expand Down
39 changes: 31 additions & 8 deletions python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import threading
import time
import uuid
from collections import defaultdict
from contextlib import contextmanager
from functools import partial
Expand Down Expand Up @@ -500,7 +501,11 @@ async def _resolve_deployment_responses(
return new_args, new_kwargs

def _process_finished_request(
self, replica_id: ReplicaID, result: Union[Any, RayError]
self,
replica_id: ReplicaID,
parent_request_id: str,
response_id: str,
result: Union[Any, RayError],
):
self._metrics_manager.dec_num_running_requests_for_replica(replica_id)
if isinstance(result, ActorDiedError):
Expand Down Expand Up @@ -594,6 +599,17 @@ async def assign_request(
) -> ReplicaResult:
"""Assign a request to a replica and return the resulting object_ref."""

response_id = uuid.uuid4()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocker: There is this generate_request_id() maybe we can rename and reuse here?

assign_request_task = asyncio.current_task()
ray.serve.context._add_request_pending_assignment(
request_meta.internal_request_id, response_id, assign_request_task
)
assign_request_task.add_done_callback(
lambda _: ray.serve.context._remove_request_pending_assignment(
request_meta.internal_request_id, response_id
)
)

with self._metrics_manager.wrap_request_assignment(request_meta):
# Optimization: if there are currently zero replicas for a deployment,
# push handle metric to controller to allow for fast cold start time.
Expand All @@ -602,12 +618,12 @@ async def assign_request(
):
self._metrics_manager.push_autoscaling_metrics_to_controller()

ref = None
replica_result = None
try:
request_args, request_kwargs = await self._resolve_deployment_responses(
request_args, request_kwargs
)
ref, replica_id = await self.schedule_and_send_request(
replica_result, replica_id = await self.schedule_and_send_request(
PendingRequest(
args=list(request_args),
kwargs=request_kwargs,
Expand All @@ -617,19 +633,26 @@ async def assign_request(

# Keep track of requests that have been sent out to replicas
if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE:
_request_context = ray.serve.context._serve_request_context.get()
request_id: str = _request_context.request_id
self._metrics_manager.inc_num_running_requests_for_replica(
replica_id
)
callback = partial(self._process_finished_request, replica_id)
ref.add_callback(callback)
callback = partial(
self._process_finished_request,
replica_id,
request_id,
response_id,
)
replica_result.add_done_callback(callback)

return ref
return replica_result
except asyncio.CancelledError:
# NOTE(edoakes): this is not strictly necessary because
# there are currently no `await` statements between
# getting the ref and returning, but I'm adding it defensively.
if ref is not None:
ref.cancel()
if replica_result is not None:
replica_result.cancel()

raise

Expand Down
43 changes: 42 additions & 1 deletion python/ray/serve/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
can use this state to access metadata or the Serve controller.
"""

import asyncio
import contextvars
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable, Optional
from typing import Callable, Dict, Optional

import ray
from ray.exceptions import RayActorError
Expand Down Expand Up @@ -205,3 +207,42 @@ def _set_request_context(
or current_request_context.multiplexed_model_id,
)
)


# `_requests_pending_assignment` is a map from request ID to a
# dictionary of asyncio tasks.
# The request ID points to an ongoing request that is executing on the
# current replica, and the asyncio tasks are ongoing tasks started on
# the router to assign child requests to downstream replicas.

# A dictionary is used over a set to track the asyncio tasks for more
# efficient addition and deletion time complexity. A uniquely generated
# `response_id` is used to identify each task.

_requests_pending_assignment: Dict[str, Dict[str, asyncio.Task]] = defaultdict(dict)


# Note that the functions below that manipulate
# `_requests_pending_assignment` are NOT thread-safe. They are only
# expected to be called from the same thread/asyncio event-loop.


def _get_requests_pending_assignment(parent_request_id: str) -> Dict[str, asyncio.Task]:
if parent_request_id in _requests_pending_assignment:
return _requests_pending_assignment[parent_request_id]

zcin marked this conversation as resolved.
Show resolved Hide resolved
return {}


def _add_request_pending_assignment(parent_request_id: str, response_id: str, task):
# NOTE: `parent_request_id` is the `internal_request_id` corresponding
# to an ongoing Serve request, so it is always non-empty.
_requests_pending_assignment[parent_request_id][response_id] = task


def _remove_request_pending_assignment(parent_request_id: str, response_id: str):
if response_id in _requests_pending_assignment[parent_request_id]:
del _requests_pending_assignment[parent_request_id][response_id]
zcin marked this conversation as resolved.
Show resolved Hide resolved

if len(_requests_pending_assignment[parent_request_id]) == 0:
del _requests_pending_assignment[parent_request_id]
6 changes: 0 additions & 6 deletions python/ray/serve/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,12 +630,6 @@ async def __call__(self, limit: int) -> AsyncIterator[int]:
`DeploymentHandle` call.
"""

def __init__(
self,
object_ref_future: concurrent.futures.Future,
):
super().__init__(object_ref_future)

def __await__(self):
raise TypeError(
"`DeploymentResponseGenerator` cannot be awaited directly. Use `async for` "
Expand Down
101 changes: 100 additions & 1 deletion python/ray/serve/tests/test_cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
async_wait_for_condition,
wait_for_condition,
)
from ray.serve._private.test_utils import send_signal_on_cancellation
from ray.serve._private.test_utils import send_signal_on_cancellation, tlog


@pytest.mark.parametrize("use_fastapi", [False, True])
Expand Down Expand Up @@ -365,5 +365,104 @@ async def get_out_of_band_response(self):
assert h.get_out_of_band_response.remote().result() == "ok"


def test_recursive_cancellation_during_execution(serve_instance):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in order for these tests to be meaningful, don't we need to turn off the core cancellation support?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So at first I took the approach of implementing full recursive cancellation at the Serve layer, but that leads to more of a performance drop. I think it's unnecessary to cover the part that core already covers (recursive task cancellation). We just need to take care of canceling the asyncio request assignment task before a request has been assigned to a replica.

So in this test file test_recursive_cancellation_during_execution already passes, and test_recursive_cancellation_during_assignment is the one that doesn't pass without this PR.

inner_signal_actor = SignalActor.remote()
outer_signal_actor = SignalActor.remote()

@serve.deployment
async def inner():
await send_signal_on_cancellation(inner_signal_actor)

@serve.deployment
class Ingress:
def __init__(self, handle):
self._handle = handle

async def __call__(self):
_ = self._handle.remote()
await send_signal_on_cancellation(outer_signal_actor)

h = serve.run(Ingress.bind(inner.bind()))

resp = h.remote()
with pytest.raises(TimeoutError):
resp.result(timeout_s=0.5)

resp.cancel()
ray.get(inner_signal_actor.wait.remote(), timeout=10)
ray.get(outer_signal_actor.wait.remote(), timeout=10)


def test_recursive_cancellation_during_assignment(serve_instance):
signal = SignalActor.remote()

@serve.deployment(max_ongoing_requests=1)
class Counter:
def __init__(self):
self._count = 0

async def __call__(self):
self._count += 1
await signal.wait.remote()

def get_count(self):
return self._count

@serve.deployment
class Ingress:
def __init__(self, handle):
self._handle = handle

async def __call__(self):
self._handle.remote()
await signal.wait.remote()
return "hi"

async def get_count(self):
return await self._handle.get_count.remote()

async def check_requests_pending_assignment_cache(self):
requests_pending_assignment = ray.serve.context._requests_pending_assignment
return {k: list(v.keys()) for k, v in requests_pending_assignment.items()}

h = serve.run(Ingress.bind(Counter.bind()))

# Send two requests to Ingress. The second should be queued and
# pending assignment at Ingress because max ongoing requests for
# Counter is only 1.
tlog("Sending two requests to Ingress.")
resp1 = h.remote()
with pytest.raises(TimeoutError):
resp1.result(timeout_s=0.5)
resp2 = h.remote()
with pytest.raises(TimeoutError):
resp2.result(timeout_s=0.5)

# Cancel second request, which should be pending assignment.
tlog("Canceling second request.")
resp2.cancel()

# Release signal so that the first request can complete, and any new
# requests to Counter can be let through
tlog("Releasing signal.")
ray.get(signal.send.remote())
assert resp1.result() == "hi"

# The second request, even though it was pending assignment to a
# Counter replica, should have been properly canceled. Confirm this
# by making sure that no more calls to __call__ were made
for _ in range(10):
assert h.get_count.remote().result() == 1

tlog("Confirmed second request was properly canceled.")

# Check that cache was cleared so there are no memory leaks
requests_pending_assignment = (
h.check_requests_pending_assignment_cache.remote().result()
)
for k, v in requests_pending_assignment.items():
assert len(v) == 0, f"Request {k} has in flight requests in cache: {v}"


if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))
2 changes: 1 addition & 1 deletion python/ray/serve/tests/unit/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __next__(self):
async def __anext__(self):
raise NotImplementedError

def add_callback(self, callback: Callable):
def add_done_callback(self, callback: Callable):
pass

def cancel(self):
Expand Down