Skip to content

Commit

Permalink
[serve] Initial version of local_testing_mode (ray-project#48477)
Browse files Browse the repository at this point in the history
Adds initial support for a (currently private) `_local_testing_mode`
flag to `serve.run`. The intention for this flag is to enable users to
write unit tests for their application and model composition logic.

User code for each deployment will be run in a background thread using
the `UserCallableWrapper` that runs in replica actors. There is a new
`Router` and `ReplicaResult` implementation that interacts with the user
code to enable the existing `DeploymentHandle` code to work.

Before merging:

- [x] Figure out why `RAY_SERVE_FORCE_LOCAL_TESTING_MODE=1 pytest -vs
tests/test_handle_cancellation.py` hangs locally

Follow-ups to this PR:

- Fix blocking .result() in resolving args for composition
- Make get_replica_context() and other auxiliary APIs work
- Support FastAPI TestClient

---------

Signed-off-by: Edward Oakes <[email protected]>
  • Loading branch information
edoakes authored Nov 5, 2024
1 parent 130cb3d commit 81cf6d8
Show file tree
Hide file tree
Showing 19 changed files with 651 additions and 88 deletions.
2 changes: 2 additions & 0 deletions python/ray/serve/_private/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,8 @@ class RequestMetadata:
# internal_request_id is always generated by the proxy and is used for tracking
# request objects. We can assume this is always unique between requests.
internal_request_id: str

# Method of the user callable to execute.
call_method: str = "__call__"

# HTTP route path of the request.
Expand Down
6 changes: 6 additions & 0 deletions python/ray/serve/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,9 @@
RAY_SERVE_USE_COMPACT_SCHEDULING_STRATEGY = (
os.environ.get("RAY_SERVE_USE_COMPACT_SCHEDULING_STRATEGY", "0") == "1"
)

# Feature flag to always override local_testing_mode to True in serve.run.
# This is used for internal testing to avoid passing the flag to every invocation.
RAY_SERVE_FORCE_LOCAL_TESTING_MODE = (
os.environ.get("RAY_SERVE_FORCE_LOCAL_TESTING_MODE", "0") == "1"
)
313 changes: 313 additions & 0 deletions python/ray/serve/_private/local_testing_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
import asyncio
import concurrent.futures
import inspect
import logging
import queue
import time
from functools import wraps
from typing import Any, Callable, Coroutine, Dict, Optional, Tuple, Union

import ray
from ray import cloudpickle
from ray.serve._private.common import DeploymentID, RequestMetadata
from ray.serve._private.constants import SERVE_LOGGER_NAME
from ray.serve._private.replica import UserCallableWrapper
from ray.serve._private.replica_result import ReplicaResult
from ray.serve._private.router import Router
from ray.serve._private.utils import GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR
from ray.serve.deployment import Deployment
from ray.serve.exceptions import RequestCancelledError
from ray.serve.handle import (
DeploymentHandle,
DeploymentResponse,
DeploymentResponseGenerator,
)

logger = logging.getLogger(SERVE_LOGGER_NAME)


def _validate_deployment_options(
deployment: Deployment,
deployment_id: DeploymentID,
):
if "num_gpus" in deployment.ray_actor_options:
logger.warning(
f"Deployment {deployment_id} has num_gpus configured. "
"CUDA_VISIBLE_DEVICES is not managed automatically in local testing mode. "
)

if "runtime_env" in deployment.ray_actor_options:
logger.warning(
f"Deployment {deployment_id} has runtime_env configured. "
"runtime_envs are ignored in local testing mode."
)


def make_local_deployment_handle(
deployment: Deployment,
app_name: str,
) -> DeploymentHandle:
"""Constructs an in-process DeploymentHandle.
This is used in the application build process for local testing mode,
where all deployments of an app run in the local process which enables
faster dev iterations and use of tooling like PDB.
The user callable will be run on an asyncio loop in a separate thread
(sharing the same code that's used in the replica).
The constructor for the user callable is run eagerly in this function to
ensure that any exceptions are raised during `serve.run`.
"""
deployment_id = DeploymentID(deployment.name, app_name)
_validate_deployment_options(deployment, deployment_id)
user_callable_wrapper = UserCallableWrapper(
deployment.func_or_class,
deployment.init_args,
deployment.init_kwargs,
deployment_id=deployment_id,
)
try:
logger.info(f"Initializing local replica class for {deployment_id}.")
user_callable_wrapper.initialize_callable().result()
except Exception:
logger.exception(f"Failed to initialize deployment {deployment_id}.")
raise

def _create_local_router(
handle_id: str, deployment_id: DeploymentID, handle_options: Any
) -> Router:
return LocalRouter(
user_callable_wrapper,
deployment_id=deployment_id,
handle_options=handle_options,
)

return DeploymentHandle(
deployment.name,
app_name,
_create_router=_create_local_router,
)


class LocalReplicaResult(ReplicaResult):
"""ReplicaResult used by in-process Deployment Handles."""

OBJ_REF_NOT_SUPPORTED_ERROR = RuntimeError(
"Converting DeploymentResponses to ObjectRefs is not supported "
"in local testing mode."
)

def __init__(
self,
future: concurrent.futures.Future,
*,
request_id: str,
is_streaming: bool = False,
generator_result_queue: Optional[queue.Queue] = None,
):
self._future = future
self._lazy_asyncio_future = None
self._request_id = request_id
self._is_streaming = is_streaming

# For streaming requests, results must be written to this queue.
# The queue will be consumed until the future is completed.
self._generator_result_queue = generator_result_queue
if self._is_streaming:
assert (
self._generator_result_queue is not None
), "generator_result_queue must be provided for streaming results."

@property
def _asyncio_future(self) -> asyncio.Future:
if self._lazy_asyncio_future is None:
self._lazy_asyncio_future = asyncio.wrap_future(self._future)

return self._lazy_asyncio_future

def _process_response(f: Union[Callable, Coroutine]):
@wraps(f)
def wrapper(self, *args, **kwargs):
try:
return f(self, *args, **kwargs)
except (asyncio.CancelledError, concurrent.futures.CancelledError):
raise RequestCancelledError(self._request_id)

@wraps(f)
async def async_wrapper(self, *args, **kwargs):
try:
return await f(self, *args, **kwargs)
except (asyncio.CancelledError, concurrent.futures.CancelledError):
raise RequestCancelledError(self._request_id)

if inspect.iscoroutinefunction(f):
return async_wrapper
else:
return wrapper

@_process_response
def get(self, timeout_s: Optional[float]):
assert (
not self._is_streaming
), "get() can only be called on a non-streaming result."

try:
return self._future.result(timeout=timeout_s)
except concurrent.futures.TimeoutError:
raise TimeoutError("Timed out waiting for result.")

@_process_response
async def get_async(self):
assert (
not self._is_streaming
), "get_async() can only be called on a non-streaming result."

return await self._asyncio_future

@_process_response
def __next__(self):
assert self._is_streaming, "next() can only be called on a streaming result."

while True:
if self._future.done() and self._generator_result_queue.empty():
if self._future.exception():
raise self._future.exception()
else:
raise StopIteration

try:
return self._generator_result_queue.get(timeout=0.01)
except queue.Empty:
pass

@_process_response
async def __anext__(self):
assert self._is_streaming, "anext() can only be called on a streaming result."

# This callback does not pull from the queue, only checks that a result is
# available, else there is a race condition where the future finishes and the
# queue is empty, but this function hasn't returned the result yet.
def _wait_for_result():
while True:
if self._future.done() or not self._generator_result_queue.empty():
return
time.sleep(0.01)

wait_for_result_task = asyncio.get_running_loop().create_task(
asyncio.to_thread(_wait_for_result),
)
done, _ = await asyncio.wait(
[self._asyncio_future, wait_for_result_task],
return_when=asyncio.FIRST_COMPLETED,
)

if not self._generator_result_queue.empty():
return self._generator_result_queue.get()

if self._asyncio_future.exception():
raise self._asyncio_future.exception()

raise StopAsyncIteration

def add_done_callback(self, callback: Callable):
self._future.add_done_callback(callback)

def cancel(self):
self._future.cancel()

def to_object_ref(self, timeout_s: Optional[float]) -> ray.ObjectRef:
raise self.OBJ_REF_NOT_SUPPORTED_ERROR

async def to_object_ref_async(self) -> ray.ObjectRef:
raise self.OBJ_REF_NOT_SUPPORTED_ERROR

def to_object_ref_gen(self) -> ray.ObjectRefGenerator:
raise self.OBJ_REF_NOT_SUPPORTED_ERROR


class LocalRouter(Router):
def __init__(
self,
user_callable_wrapper: UserCallableWrapper,
deployment_id: DeploymentID,
handle_options: Any,
):
self._deployment_id = deployment_id
self._user_callable_wrapper = user_callable_wrapper
assert (
self._user_callable_wrapper._callable is not None
), "User callable must already be initialized."

def running_replicas_populated(self) -> bool:
return True

def _resolve_deployment_responses(
self, request_args: Tuple[Any], request_kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any], Dict[str, Any]]:
"""Replace DeploymentResponse objects with their results.
NOTE(edoakes): this currently calls the blocking `.result()` method
on the responses to resolve them to their values. This is a divergence
from the remote codepath where they're resolved concurrently.
"""

