Skip to content

Commit

Permalink
[serve] raise RequestCancelledError when request is cancelled during …
Browse files Browse the repository at this point in the history
…assignment (ray-project#48496)

## Why are these changes needed?

Currently `ray.serve.exceptions.RequestCancelledError` is raised only
when a request is cancelled during execution. We should also raise
`RequestCancelledError` when a request is cancelled during assignment.

---------

Signed-off-by: Cindy Zhang <[email protected]>
  • Loading branch information
zcin authored Nov 4, 2024
1 parent 2cdf967 commit d5d03e6
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 30 deletions.
70 changes: 44 additions & 26 deletions python/ray/serve/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
inside_ray_client_context,
is_running_in_asyncio_loop,
)
from ray.serve.exceptions import RayServeException
from ray.serve.exceptions import RayServeException, RequestCancelledError
from ray.util import metrics
from ray.util.annotations import DeveloperAPI, PublicAPI

Expand Down Expand Up @@ -193,29 +193,16 @@ def _options(self, _prefer_local_routing=DEFAULT.VALUE, **kwargs):
)

def _remote(
self, args: Tuple[Any], kwargs: Dict[str, Any]
self,
request_metadata: RequestMetadata,
args: Tuple[Any],
kwargs: Dict[str, Any],
) -> concurrent.futures.Future:
self._record_telemetry_if_needed()
_request_context = ray.serve.context._serve_request_context.get()
request_metadata = RequestMetadata(
request_id=_request_context.request_id
if _request_context.request_id
else generate_request_id(),
internal_request_id=_request_context._internal_request_id
if _request_context._internal_request_id
else generate_request_id(),
call_method=self.handle_options.method_name,
route=_request_context.route,
app_name=self.app_name,
multiplexed_model_id=self.handle_options.multiplexed_model_id,
is_streaming=self.handle_options.stream,
_request_protocol=self.handle_options._request_protocol,
grpc_context=_request_context.grpc_context,
)
self.request_counter.inc(
tags={
"route": _request_context.route,
"application": _request_context.app_name,
"route": request_metadata.route,
"application": request_metadata.app_name,
}
)

Expand Down Expand Up @@ -249,10 +236,19 @@ def __reduce__(self):


class _DeploymentResponseBase:
def __init__(self, replica_result_future: concurrent.futures.Future[ReplicaResult]):
def __init__(
self,
replica_result_future: concurrent.futures.Future[ReplicaResult],
request_metadata: RequestMetadata,
):
self._cancelled = False
self._replica_result_future = replica_result_future
self._replica_result: Optional[ReplicaResult] = None
self._request_metadata: RequestMetadata = request_metadata

@property
def request_id(self) -> str:
return self._request_metadata.request_id

def _fetch_future_result_sync(
self, _timeout_s: Optional[float] = None
Expand All @@ -269,6 +265,8 @@ def _fetch_future_result_sync(
)
except concurrent.futures.TimeoutError:
raise TimeoutError("Timed out resolving to ObjectRef.") from None
except concurrent.futures.CancelledError:
raise RequestCancelledError(self.request_id) from None

return self._replica_result

Expand All @@ -281,9 +279,12 @@ async def _fetch_future_result_async(self) -> ReplicaResult:
if self._replica_result is None:
# Use `asyncio.wrap_future` so `self._replica_result_future` can be awaited
# safely from any asyncio loop.
self._replica_result = await asyncio.wrap_future(
self._replica_result_future
)
try:
self._replica_result = await asyncio.wrap_future(
self._replica_result_future
)
except asyncio.CancelledError:
raise RequestCancelledError(self.request_id) from None

return self._replica_result

Expand Down Expand Up @@ -726,10 +727,27 @@ def remote(
**kwargs: Keyword arguments to be serialized and passed to the
remote method call.
"""
future = self._remote(args, kwargs)
_request_context = ray.serve.context._serve_request_context.get()
request_metadata = RequestMetadata(
request_id=_request_context.request_id
if _request_context.request_id
else generate_request_id(),
internal_request_id=_request_context._internal_request_id
if _request_context._internal_request_id
else generate_request_id(),
call_method=self.handle_options.method_name,
route=_request_context.route,
app_name=self.app_name,
multiplexed_model_id=self.handle_options.multiplexed_model_id,
is_streaming=self.handle_options.stream,
_request_protocol=self.handle_options._request_protocol,
grpc_context=_request_context.grpc_context,
)

future = self._remote(request_metadata, args, kwargs)
if self.handle_options.stream:
response_cls = DeploymentResponseGenerator
else:
response_cls = DeploymentResponse

return response_cls(future)
return response_cls(future, request_metadata)
6 changes: 2 additions & 4 deletions python/ray/serve/tests/test_handle_cancellation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import asyncio
import concurrent.futures
import sys

import pytest
Expand Down Expand Up @@ -65,7 +63,7 @@ async def __call__(self, *args):
# Make a second request, cancel it, and verify that it is cancelled.
second_response = h.remote()
second_response.cancel()
with pytest.raises(concurrent.futures.CancelledError):
with pytest.raises(RequestCancelledError):
second_response.result()

# Now signal the initial request to finish and check that the second request
Expand Down Expand Up @@ -141,7 +139,7 @@ async def one_waiter():
# Make a second request, cancel it, and verify that it is cancelled.
second_response = self._h.remote()
second_response.cancel()
with pytest.raises(asyncio.CancelledError):
with pytest.raises(RequestCancelledError):
await second_response

# Now signal the initial request to finish and check that the second request
Expand Down

0 comments on commit d5d03e6

Please sign in to comment.