Skip to content

Commit

Permalink
[serve] move handle options to its own file (ray-project#48454)
Browse files Browse the repository at this point in the history
## Why are these changes needed?

Move handle options to `handle_options.py`.


Signed-off-by: Cindy Zhang <[email protected]>
  • Loading branch information
zcin authored Nov 1, 2024
1 parent 199b582 commit 056d596
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 87 deletions.
15 changes: 6 additions & 9 deletions python/ray/serve/_private/default_impl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Optional, Tuple
from typing import Callable, Optional, Tuple

import ray
from ray._raylet import GcsClient
Expand All @@ -17,6 +17,7 @@
DeploymentScheduler,
)
from ray.serve._private.grpc_util import gRPCServer
from ray.serve._private.handle_options import DynamicHandleOptions, InitHandleOptions
from ray.serve._private.replica_scheduler import (
ActorReplicaWrapper,
PowerOfTwoChoicesReplicaScheduler,
Expand Down Expand Up @@ -55,15 +56,11 @@ def create_deployment_scheduler(


def create_dynamic_handle_options(**kwargs):
from ray.serve.handle import _DynamicHandleOptions

return _DynamicHandleOptions(**kwargs)
return DynamicHandleOptions(**kwargs)


def create_init_handle_options(**kwargs):
from ray.serve.handle import _InitHandleOptions

return _InitHandleOptions.create(**kwargs)
return InitHandleOptions.create(**kwargs)


def _get_node_id_and_az() -> Tuple[str, Optional[str]]:
Expand All @@ -81,13 +78,13 @@ def _get_node_id_and_az() -> Tuple[str, Optional[str]]:


# Interface definition for create_router.
CreateRouterCallable = Callable[[str, DeploymentID, Any], Router]
CreateRouterCallable = Callable[[str, DeploymentID, InitHandleOptions], Router]


def create_router(
handle_id: str,
deployment_id: DeploymentID,
handle_options: Any,
handle_options: InitHandleOptions,
) -> Router:
# NOTE(edoakes): this is lazy due to a nasty circular import that should be fixed.
from ray.serve.context import _get_global_client
Expand Down
71 changes: 71 additions & 0 deletions python/ray/serve/_private/handle_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields

import ray
from ray.serve._private.common import DeploymentHandleSource, RequestProtocol
from ray.serve._private.utils import DEFAULT


@dataclass(frozen=True)
class InitHandleOptionsBase(ABC):
"""Init options for each ServeHandle instance.
These fields can be set by calling `.init()` on a handle before
sending the first request.
"""

_prefer_local_routing: bool = False
_source: DeploymentHandleSource = DeploymentHandleSource.UNKNOWN

@classmethod
@abstractmethod
def create(cls, **kwargs) -> "InitHandleOptionsBase":
raise NotImplementedError


@dataclass(frozen=True)
class InitHandleOptions(InitHandleOptionsBase):
@classmethod
def create(cls, **kwargs) -> "InitHandleOptions":
for k in list(kwargs.keys()):
if kwargs[k] == DEFAULT.VALUE:
# Use default value
del kwargs[k]

# Detect replica source for handles
if (
"_source" not in kwargs
and ray.serve.context._get_internal_replica_context() is not None
):
kwargs["_source"] = DeploymentHandleSource.REPLICA

return cls(**kwargs)


@dataclass(frozen=True)
class DynamicHandleOptionsBase(ABC):
"""Dynamic options for each ServeHandle instance.
These fields can be changed by calling `.options()` on a handle.
"""

method_name: str = "__call__"
multiplexed_model_id: str = ""
stream: bool = False
_request_protocol: str = RequestProtocol.UNDEFINED

def copy_and_update(self, **kwargs) -> "DynamicHandleOptionsBase":
new_kwargs = {}

for f in fields(self):
if f.name not in kwargs or kwargs[f.name] == DEFAULT.VALUE:
new_kwargs[f.name] = getattr(self, f.name)
else:
new_kwargs[f.name] = kwargs[f.name]

return DynamicHandleOptions(**new_kwargs)


@dataclass(frozen=True)
class DynamicHandleOptions(DynamicHandleOptionsBase):
pass
79 changes: 8 additions & 71 deletions python/ray/serve/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,22 @@
import logging
import time
import warnings
from abc import ABC
from dataclasses import dataclass, fields
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Tuple, Union

import ray
from ray._raylet import ObjectRefGenerator
from ray.serve._private.common import (
DeploymentHandleSource,
DeploymentID,
RequestMetadata,
RequestProtocol,
)
from ray.serve._private.common import DeploymentID, RequestMetadata, RequestProtocol
from ray.serve._private.constants import SERVE_LOGGER_NAME
from ray.serve._private.default_impl import (
CreateRouterCallable,
create_dynamic_handle_options,
create_init_handle_options,
create_router,
)
from ray.serve._private.handle_options import (
DynamicHandleOptionsBase,
InitHandleOptionsBase,
)
from ray.serve._private.replica_result import ReplicaResult
from ray.serve._private.router import Router
from ray.serve._private.usage import ServeUsageTag
Expand All @@ -40,83 +37,23 @@
logger = logging.getLogger(SERVE_LOGGER_NAME)


@dataclass(frozen=True)
class _InitHandleOptionsBase:
"""Init options for each ServeHandle instance.
These fields can be set by calling `.init()` on a handle before
sending the first request.
"""

_prefer_local_routing: bool = False
_source: DeploymentHandleSource = DeploymentHandleSource.UNKNOWN


@dataclass(frozen=True)
class _InitHandleOptions(_InitHandleOptionsBase):
@classmethod
def create(cls, **kwargs) -> "_InitHandleOptions":
for k in list(kwargs.keys()):
if kwargs[k] == DEFAULT.VALUE:
# Use default value
del kwargs[k]

# Detect replica source for handles
if (
"_source" not in kwargs
and ray.serve.context._get_internal_replica_context() is not None
):
kwargs["_source"] = DeploymentHandleSource.REPLICA

return cls(**kwargs)


@dataclass(frozen=True)
class _DynamicHandleOptionsBase(ABC):
"""Dynamic options for each ServeHandle instance.
These fields can be changed by calling `.options()` on a handle.
"""

method_name: str = "__call__"
multiplexed_model_id: str = ""
stream: bool = False
_request_protocol: str = RequestProtocol.UNDEFINED

def copy_and_update(self, **kwargs) -> "_DynamicHandleOptionsBase":
new_kwargs = {}

for f in fields(self):
if f.name not in kwargs or kwargs[f.name] == DEFAULT.VALUE:
new_kwargs[f.name] = getattr(self, f.name)
else:
new_kwargs[f.name] = kwargs[f.name]

return _DynamicHandleOptions(**new_kwargs)


@dataclass(frozen=True)
class _DynamicHandleOptions(_DynamicHandleOptionsBase):
pass


class _DeploymentHandleBase:
def __init__(
self,
deployment_name: str,
app_name: str,
*,
handle_options: Optional[_DynamicHandleOptionsBase] = None,
handle_options: Optional[DynamicHandleOptionsBase] = None,
_router: Optional[Router] = None,
_create_router: Optional[CreateRouterCallable] = None,
_request_counter: Optional[metrics.Counter] = None,
_recorded_telemetry: bool = False,
):
self.deployment_id = DeploymentID(name=deployment_name, app_name=app_name)
self.handle_options: _DynamicHandleOptionsBase = (
self.handle_options: DynamicHandleOptionsBase = (
handle_options or create_dynamic_handle_options()
)
self.init_options: Optional[_InitHandleOptionsBase] = None
self.init_options: Optional[InitHandleOptionsBase] = None

self.handle_id = get_random_string()
self.request_counter = _request_counter or self._create_request_counter(
Expand Down
14 changes: 7 additions & 7 deletions python/ray/serve/tests/unit/test_handle_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import pytest

from ray.serve._private.common import DeploymentHandleSource, RequestProtocol
from ray.serve._private.handle_options import DynamicHandleOptions, InitHandleOptions
from ray.serve._private.utils import DEFAULT
from ray.serve.handle import _DynamicHandleOptions, _InitHandleOptions


def test_dynamic_handle_options():
default_options = _DynamicHandleOptions()
default_options = DynamicHandleOptions()
assert default_options.method_name == "__call__"
assert default_options.multiplexed_model_id == ""
assert default_options.stream is False
Expand Down Expand Up @@ -62,25 +62,25 @@ def test_dynamic_handle_options():


def test_init_handle_options():
default_options = _InitHandleOptions.create()
default_options = InitHandleOptions.create()
assert default_options._prefer_local_routing is False
assert default_options._source == DeploymentHandleSource.UNKNOWN

default1 = _InitHandleOptions.create(_prefer_local_routing=DEFAULT.VALUE)
default1 = InitHandleOptions.create(_prefer_local_routing=DEFAULT.VALUE)
assert default1._prefer_local_routing is False
assert default1._source == DeploymentHandleSource.UNKNOWN

default2 = _InitHandleOptions.create(_source=DEFAULT.VALUE)
default2 = InitHandleOptions.create(_source=DEFAULT.VALUE)
assert default2._prefer_local_routing is False
assert default2._source == DeploymentHandleSource.UNKNOWN

prefer_local = _InitHandleOptions.create(
prefer_local = InitHandleOptions.create(
_prefer_local_routing=True, _source=DEFAULT.VALUE
)
assert prefer_local._prefer_local_routing is True
assert prefer_local._source == DeploymentHandleSource.UNKNOWN

proxy_options = _InitHandleOptions.create(_source=DeploymentHandleSource.PROXY)
proxy_options = InitHandleOptions.create(_source=DeploymentHandleSource.PROXY)
assert proxy_options._prefer_local_routing is False
assert proxy_options._source == DeploymentHandleSource.PROXY

Expand Down

0 comments on commit 056d596

Please sign in to comment.