def _new_arg(arg: Any) -> Any:
if isinstance(arg, DeploymentResponse):
new_arg = arg.result(_skip_asyncio_check=True)
elif isinstance(arg, DeploymentResponseGenerator):
raise GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR
else:
new_arg = arg

return new_arg

# Serialize and deserialize the arguments to mimic remote call behavior.
return cloudpickle.loads(
cloudpickle.dumps(
(
tuple(_new_arg(arg) for arg in request_args),
{k: _new_arg(v) for k, v in request_kwargs.items()},
)
)
)

def assign_request(
self,
request_meta: RequestMetadata,
*request_args,
**request_kwargs,
) -> concurrent.futures.Future[LocalReplicaResult]:
request_args, request_kwargs = self._resolve_deployment_responses(
request_args, request_kwargs
)

if request_meta.is_streaming:
generator_result_queue = queue.Queue()

def generator_result_callback(item: Any):
generator_result_queue.put_nowait(item)

else:
generator_result_queue = None
generator_result_callback = None

# Conform to the router interface of returning a future to the ReplicaResult.
noop_future = concurrent.futures.Future()
noop_future.set_result(
LocalReplicaResult(
self._user_callable_wrapper.call_user_method(
request_meta,
request_args,
request_kwargs,
generator_result_callback=generator_result_callback,
),
request_id=request_meta.request_id,
is_streaming=request_meta.is_streaming,
generator_result_queue=generator_result_queue,
)
)
return noop_future

