diff --git a/python/ray/serve/_private/common.py b/python/ray/serve/_private/common.py index 1c0110fbbae2..8c9bc12bd984 100644 --- a/python/ray/serve/_private/common.py +++ b/python/ray/serve/_private/common.py @@ -558,6 +558,8 @@ class RequestMetadata: # internal_request_id is always generated by the proxy and is used for tracking # request objects. We can assume this is always unique between requests. internal_request_id: str + + # Method of the user callable to execute. call_method: str = "__call__" # HTTP route path of the request. diff --git a/python/ray/serve/_private/constants.py b/python/ray/serve/_private/constants.py index 08af6d52cbab..8fae61c95329 100644 --- a/python/ray/serve/_private/constants.py +++ b/python/ray/serve/_private/constants.py @@ -350,3 +350,9 @@ RAY_SERVE_USE_COMPACT_SCHEDULING_STRATEGY = ( os.environ.get("RAY_SERVE_USE_COMPACT_SCHEDULING_STRATEGY", "0") == "1" ) + +# Feature flag to always override local_testing_mode to True in serve.run. +# This is used for internal testing to avoid passing the flag to every invocation. +RAY_SERVE_FORCE_LOCAL_TESTING_MODE = ( + os.environ.get("RAY_SERVE_FORCE_LOCAL_TESTING_MODE", "0") == "1" +) diff --git a/python/ray/serve/_private/local_testing_mode.py b/python/ray/serve/_private/local_testing_mode.py new file mode 100644 index 000000000000..6ccc16cd3628 --- /dev/null +++ b/python/ray/serve/_private/local_testing_mode.py @@ -0,0 +1,313 @@ +import asyncio +import concurrent.futures +import inspect +import logging +import queue +import time +from functools import wraps +from typing import Any, Callable, Coroutine, Dict, Optional, Tuple, Union + +import ray +from ray import cloudpickle +from ray.serve._private.common import DeploymentID, RequestMetadata +from ray.serve._private.constants import SERVE_LOGGER_NAME +from ray.serve._private.replica import UserCallableWrapper +from ray.serve._private.replica_result import ReplicaResult +from ray.serve._private.router import Router +from ray.serve._private.utils import GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR +from ray.serve.deployment import Deployment +from ray.serve.exceptions import RequestCancelledError +from ray.serve.handle import ( + DeploymentHandle, + DeploymentResponse, + DeploymentResponseGenerator, +) + +logger = logging.getLogger(SERVE_LOGGER_NAME) + + +def _validate_deployment_options( + deployment: Deployment, + deployment_id: DeploymentID, +): + if "num_gpus" in deployment.ray_actor_options: + logger.warning( + f"Deployment {deployment_id} has num_gpus configured. " + "CUDA_VISIBLE_DEVICES is not managed automatically in local testing mode. " + ) + + if "runtime_env" in deployment.ray_actor_options: + logger.warning( + f"Deployment {deployment_id} has runtime_env configured. " + "runtime_envs are ignored in local testing mode." + ) + + +def make_local_deployment_handle( + deployment: Deployment, + app_name: str, +) -> DeploymentHandle: + """Constructs an in-process DeploymentHandle. + + This is used in the application build process for local testing mode, + where all deployments of an app run in the local process which enables + faster dev iterations and use of tooling like PDB. + + The user callable will be run on an asyncio loop in a separate thread + (sharing the same code that's used in the replica). + + The constructor for the user callable is run eagerly in this function to + ensure that any exceptions are raised during `serve.run`. + """ + deployment_id = DeploymentID(deployment.name, app_name) + _validate_deployment_options(deployment, deployment_id) + user_callable_wrapper = UserCallableWrapper( + deployment.func_or_class, + deployment.init_args, + deployment.init_kwargs, + deployment_id=deployment_id, + ) + try: + logger.info(f"Initializing local replica class for {deployment_id}.") + user_callable_wrapper.initialize_callable().result() + except Exception: + logger.exception(f"Failed to initialize deployment {deployment_id}.") + raise + + def _create_local_router( + handle_id: str, deployment_id: DeploymentID, handle_options: Any + ) -> Router: + return LocalRouter( + user_callable_wrapper, + deployment_id=deployment_id, + handle_options=handle_options, + ) + + return DeploymentHandle( + deployment.name, + app_name, + _create_router=_create_local_router, + ) + + +class LocalReplicaResult(ReplicaResult): + """ReplicaResult used by in-process Deployment Handles.""" + + OBJ_REF_NOT_SUPPORTED_ERROR = RuntimeError( + "Converting DeploymentResponses to ObjectRefs is not supported " + "in local testing mode." + ) + + def __init__( + self, + future: concurrent.futures.Future, + *, + request_id: str, + is_streaming: bool = False, + generator_result_queue: Optional[queue.Queue] = None, + ): + self._future = future + self._lazy_asyncio_future = None + self._request_id = request_id + self._is_streaming = is_streaming + + # For streaming requests, results must be written to this queue. + # The queue will be consumed until the future is completed. + self._generator_result_queue = generator_result_queue + if self._is_streaming: + assert ( + self._generator_result_queue is not None + ), "generator_result_queue must be provided for streaming results." + + @property + def _asyncio_future(self) -> asyncio.Future: + if self._lazy_asyncio_future is None: + self._lazy_asyncio_future = asyncio.wrap_future(self._future) + + return self._lazy_asyncio_future + + def _process_response(f: Union[Callable, Coroutine]): + @wraps(f) + def wrapper(self, *args, **kwargs): + try: + return f(self, *args, **kwargs) + except (asyncio.CancelledError, concurrent.futures.CancelledError): + raise RequestCancelledError(self._request_id) + + @wraps(f) + async def async_wrapper(self, *args, **kwargs): + try: + return await f(self, *args, **kwargs) + except (asyncio.CancelledError, concurrent.futures.CancelledError): + raise RequestCancelledError(self._request_id) + + if inspect.iscoroutinefunction(f): + return async_wrapper + else: + return wrapper + + @_process_response + def get(self, timeout_s: Optional[float]): + assert ( + not self._is_streaming + ), "get() can only be called on a non-streaming result." + + try: + return self._future.result(timeout=timeout_s) + except concurrent.futures.TimeoutError: + raise TimeoutError("Timed out waiting for result.") + + @_process_response + async def get_async(self): + assert ( + not self._is_streaming + ), "get_async() can only be called on a non-streaming result." + + return await self._asyncio_future + + @_process_response + def __next__(self): + assert self._is_streaming, "next() can only be called on a streaming result." + + while True: + if self._future.done() and self._generator_result_queue.empty(): + if self._future.exception(): + raise self._future.exception() + else: + raise StopIteration + + try: + return self._generator_result_queue.get(timeout=0.01) + except queue.Empty: + pass + + @_process_response + async def __anext__(self): + assert self._is_streaming, "anext() can only be called on a streaming result." + + # This callback does not pull from the queue, only checks that a result is + # available, else there is a race condition where the future finishes and the + # queue is empty, but this function hasn't returned the result yet. + def _wait_for_result(): + while True: + if self._future.done() or not self._generator_result_queue.empty(): + return + time.sleep(0.01) + + wait_for_result_task = asyncio.get_running_loop().create_task( + asyncio.to_thread(_wait_for_result), + ) + done, _ = await asyncio.wait( + [self._asyncio_future, wait_for_result_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + if not self._generator_result_queue.empty(): + return self._generator_result_queue.get() + + if self._asyncio_future.exception(): + raise self._asyncio_future.exception() + + raise StopAsyncIteration + + def add_done_callback(self, callback: Callable): + self._future.add_done_callback(callback) + + def cancel(self): + self._future.cancel() + + def to_object_ref(self, timeout_s: Optional[float]) -> ray.ObjectRef: + raise self.OBJ_REF_NOT_SUPPORTED_ERROR + + async def to_object_ref_async(self) -> ray.ObjectRef: + raise self.OBJ_REF_NOT_SUPPORTED_ERROR + + def to_object_ref_gen(self) -> ray.ObjectRefGenerator: + raise self.OBJ_REF_NOT_SUPPORTED_ERROR + + +class LocalRouter(Router): + def __init__( + self, + user_callable_wrapper: UserCallableWrapper, + deployment_id: DeploymentID, + handle_options: Any, + ): + self._deployment_id = deployment_id + self._user_callable_wrapper = user_callable_wrapper + assert ( + self._user_callable_wrapper._callable is not None + ), "User callable must already be initialized." + + def running_replicas_populated(self) -> bool: + return True + + def _resolve_deployment_responses( + self, request_args: Tuple[Any], request_kwargs: Dict[str, Any] + ) -> Tuple[Tuple[Any], Dict[str, Any]]: + """Replace DeploymentResponse objects with their results. + + NOTE(edoakes): this currently calls the blocking `.result()` method + on the responses to resolve them to their values. This is a divergence + from the remote codepath where they're resolved concurrently. + """ + + def _new_arg(arg: Any) -> Any: + if isinstance(arg, DeploymentResponse): + new_arg = arg.result(_skip_asyncio_check=True) + elif isinstance(arg, DeploymentResponseGenerator): + raise GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR + else: + new_arg = arg + + return new_arg + + # Serialize and deserialize the arguments to mimic remote call behavior. + return cloudpickle.loads( + cloudpickle.dumps( + ( + tuple(_new_arg(arg) for arg in request_args), + {k: _new_arg(v) for k, v in request_kwargs.items()}, + ) + ) + ) + + def assign_request( + self, + request_meta: RequestMetadata, + *request_args, + **request_kwargs, + ) -> concurrent.futures.Future[LocalReplicaResult]: + request_args, request_kwargs = self._resolve_deployment_responses( + request_args, request_kwargs + ) + + if request_meta.is_streaming: + generator_result_queue = queue.Queue() + + def generator_result_callback(item: Any): + generator_result_queue.put_nowait(item) + + else: + generator_result_queue = None + generator_result_callback = None + + # Conform to the router interface of returning a future to the ReplicaResult. + noop_future = concurrent.futures.Future() + noop_future.set_result( + LocalReplicaResult( + self._user_callable_wrapper.call_user_method( + request_meta, + request_args, + request_kwargs, + generator_result_callback=generator_result_callback, + ), + request_id=request_meta.request_id, + is_streaming=request_meta.is_streaming, + generator_result_queue=generator_result_queue, + ) + ) + return noop_future + + def shutdown(self): + pass diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index c390c05bc055..b90c837b6cc0 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -13,7 +13,7 @@ from typing import Any, AsyncGenerator, Callable, Dict, Optional, Tuple, Union import starlette.responses -from starlette.types import ASGIApp +from starlette.types import ASGIApp, Message import ray from ray import cloudpickle @@ -464,7 +464,7 @@ async def _call_user_generator( # `asyncio.Event`s are not thread safe, so `call_soon_threadsafe` must be # used to interact with the result queue from the user callable thread. - async def _enqueue_thread_safe(item: Any): + def _enqueue_thread_safe(item: Any): self._event_loop.call_soon_threadsafe(result_queue.put_nowait, item) call_user_method_future = asyncio.wrap_future( @@ -1064,10 +1064,14 @@ def _prepare_args_for_http_request( receive_task = self._user_code_event_loop.create_task( receive.fetch_until_disconnect() ) + + async def _send(message: Message): + return generator_result_callback(message) + asgi_args = ASGIArgs( scope=scope, receive=receive, - send=generator_result_callback, + send=_send, ) if is_asgi_app: request_args = asgi_args.to_args_tuple() @@ -1128,12 +1132,12 @@ async def _handle_user_method_result( for r in result: if request_metadata.is_grpc_request: r = (request_metadata.grpc_context, r.SerializeToString()) - await generator_result_callback(r) + generator_result_callback(r) elif result_is_async_gen: async for r in result: if request_metadata.is_grpc_request: r = (request_metadata.grpc_context, r.SerializeToString()) - await generator_result_callback(r) + generator_result_callback(r) elif request_metadata.is_http_request and not is_asgi_app: # For the FastAPI codepath, the response has already been sent over # ASGI, but for the vanilla deployment codepath we need to send it. diff --git a/python/ray/serve/_private/test_utils.py b/python/ray/serve/_private/test_utils.py index cf3e717ba5e3..64427528aaba 100644 --- a/python/ray/serve/_private/test_utils.py +++ b/python/ray/serve/_private/test_utils.py @@ -3,6 +3,7 @@ import os import threading import time +from contextlib import asynccontextmanager from copy import copy, deepcopy from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -514,14 +515,23 @@ def ping_fruit_stand(channel, app_name): assert response.costs == 32 +@asynccontextmanager async def send_signal_on_cancellation(signal_actor: ActorHandle): + cancelled = False try: - await asyncio.sleep(100000) + yield + await asyncio.sleep(100) except asyncio.CancelledError: + cancelled = True # Clear the context var to avoid Ray recursively cancelling this method call. ray._raylet.async_task_id.set(None) await signal_actor.send.remote() + if not cancelled: + raise RuntimeError( + "CancelledError wasn't raised during `send_signal_on_cancellation` block" + ) + class FakeGrpcContext: def __init__(self): diff --git a/python/ray/serve/_private/utils.py b/python/ray/serve/_private/utils.py index 3b9a5f4f6644..b57b8518b22f 100644 --- a/python/ray/serve/_private/utils.py +++ b/python/ray/serve/_private/utils.py @@ -39,6 +39,11 @@ np = None MESSAGE_PACK_OFFSET = 9 +GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR = RuntimeError( + "Streaming deployment handle results cannot be passed to " + "downstream handle calls. If you have a use case requiring " + "this feature, please file a feature request on GitHub." +) # Use a global singleton enum to emulate default options. We cannot use None @@ -598,12 +603,6 @@ async def resolve_request_args( """ from ray.serve.handle import DeploymentResponse, DeploymentResponseGenerator - generator_not_supported_message = ( - "Streaming deployment handle results cannot be passed to " - "downstream handle calls. If you have a use case requiring " - "this feature, please file a feature request on GitHub." - ) - new_args = [None for _ in range(len(request_args))] new_kwargs = {} @@ -611,7 +610,7 @@ async def resolve_request_args( response_indices = [] for i, obj in enumerate(request_args): if isinstance(obj, DeploymentResponseGenerator): - raise RuntimeError(generator_not_supported_message) + raise GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR elif isinstance(obj, DeploymentResponse): # Launch async task to convert DeploymentResponse to an object ref, and # keep track of the argument index in the original `request_args` @@ -624,7 +623,7 @@ async def resolve_request_args( response_keys = [] for k, obj in request_kwargs.items(): if isinstance(obj, DeploymentResponseGenerator): - raise RuntimeError(generator_not_supported_message) + raise GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR elif isinstance(obj, DeploymentResponse): # Launch async task to convert DeploymentResponse to an object ref, and # keep track of the corresponding key in the original `request_kwargs` diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index f45ffd9bdbc2..182795889d47 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -16,11 +16,17 @@ ReplicaConfig, handle_num_replicas_auto, ) -from ray.serve._private.constants import SERVE_DEFAULT_APP_NAME, SERVE_LOGGER_NAME +from ray.serve._private.constants import ( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + SERVE_DEFAULT_APP_NAME, + SERVE_LOGGER_NAME, +) from ray.serve._private.http_util import ( ASGIAppReplicaWrapper, make_fastapi_class_based_view, ) +from ray.serve._private.local_testing_mode import make_local_deployment_handle +from ray.serve._private.logging_utils import configure_component_logger from ray.serve._private.usage import ServeUsageTag from ray.serve._private.utils import ( DEFAULT, @@ -424,10 +430,12 @@ def decorator(_func_or_class): @PublicAPI(stability="stable") def _run( target: Application, + *, _blocking: bool = True, name: str = SERVE_DEFAULT_APP_NAME, route_prefix: Optional[str] = "/", logging_config: Optional[Union[Dict, LoggingConfig]] = None, + _local_testing_mode: bool = False, ) -> DeploymentHandle: """Run an application and return a handle to its ingress deployment. @@ -437,26 +445,46 @@ def _run( if len(name) == 0: raise RayServeException("Application name must a non-empty string.") - validate_route_prefix(route_prefix) - - client = _private_api.serve_start( - http_options={"location": "EveryNode"}, - ) - - # Record after Ray has been started. - ServeUsageTag.API_VERSION.record("v2") - if not isinstance(target, Application): raise TypeError( "`serve.run` expects an `Application` returned by `Deployment.bind()`." ) - return client.deploy_application( - build_app(target, name=name), - blocking=_blocking, - route_prefix=route_prefix, - logging_config=logging_config, - ) + if RAY_SERVE_FORCE_LOCAL_TESTING_MODE: + if not _local_testing_mode: + logger.info("Overriding local_testing_mode=True from environment variable.") + + _local_testing_mode = True + + validate_route_prefix(route_prefix) + + if _local_testing_mode: + configure_component_logger( + component_name="local_test", + component_id="-", + logging_config=logging_config or LoggingConfig(), + stream_handler_only=True, + ) + built_app = build_app( + target, + name=name, + make_deployment_handle=make_local_deployment_handle, + ) + handle = built_app.deployment_handles[built_app.ingress_deployment_name] + else: + client = _private_api.serve_start( + http_options={"location": "EveryNode"}, + ) + # Record after Ray has been started. + ServeUsageTag.API_VERSION.record("v2") + handle = client.deploy_application( + build_app(target, name=name), + blocking=_blocking, + route_prefix=route_prefix, + logging_config=logging_config, + ) + + return handle @PublicAPI(stability="stable") @@ -466,6 +494,7 @@ def run( name: str = SERVE_DEFAULT_APP_NAME, route_prefix: Optional[str] = "/", logging_config: Optional[Union[Dict, LoggingConfig]] = None, + _local_testing_mode: bool = False, ) -> DeploymentHandle: """Run an application and return a handle to its ingress deployment. @@ -498,6 +527,7 @@ def run( name=name, route_prefix=route_prefix, logging_config=logging_config, + _local_testing_mode=_local_testing_mode, ) logger.info(f"Deployed app '{name}' successfully.") diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 7cfb24bc67f9..e2d5366324b7 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -412,7 +412,12 @@ def __reduce__(self): "only pass `DeploymentResponse` objects as top level arguments." ) - def result(self, *, timeout_s: Optional[float] = None) -> Any: + def result( + self, + *, + timeout_s: Optional[float] = None, + _skip_asyncio_check: bool = False, + ) -> Any: """Fetch the result of the handle call synchronously. This should *not* be used from within a deployment as it runs in an asyncio @@ -422,7 +427,7 @@ def result(self, *, timeout_s: Optional[float] = None) -> Any: a `TimeoutError` is raised. """ - if is_running_in_asyncio_loop(): + if not _skip_asyncio_check and is_running_in_asyncio_loop(): raise RuntimeError( "Sync methods should not be called from within an `asyncio` event " "loop. Use `await response` instead." diff --git a/python/ray/serve/tests/BUILD b/python/ray/serve/tests/BUILD index 4abc0bdb6c1d..0af13939ec41 100644 --- a/python/ray/serve/tests/BUILD +++ b/python/ray/serve/tests/BUILD @@ -444,3 +444,26 @@ py_test_module_list( ], ) + +# Test handle API with local testing mode. +py_test_module_list( + size = "small", + env = {"RAY_SERVE_FORCE_LOCAL_TESTING_MODE": "1"}, + files = [ + "test_handle_1.py", + "test_handle_2.py", + "test_handle_cancellation.py", + "test_handle_streaming.py", + ], + name_suffix = "_with_local_testing_mode", + tags = [ + "exclusive", + "team:serve", + ], + deps = [ + ":common", + ":conftest", + "//python/ray/serve:serve_lib", + ], +) + diff --git a/python/ray/serve/tests/test_actor_replica_wrapper.py b/python/ray/serve/tests/test_actor_replica_wrapper.py index faa829c11a7e..0605be102b5e 100644 --- a/python/ray/serve/tests/test_actor_replica_wrapper.py +++ b/python/ray/serve/tests/test_actor_replica_wrapper.py @@ -66,9 +66,8 @@ async def handle_request_with_rejection( cancelled_signal_actor = kwargs.pop("cancelled_signal_actor", None) if cancelled_signal_actor is not None: executing_signal_actor = kwargs.pop("executing_signal_actor") - await executing_signal_actor.send.remote() - await send_signal_on_cancellation(cancelled_signal_actor) - return + async with send_signal_on_cancellation(cancelled_signal_actor): + await executing_signal_actor.send.remote() yield pickle.dumps(self._replica_queue_length_info) if not self._replica_queue_length_info.accepted: diff --git a/python/ray/serve/tests/test_grpc.py b/python/ray/serve/tests/test_grpc.py index 7ce7a7e986e2..b9ce56461563 100644 --- a/python/ray/serve/tests/test_grpc.py +++ b/python/ray/serve/tests/test_grpc.py @@ -495,8 +495,8 @@ async def test_grpc_proxy_cancellation(ray_instance, ray_shutdown, streaming: bo @serve.deployment class Downstream: async def wait_for_singal(self): - await running_signal_actor.send.remote() - await send_signal_on_cancellation(cancelled_signal_actor) + async with send_signal_on_cancellation(cancelled_signal_actor): + await running_signal_actor.send.remote() async def __call__(self, *args): await self.wait_for_singal() diff --git a/python/ray/serve/tests/test_handle_1.py b/python/ray/serve/tests/test_handle_1.py index ba59dfb88f3b..922e67c816e7 100644 --- a/python/ray/serve/tests/test_handle_1.py +++ b/python/ray/serve/tests/test_handle_1.py @@ -8,11 +8,18 @@ import ray from ray import serve from ray.serve._private.common import DeploymentHandleSource, RequestProtocol -from ray.serve._private.constants import SERVE_DEFAULT_APP_NAME +from ray.serve._private.constants import ( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + SERVE_DEFAULT_APP_NAME, +) from ray.serve.exceptions import RayServeException from ray.serve.handle import DeploymentHandle +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode doesn't set handle source", +) def test_replica_handle_source(serve_instance): @serve.deployment def f(): @@ -31,6 +38,10 @@ def check(self): assert h.check.remote().result() +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode work with tasks & actors", +) def test_handle_serializable(serve_instance): @serve.deployment def f(): @@ -55,6 +66,10 @@ async def __call__(self): assert app_handle.remote().result() == "hello" +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode doesn't support get_app_handle/get_deployment_handle", +) def test_get_and_call_handle_in_thread(serve_instance): @serve.deployment def f(): @@ -110,6 +125,10 @@ def __call__(self): assert handle2.request_counter.info == counter_info +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode doesn't support get_app_handle/get_deployment_handle", +) def test_repeated_get_handle_cached(serve_instance): @serve.deployment def f(_): @@ -164,6 +183,10 @@ def _get_asyncio_loop_running_in_thread() -> asyncio.AbstractEventLoop: return loop +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode doesn't support get_app_handle/get_deployment_handle", +) @pytest.mark.asyncio async def test_call_handle_across_asyncio_loops(serve_instance): @serve.deployment diff --git a/python/ray/serve/tests/test_handle_2.py b/python/ray/serve/tests/test_handle_2.py index 739c93e67a10..cc58f970f5b7 100644 --- a/python/ray/serve/tests/test_handle_2.py +++ b/python/ray/serve/tests/test_handle_2.py @@ -8,7 +8,10 @@ from ray import serve from ray._private.test_utils import SignalActor, async_wait_for_condition from ray._private.utils import get_or_create_event_loop -from ray.serve._private.constants import RAY_SERVE_ENABLE_STRICT_MAX_ONGOING_REQUESTS +from ray.serve._private.constants import ( + RAY_SERVE_ENABLE_STRICT_MAX_ONGOING_REQUESTS, + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, +) from ray.serve.exceptions import RayServeException from ray.serve.handle import ( DeploymentHandle, @@ -68,6 +71,10 @@ async def __call__(self): assert ref.result() == "hi" +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode doesn't support getting dynamic handles", +) def test_get_app_and_deployment_handle(serve_instance): """Test the `get_app_handle` and `get_deployment_handle` APIs.""" @@ -132,6 +139,10 @@ async def __call__(self): assert handle.remote().result() == "driver|downstream1|downstream2|hi" +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode only supports single apps", +) @pytest.mark.parametrize("arg_type", ["args", "kwargs"]) def test_compose_apps(serve_instance, arg_type): """Test composing deployment handle refs outside of a deployment.""" @@ -211,19 +222,32 @@ def test_nested_deployment_response_error(serve_instance): handle call errors, and with an informative error message.""" @serve.deployment - class Deployment: - def __call__(self, inp: str): - return inp + class Downstream: + def __call__(self, *args): + pass - handle1 = serve.run(Deployment.bind(), name="app1", route_prefix="/app1") - handle2 = serve.run(Deployment.bind(), name="app2", route_prefix="/app2") + @serve.deployment + class Upstream: + def __init__(self, h1: DeploymentHandle, h2: DeploymentHandle): + self._h1 = h1 + self._h2 = h2 - with pytest.raises( - RayServeException, match="`DeploymentResponse` is not serializable" - ): - handle1.remote([handle2.remote("hi")]).result() + async def __call__(self): + with pytest.raises( + RayServeException, match="`DeploymentResponse` is not serializable" + ): + await self._h2.remote([self._h2.remote()]) + h = serve.run( + Upstream.bind(Downstream.bind(), Downstream.bind()), + ) + h.remote().result() + +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode doesn't support _to_object_ref", +) def test_convert_to_object_ref(serve_instance): """Test converting deployment handle refs to Ray object refs.""" @@ -276,6 +300,10 @@ async def __call__(self): assert list(gen) == list(range(10)) +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode doesn't support _to_object_ref", +) def test_convert_to_object_ref_gen(serve_instance): """Test converting generators to obj ref gens inside and outside a deployment.""" @@ -380,9 +408,13 @@ async def __call__(self): r.result() == "OK" +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode doesn't respect max_ongoing_requests", +) @pytest.mark.skipif( not RAY_SERVE_ENABLE_STRICT_MAX_ONGOING_REQUESTS, - reason="Strict enforcement must be enabled.", + reason="Strict enforcement must be enabled", ) @pytest.mark.asyncio async def test_max_ongoing_requests_enforced(serve_instance): diff --git a/python/ray/serve/tests/test_handle_cancellation.py b/python/ray/serve/tests/test_handle_cancellation.py index 0317ccb88127..3b4ec2bc6160 100644 --- a/python/ray/serve/tests/test_handle_cancellation.py +++ b/python/ray/serve/tests/test_handle_cancellation.py @@ -9,6 +9,7 @@ async_wait_for_condition, wait_for_condition, ) +from ray.serve._private.constants import RAY_SERVE_FORCE_LOCAL_TESTING_MODE from ray.serve._private.test_utils import send_signal_on_cancellation, tlog from ray.serve.exceptions import RequestCancelledError @@ -21,8 +22,8 @@ def test_cancel_sync_handle_call_during_execution(serve_instance): @serve.deployment class Ingress: async def __call__(self, *args): - await running_signal_actor.send.remote() - await send_signal_on_cancellation(cancelled_signal_actor) + async with send_signal_on_cancellation(cancelled_signal_actor): + await running_signal_actor.send.remote() h = serve.run(Ingress.bind()) @@ -38,6 +39,10 @@ async def __call__(self, *args): r.result() +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode doesn't have assignment/execution split", +) def test_cancel_sync_handle_call_during_assignment(serve_instance): """Test cancelling handle request during assignment (sync context).""" signal_actor = SignalActor.remote() @@ -82,8 +87,8 @@ def test_cancel_async_handle_call_during_execution(serve_instance): @serve.deployment class Downstream: async def __call__(self, *args): - await running_signal_actor.send.remote() - await send_signal_on_cancellation(cancelled_signal_actor) + async with send_signal_on_cancellation(cancelled_signal_actor): + await running_signal_actor.send.remote() @serve.deployment class Ingress: @@ -106,6 +111,10 @@ async def __call__(self, *args): h.remote().result() # Would raise if test failed. +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode doesn't have assignment/execution split", +) def test_cancel_async_handle_call_during_assignment(serve_instance): """Test cancelling handle request during assignment (async context).""" signal_actor = SignalActor.remote() @@ -161,7 +170,8 @@ def test_cancel_generator_sync(serve_instance): class Ingress: async def __call__(self, *args): yield "hi" - await send_signal_on_cancellation(signal_actor) + async with send_signal_on_cancellation(signal_actor): + pass h = serve.run(Ingress.bind()).options(stream=True) @@ -187,7 +197,8 @@ def test_cancel_generator_async(serve_instance): class Downstream: async def __call__(self, *args): yield "hi" - await send_signal_on_cancellation(signal_actor) + async with send_signal_on_cancellation(signal_actor): + pass @serve.deployment class Ingress: @@ -279,13 +290,18 @@ async def get_out_of_band_response(self): assert h.get_out_of_band_response.remote().result() == "ok" +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode doesn't implement recursive cancellation", +) def test_recursive_cancellation_during_execution(serve_instance): inner_signal_actor = SignalActor.remote() outer_signal_actor = SignalActor.remote() @serve.deployment async def inner(): - await send_signal_on_cancellation(inner_signal_actor) + async with send_signal_on_cancellation(inner_signal_actor): + pass @serve.deployment class Ingress: @@ -294,7 +310,8 @@ def __init__(self, handle): async def __call__(self): _ = self._handle.remote() - await send_signal_on_cancellation(outer_signal_actor) + async with send_signal_on_cancellation(outer_signal_actor): + pass h = serve.run(Ingress.bind(inner.bind())) @@ -307,6 +324,10 @@ async def __call__(self): ray.get(outer_signal_actor.wait.remote(), timeout=10) +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode doesn't implement recursive cancellation", +) def test_recursive_cancellation_during_assignment(serve_instance): signal = SignalActor.remote() diff --git a/python/ray/serve/tests/test_handle_streaming.py b/python/ray/serve/tests/test_handle_streaming.py index 782460026b97..1a7e6ed5991a 100644 --- a/python/ray/serve/tests/test_handle_streaming.py +++ b/python/ray/serve/tests/test_handle_streaming.py @@ -269,7 +269,9 @@ async def __call__(self): with pytest.raises(StopAsyncIteration): assert await gen2.__anext__() - h = serve.run(Delegate.bind(deployment.bind(), deployment.bind())) + h = serve.run( + Delegate.bind(deployment.bind(), deployment.bind()), + ) h.remote().result() diff --git a/python/ray/serve/tests/test_http_cancellation.py b/python/ray/serve/tests/test_http_cancellation.py index 8bf4441fbda3..b70a5c57c9d7 100644 --- a/python/ray/serve/tests/test_http_cancellation.py +++ b/python/ray/serve/tests/test_http_cancellation.py @@ -21,7 +21,8 @@ def test_cancel_on_http_client_disconnect_during_execution( @serve.deployment async def inner(): - await send_signal_on_cancellation(inner_signal_actor) + async with send_signal_on_cancellation(inner_signal_actor): + pass if use_fastapi: app = FastAPI() @@ -35,7 +36,8 @@ def __init__(self, handle): @app.get("/") async def wait_for_cancellation(self): _ = self._handle.remote() - await send_signal_on_cancellation(outer_signal_actor) + async with send_signal_on_cancellation(outer_signal_actor): + pass else: @@ -46,7 +48,8 @@ def __init__(self, handle): async def __call__(self, request: Request): _ = self._handle.remote() - await send_signal_on_cancellation(outer_signal_actor) + async with send_signal_on_cancellation(outer_signal_actor): + pass serve.run(Ingress.bind(inner.bind())) diff --git a/python/ray/serve/tests/test_request_timeout.py b/python/ray/serve/tests/test_request_timeout.py index 6a0575326988..a62934983790 100644 --- a/python/ray/serve/tests/test_request_timeout.py +++ b/python/ray/serve/tests/test_request_timeout.py @@ -296,7 +296,8 @@ def test_cancel_on_http_timeout_during_execution( @serve.deployment async def inner(): - await send_signal_on_cancellation(inner_signal_actor) + async with send_signal_on_cancellation(inner_signal_actor): + pass if use_fastapi: app = FastAPI() @@ -310,7 +311,8 @@ def __init__(self, handle): @app.get("/") async def wait_for_cancellation(self): _ = self._handle.remote() - await send_signal_on_cancellation(outer_signal_actor) + async with send_signal_on_cancellation(outer_signal_actor): + pass else: @@ -321,7 +323,8 @@ def __init__(self, handle): async def __call__(self, request: Request): _ = self._handle.remote() - await send_signal_on_cancellation(outer_signal_actor) + async with send_signal_on_cancellation(outer_signal_actor): + pass serve.run(Ingress.bind(inner.bind())) diff --git a/python/ray/serve/tests/unit/test_local_testing_mode.py b/python/ray/serve/tests/unit/test_local_testing_mode.py new file mode 100644 index 000000000000..d8d47d322f68 --- /dev/null +++ b/python/ray/serve/tests/unit/test_local_testing_mode.py @@ -0,0 +1,91 @@ +import sys + +import pytest + +from ray import serve +from ray.serve.handle import DeploymentHandle + + +def test_basic_composition(): + @serve.deployment + class Inner: + def __init__(self, my_name: str): + self._my_name = my_name + + def __call__(self): + return self._my_name + + @serve.deployment + class Outer: + def __init__(self, my_name: str, inner_handle: DeploymentHandle): + assert isinstance(inner_handle, DeploymentHandle) + + self._my_name = my_name + self._inner_handle = inner_handle + + async def __call__(self, name: str): + inner_name = await self._inner_handle.remote() + return f"Hello {name} from {self._my_name} and {inner_name}!" + + h = serve.run(Outer.bind("Theodore", Inner.bind("Kevin")), _local_testing_mode=True) + assert isinstance(h, DeploymentHandle) + assert h.remote("Edith").result() == "Hello Edith from Theodore and Kevin!" + + +@pytest.mark.parametrize("deployment", ["Inner", "Outer"]) +def test_exception_raised_in_constructor(deployment: str): + @serve.deployment + class Inner: + def __init__(self, should_raise: bool): + if should_raise: + raise RuntimeError("Exception in Inner constructor.") + + @serve.deployment + class Outer: + def __init__(self, h: DeploymentHandle, should_raise: bool): + if should_raise: + raise RuntimeError("Exception in Outer constructor.") + + with pytest.raises(RuntimeError, match=f"Exception in {deployment} constructor."): + serve.run( + Outer.bind(Inner.bind(deployment == "Inner"), deployment == "Outer"), + _local_testing_mode=True, + ) + + +def test_to_object_ref_error_message(): + @serve.deployment + class Inner: + pass + + @serve.deployment + class Outer: + def __init__(self, h: DeploymentHandle): + self._h = h + + async def __call__(self): + with pytest.raises( + RuntimeError, + match=( + "Converting DeploymentResponses to ObjectRefs " + "is not supported in local testing mode." + ), + ): + await self._h.remote()._to_object_ref() + + h = serve.run(Outer.bind(Inner.bind()), _local_testing_mode=True) + with pytest.raises( + RuntimeError, + match=( + "Converting DeploymentResponses to ObjectRefs " + "is not supported in local testing mode." + ), + ): + h.remote()._to_object_ref_sync() + + # Test the inner handle case (this would raise if it failed). + h.remote().result() + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/unit/test_user_callable_wrapper.py b/python/ray/serve/tests/unit/test_user_callable_wrapper.py index d43747845f55..b03c9ca7e39e 100644 --- a/python/ray/serve/tests/unit/test_user_callable_wrapper.py +++ b/python/ray/serve/tests/unit/test_user_callable_wrapper.py @@ -4,7 +4,7 @@ import sys import threading from dataclasses import dataclass -from typing import Any, AsyncGenerator, Callable, Generator, Optional +from typing import AsyncGenerator, Callable, Generator, Optional import pytest from fastapi import FastAPI @@ -221,9 +221,6 @@ def test_basic_class_callable_generators(): result_list = [] - async def append_to_list(item: Any): - result_list.append(item) - # Call sync generator without is_streaming. request_metadata = _make_request_metadata( call_method="call_generator", is_streaming=False @@ -232,7 +229,10 @@ async def append_to_list(item: Any): TypeError, match="Method 'call_generator' returned a generator." ): user_callable_wrapper.call_user_method( - request_metadata, (10,), dict(), generator_result_callback=append_to_list + request_metadata, + (10,), + dict(), + generator_result_callback=result_list.append, ).result() # Call sync generator. @@ -240,7 +240,7 @@ async def append_to_list(item: Any): call_method="call_generator", is_streaming=True ) user_callable_wrapper.call_user_method( - request_metadata, (10,), dict(), generator_result_callback=append_to_list + request_metadata, (10,), dict(), generator_result_callback=result_list.append ).result() assert result_list == list(range(10)) result_list.clear() @@ -251,7 +251,7 @@ async def append_to_list(item: Any): request_metadata, (10,), {"raise_exception": True}, - generator_result_callback=append_to_list, + generator_result_callback=result_list.append, ).result() assert result_list == [0] result_list.clear() @@ -264,7 +264,10 @@ async def append_to_list(item: Any): TypeError, match="Method 'call_async_generator' returned a generator." ): user_callable_wrapper.call_user_method( - request_metadata, (10,), dict(), generator_result_callback=append_to_list + request_metadata, + (10,), + dict(), + generator_result_callback=result_list.append, ).result() # Call async generator. @@ -272,7 +275,7 @@ async def append_to_list(item: Any): call_method="call_async_generator", is_streaming=True ) user_callable_wrapper.call_user_method( - request_metadata, (10,), dict(), generator_result_callback=append_to_list + request_metadata, (10,), dict(), generator_result_callback=result_list.append ).result() assert result_list == list(range(10)) result_list.clear() @@ -283,7 +286,7 @@ async def append_to_list(item: Any): request_metadata, (10,), {"raise_exception": True}, - generator_result_callback=append_to_list, + generator_result_callback=result_list.append, ).result() assert result_list == [0] @@ -329,16 +332,16 @@ def test_basic_function_callable_generators(fn: Callable): result_list = [] - async def append_to_list(item: Any): - result_list.append(item) - # Call generator function without is_streaming. request_metadata = _make_request_metadata(is_streaming=False) with pytest.raises( TypeError, match=f"Method '{fn.__name__}' returned a generator." ): user_callable_wrapper.call_user_method( - request_metadata, (10,), dict(), generator_result_callback=append_to_list + request_metadata, + (10,), + dict(), + generator_result_callback=result_list.append, ).result() # Call generator function. @@ -346,7 +349,7 @@ async def append_to_list(item: Any): call_method="call_generator", is_streaming=True ) user_callable_wrapper.call_user_method( - request_metadata, (10,), dict(), generator_result_callback=append_to_list + request_metadata, (10,), dict(), generator_result_callback=result_list.append ).result() assert result_list == list(range(10)) result_list.clear() @@ -357,7 +360,7 @@ async def append_to_list(item: Any): request_metadata, (10,), {"raise_exception": True}, - generator_result_callback=append_to_list, + generator_result_callback=result_list.append, ).result() assert result_list == [0] @@ -525,9 +528,6 @@ def test_grpc_streaming_request(): result_list = [] - async def append_to_list(item: Any): - result_list.append(item) - request_metadata = _make_request_metadata( call_method="stream", is_grpc_request=True, is_streaming=True ) @@ -535,7 +535,7 @@ async def append_to_list(item: Any): request_metadata, (grpc_request,), dict(), - generator_result_callback=append_to_list, + generator_result_callback=result_list.append, ).result() assert len(result_list) == 10 @@ -614,15 +614,12 @@ async def receive_asgi_messages(_: str): result_list = [] - async def append_to_list(item: Any): - result_list.append(item) - request_metadata = _make_request_metadata(is_http_request=True, is_streaming=True) user_callable_wrapper.call_user_method( request_metadata, (http_request,), dict(), - generator_result_callback=append_to_list, + generator_result_callback=result_list.append, ).result() assert result_list[0]["type"] == "http.response.start"