diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 3f6ab66feffd..c755ad3f3aa8 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -346,6 +346,7 @@ def _wrap_user_method_call(self, request_metadata: RequestMetadata): ray.serve.context._RequestContext( route=request_metadata.route, request_id=request_metadata.request_id, + _internal_request_id=request_metadata.internal_request_id, app_name=self._deployment_id.app_name, multiplexed_model_id=request_metadata.multiplexed_model_id, grpc_context=request_metadata.grpc_context, @@ -359,6 +360,15 @@ def _wrap_user_method_call(self, request_metadata: RequestMetadata): yield except asyncio.CancelledError as e: user_exception = e + + # Recursively cancel child requests + requests_pending_assignment = ( + ray.serve.context._get_requests_pending_assignment( + request_metadata.internal_request_id + ) + ) + for task in requests_pending_assignment.values(): + task.cancel() except Exception as e: user_exception = e logger.error(f"Request failed:\n{e}") diff --git a/python/ray/serve/_private/replica_result.py b/python/ray/serve/_private/replica_result.py index 351862fafd1f..5a20de9fed41 100644 --- a/python/ray/serve/_private/replica_result.py +++ b/python/ray/serve/_private/replica_result.py @@ -25,7 +25,7 @@ async def __anext__(self): raise NotImplementedError @abstractmethod - def add_callback(self, callback: Callable): + def add_done_callback(self, callback: Callable): raise NotImplementedError @abstractmethod @@ -133,7 +133,7 @@ async def __anext__(self): next_obj_ref = await self._obj_ref_gen.__anext__() return await next_obj_ref - def add_callback(self, callback: Callable): + def add_done_callback(self, callback: Callable): if self._obj_ref_gen is not None: self._obj_ref_gen.completed()._on_completed(callback) else: diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index 1aae3dbaac28..829777107643 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -2,6 +2,7 @@ import logging import threading import time +import uuid from collections import defaultdict from contextlib import contextmanager from functools import partial @@ -500,7 +501,11 @@ async def _resolve_deployment_responses( return new_args, new_kwargs def _process_finished_request( - self, replica_id: ReplicaID, result: Union[Any, RayError] + self, + replica_id: ReplicaID, + parent_request_id: str, + response_id: str, + result: Union[Any, RayError], ): self._metrics_manager.dec_num_running_requests_for_replica(replica_id) if isinstance(result, ActorDiedError): @@ -594,6 +599,17 @@ async def assign_request( ) -> ReplicaResult: """Assign a request to a replica and return the resulting object_ref.""" + response_id = uuid.uuid4() + assign_request_task = asyncio.current_task() + ray.serve.context._add_request_pending_assignment( + request_meta.internal_request_id, response_id, assign_request_task + ) + assign_request_task.add_done_callback( + lambda _: ray.serve.context._remove_request_pending_assignment( + request_meta.internal_request_id, response_id + ) + ) + with self._metrics_manager.wrap_request_assignment(request_meta): # Optimization: if there are currently zero replicas for a deployment, # push handle metric to controller to allow for fast cold start time. @@ -602,12 +618,12 @@ async def assign_request( ): self._metrics_manager.push_autoscaling_metrics_to_controller() - ref = None + replica_result = None try: request_args, request_kwargs = await self._resolve_deployment_responses( request_args, request_kwargs ) - ref, replica_id = await self.schedule_and_send_request( + replica_result, replica_id = await self.schedule_and_send_request( PendingRequest( args=list(request_args), kwargs=request_kwargs, @@ -617,19 +633,26 @@ async def assign_request( # Keep track of requests that have been sent out to replicas if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE: + _request_context = ray.serve.context._serve_request_context.get() + request_id: str = _request_context.request_id self._metrics_manager.inc_num_running_requests_for_replica( replica_id ) - callback = partial(self._process_finished_request, replica_id) - ref.add_callback(callback) + callback = partial( + self._process_finished_request, + replica_id, + request_id, + response_id, + ) + replica_result.add_done_callback(callback) - return ref + return replica_result except asyncio.CancelledError: # NOTE(edoakes): this is not strictly necessary because # there are currently no `await` statements between # getting the ref and returning, but I'm adding it defensively. - if ref is not None: - ref.cancel() + if replica_result is not None: + replica_result.cancel() raise diff --git a/python/ray/serve/context.py b/python/ray/serve/context.py index 32b56f8ffce1..e87e052e4200 100644 --- a/python/ray/serve/context.py +++ b/python/ray/serve/context.py @@ -3,10 +3,12 @@ can use this state to access metadata or the Serve controller. """ +import asyncio import contextvars import logging +from collections import defaultdict from dataclasses import dataclass -from typing import Callable, Optional +from typing import Callable, Dict, Optional import ray from ray.exceptions import RayActorError @@ -205,3 +207,42 @@ def _set_request_context( or current_request_context.multiplexed_model_id, ) ) + + +# `_requests_pending_assignment` is a map from request ID to a +# dictionary of asyncio tasks. +# The request ID points to an ongoing request that is executing on the +# current replica, and the asyncio tasks are ongoing tasks started on +# the router to assign child requests to downstream replicas. + +# A dictionary is used over a set to track the asyncio tasks for more +# efficient addition and deletion time complexity. A uniquely generated +# `response_id` is used to identify each task. + +_requests_pending_assignment: Dict[str, Dict[str, asyncio.Task]] = defaultdict(dict) + + +# Note that the functions below that manipulate +# `_requests_pending_assignment` are NOT thread-safe. They are only +# expected to be called from the same thread/asyncio event-loop. + + +def _get_requests_pending_assignment(parent_request_id: str) -> Dict[str, asyncio.Task]: + if parent_request_id in _requests_pending_assignment: + return _requests_pending_assignment[parent_request_id] + + return {} + + +def _add_request_pending_assignment(parent_request_id: str, response_id: str, task): + # NOTE: `parent_request_id` is the `internal_request_id` corresponding + # to an ongoing Serve request, so it is always non-empty. + _requests_pending_assignment[parent_request_id][response_id] = task + + +def _remove_request_pending_assignment(parent_request_id: str, response_id: str): + if response_id in _requests_pending_assignment[parent_request_id]: + del _requests_pending_assignment[parent_request_id][response_id] + + if len(_requests_pending_assignment[parent_request_id]) == 0: + del _requests_pending_assignment[parent_request_id] diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 944adbd99a3b..2f43591c9047 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -630,12 +630,6 @@ async def __call__(self, limit: int) -> AsyncIterator[int]: `DeploymentHandle` call. """ - def __init__( - self, - object_ref_future: concurrent.futures.Future, - ): - super().__init__(object_ref_future) - def __await__(self): raise TypeError( "`DeploymentResponseGenerator` cannot be awaited directly. Use `async for` " diff --git a/python/ray/serve/tests/test_cancellation.py b/python/ray/serve/tests/test_cancellation.py index aec89913eda9..e2786c6ee672 100644 --- a/python/ray/serve/tests/test_cancellation.py +++ b/python/ray/serve/tests/test_cancellation.py @@ -14,7 +14,7 @@ async_wait_for_condition, wait_for_condition, ) -from ray.serve._private.test_utils import send_signal_on_cancellation +from ray.serve._private.test_utils import send_signal_on_cancellation, tlog @pytest.mark.parametrize("use_fastapi", [False, True]) @@ -365,5 +365,104 @@ async def get_out_of_band_response(self): assert h.get_out_of_band_response.remote().result() == "ok" +def test_recursive_cancellation_during_execution(serve_instance): + inner_signal_actor = SignalActor.remote() + outer_signal_actor = SignalActor.remote() + + @serve.deployment + async def inner(): + await send_signal_on_cancellation(inner_signal_actor) + + @serve.deployment + class Ingress: + def __init__(self, handle): + self._handle = handle + + async def __call__(self): + _ = self._handle.remote() + await send_signal_on_cancellation(outer_signal_actor) + + h = serve.run(Ingress.bind(inner.bind())) + + resp = h.remote() + with pytest.raises(TimeoutError): + resp.result(timeout_s=0.5) + + resp.cancel() + ray.get(inner_signal_actor.wait.remote(), timeout=10) + ray.get(outer_signal_actor.wait.remote(), timeout=10) + + +def test_recursive_cancellation_during_assignment(serve_instance): + signal = SignalActor.remote() + + @serve.deployment(max_ongoing_requests=1) + class Counter: + def __init__(self): + self._count = 0 + + async def __call__(self): + self._count += 1 + await signal.wait.remote() + + def get_count(self): + return self._count + + @serve.deployment + class Ingress: + def __init__(self, handle): + self._handle = handle + + async def __call__(self): + self._handle.remote() + await signal.wait.remote() + return "hi" + + async def get_count(self): + return await self._handle.get_count.remote() + + async def check_requests_pending_assignment_cache(self): + requests_pending_assignment = ray.serve.context._requests_pending_assignment + return {k: list(v.keys()) for k, v in requests_pending_assignment.items()} + + h = serve.run(Ingress.bind(Counter.bind())) + + # Send two requests to Ingress. The second should be queued and + # pending assignment at Ingress because max ongoing requests for + # Counter is only 1. + tlog("Sending two requests to Ingress.") + resp1 = h.remote() + with pytest.raises(TimeoutError): + resp1.result(timeout_s=0.5) + resp2 = h.remote() + with pytest.raises(TimeoutError): + resp2.result(timeout_s=0.5) + + # Cancel second request, which should be pending assignment. + tlog("Canceling second request.") + resp2.cancel() + + # Release signal so that the first request can complete, and any new + # requests to Counter can be let through + tlog("Releasing signal.") + ray.get(signal.send.remote()) + assert resp1.result() == "hi" + + # The second request, even though it was pending assignment to a + # Counter replica, should have been properly canceled. Confirm this + # by making sure that no more calls to __call__ were made + for _ in range(10): + assert h.get_count.remote().result() == 1 + + tlog("Confirmed second request was properly canceled.") + + # Check that cache was cleared so there are no memory leaks + requests_pending_assignment = ( + h.check_requests_pending_assignment_cache.remote().result() + ) + for k, v in requests_pending_assignment.items(): + assert len(v) == 0, f"Request {k} has in flight requests in cache: {v}" + + if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/unit/test_router.py b/python/ray/serve/tests/unit/test_router.py index 3870ac73f046..71a43f8b468b 100644 --- a/python/ray/serve/tests/unit/test_router.py +++ b/python/ray/serve/tests/unit/test_router.py @@ -51,7 +51,7 @@ def __next__(self): async def __anext__(self): raise NotImplementedError - def add_callback(self, callback: Callable): + def add_done_callback(self, callback: Callable): pass def cancel(self):