def shutdown(self):
pass
14 changes: 9 additions & 5 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Tuple, Union

import starlette.responses
from starlette.types import ASGIApp
from starlette.types import ASGIApp, Message

import ray
from ray import cloudpickle
Expand Down Expand Up @@ -464,7 +464,7 @@ async def _call_user_generator(

# `asyncio.Event`s are not thread safe, so `call_soon_threadsafe` must be
# used to interact with the result queue from the user callable thread.
async def _enqueue_thread_safe(item: Any):
def _enqueue_thread_safe(item: Any):
self._event_loop.call_soon_threadsafe(result_queue.put_nowait, item)

call_user_method_future = asyncio.wrap_future(
Expand Down Expand Up @@ -1064,10 +1064,14 @@ def _prepare_args_for_http_request(
receive_task = self._user_code_event_loop.create_task(
receive.fetch_until_disconnect()
)

async def _send(message: Message):
return generator_result_callback(message)

asgi_args = ASGIArgs(
scope=scope,
receive=receive,
send=generator_result_callback,
send=_send,
)
if is_asgi_app:
request_args = asgi_args.to_args_tuple()
Expand Down Expand Up @@ -1128,12 +1132,12 @@ async def _handle_user_method_result(
for r in result:
if request_metadata.is_grpc_request:
r = (request_metadata.grpc_context, r.SerializeToString())
await generator_result_callback(r)
generator_result_callback(r)
elif result_is_async_gen:
async for r in result:
if request_metadata.is_grpc_request:
r = (request_metadata.grpc_context, r.SerializeToString())
await generator_result_callback(r)
generator_result_callback(r)
elif request_metadata.is_http_request and not is_asgi_app:
# For the FastAPI codepath, the response has already been sent over
# ASGI, but for the vanilla deployment codepath we need to send it.
Expand Down
Loading

0 comments on commit 81cf6d8

Please sign in to comment.