diff --git a/ci/lint/pydoclint-baseline.txt b/ci/lint/pydoclint-baseline.txt index 02e90b334242..fc84272857bb 100644 --- a/ci/lint/pydoclint-baseline.txt +++ b/ci/lint/pydoclint-baseline.txt @@ -1453,10 +1453,6 @@ python/ray/serve/_private/proxy_response_generator.py python/ray/serve/_private/proxy_state.py DOC201: Method `ProxyStateManager.get_targets` does not have a return section in docstring -------------------- -python/ray/serve/_private/router.py - DOC101: Method `SingletonThreadRouter.assign_request`: Docstring contains fewer arguments than in function signature. - DOC103: Method `SingletonThreadRouter.assign_request`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**request_kwargs: , *request_args: , request_meta: RequestMetadata]. --------------------- python/ray/serve/_private/storage/kv_store.py DOC201: Method `RayInternalKVStore.put` does not have a return section in docstring DOC201: Method `RayInternalKVStore.delete` does not have a return section in docstring diff --git a/doc/source/serve/api/index.md b/doc/source/serve/api/index.md index 19a589e38c70..7be0c11517cb 100644 --- a/doc/source/serve/api/index.md +++ b/doc/source/serve/api/index.md @@ -171,6 +171,7 @@ See the [model composition guide](serve-model-composition) for how to update cod serve.exceptions.RequestCancelledError serve.exceptions.gRPCStatusError serve.exceptions.DeploymentUnavailableError + serve.exceptions.ReplicaUnavailableError ``` @@ -524,4 +525,4 @@ Content-Type: application/json serve.llm.LLMServer serve.llm.LLMRouter -``` \ No newline at end of file +``` diff --git a/python/ray/serve/_private/common.py b/python/ray/serve/_private/common.py index c095a3ee44e5..29cd797791a3 100644 --- a/python/ray/serve/_private/common.py +++ b/python/ray/serve/_private/common.py @@ -820,6 +820,9 @@ class RequestMetadata: request_serialization: str = "cloudpickle" response_serialization: str = "cloudpickle" + # Token for a replica-side slot reserved by choose_replica(). + _reserved_slot_token: Optional[str] = None + @property def is_http_request(self) -> bool: return self._request_protocol == RequestProtocol.HTTP diff --git a/python/ray/serve/_private/local_testing_mode.py b/python/ray/serve/_private/local_testing_mode.py index e1fea75b6141..8bd847ff27d2 100644 --- a/python/ray/serve/_private/local_testing_mode.py +++ b/python/ray/serve/_private/local_testing_mode.py @@ -4,8 +4,19 @@ import logging import queue import time +from contextlib import asynccontextmanager from functools import wraps -from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Union +from typing import ( + Any, + AsyncIterator, + Callable, + Coroutine, + Dict, + List, + Optional, + Tuple, + Union, +) import ray from ray import cloudpickle @@ -16,6 +27,7 @@ ) from ray.serve._private.replica import UserCallableWrapper from ray.serve._private.replica_result import ReplicaResult +from ray.serve._private.request_router.replica_wrapper import ReplicaSelection from ray.serve._private.router import Router from ray.serve._private.utils import GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR from ray.serve.deployment import Deployment @@ -341,6 +353,39 @@ def generator_result_callback(item: Any): ) return noop_future + @asynccontextmanager + async def choose_replica( + self, + request_meta: RequestMetadata, + *request_args, + **request_kwargs, + ) -> AsyncIterator[ReplicaSelection]: + """Choose replica is not supported in local testing mode. + + This is a stub implementation to satisfy the Router ABC interface. + """ + raise NotImplementedError( + "choose_replica is not supported in local testing mode. " + "Use assign_request instead." + ) + yield # Make this a generator for asynccontextmanager + + def dispatch( + self, + selection: ReplicaSelection, + request_meta: RequestMetadata, + *request_args, + **request_kwargs, + ) -> concurrent.futures.Future[ReplicaResult]: + """Dispatch is not supported in local testing mode. + + This is a stub implementation to satisfy the Router ABC interface. + """ + raise NotImplementedError( + "dispatch is not supported in local testing mode. " + "Use assign_request instead." + ) + async def broadcast( self, request_meta: RequestMetadata, diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index c71b2694ae11..dea36d8f319d 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -1182,13 +1182,36 @@ def __init__( self._direct_ingress_grpc_server_task: Optional[asyncio.Task] = None self._num_queued_requests = 0 + self._reserved_slots: Set[str] = set() @property def max_ongoing_requests(self) -> int: return self._deployment_config.max_ongoing_requests def get_num_ongoing_requests(self) -> int: - return self._metrics_manager.get_num_ongoing_requests() + return self._metrics_manager.get_num_ongoing_requests() + len( + self._reserved_slots + ) + + async def reserve_slot( + self, request_metadata: RequestMetadata, slot_token: str + ) -> Tuple[bool, int]: + """Reserve replica capacity for a future dispatch call.""" + if not self._can_accept_request(request_metadata): + return False, self.get_num_ongoing_requests() + + await self._semaphore.acquire() + self._reserved_slots.add(slot_token) + return True, self.get_num_ongoing_requests() + + def release_slot(self, slot_token: str) -> Tuple[bool, int]: + """Release replica capacity reserved by choose_replica().""" + if slot_token not in self._reserved_slots: + return False, self.get_num_ongoing_requests() + + self._reserved_slots.remove(slot_token) + self._semaphore.release() + return True, self.get_num_ongoing_requests() def get_metadata(self) -> ReplicaMetadata: current_rank = ray.serve.context._get_internal_replica_context().rank @@ -1865,12 +1888,25 @@ def _on_request_failed(self, request_metadata: RequestMetadata, e: Exception): @asynccontextmanager async def _start_request(self, request_metadata: RequestMetadata): - async with self._semaphore: + reserved_slot_token = request_metadata._reserved_slot_token + if reserved_slot_token: + if reserved_slot_token not in self._reserved_slots: + raise RuntimeError( + "Request tried to consume an unknown reserved slot " + f"{reserved_slot_token}." + ) + self._reserved_slots.remove(reserved_slot_token) + else: + await self._semaphore.acquire() + + try: try: self._metrics_manager.inc_num_ongoing_requests(request_metadata) yield finally: self._metrics_manager.dec_num_ongoing_requests(request_metadata) + finally: + self._semaphore.release() async def _drain_ongoing_requests(self): """Wait for any ongoing requests to finish. @@ -2759,6 +2795,16 @@ def get_num_ongoing_requests(self) -> int: """ return self._replica_impl.get_num_ongoing_requests() + async def reserve_slot( + self, request_metadata: RequestMetadata, slot_token: str + ) -> Tuple[bool, int]: + """Reserve capacity for a future choose_replica/dispatch request.""" + return await self._replica_impl.reserve_slot(request_metadata, slot_token) + + def release_slot(self, slot_token: str) -> Tuple[bool, int]: + """Release capacity reserved by choose_replica().""" + return self._replica_impl.release_slot(slot_token) + async def is_allocated(self) -> str: """poke the replica to check whether it's alive. diff --git a/python/ray/serve/_private/request_router/__init__.py b/python/ray/serve/_private/request_router/__init__.py index b4b761681077..cbbccfe9dc4b 100644 --- a/python/ray/serve/_private/request_router/__init__.py +++ b/python/ray/serve/_private/request_router/__init__.py @@ -3,6 +3,7 @@ PowerOfTwoChoicesRequestRouter, ) from ray.serve._private.request_router.replica_wrapper import ( # noqa: F401 + ReplicaSelection, RunningReplica, ) from ray.serve._private.request_router.request_router import ( # noqa: F401 diff --git a/python/ray/serve/_private/request_router/replica_wrapper.py b/python/ray/serve/_private/request_router/replica_wrapper.py index 576e95a73b72..fc529e66bf22 100644 --- a/python/ray/serve/_private/request_router/replica_wrapper.py +++ b/python/ray/serve/_private/request_router/replica_wrapper.py @@ -1,7 +1,8 @@ import asyncio -import logging import pickle +import uuid from abc import ABC, abstractmethod +from dataclasses import dataclass, field from typing import Any, Dict, Optional, Set, Tuple import grpc @@ -9,13 +10,13 @@ import ray from ray.actor import ActorHandle from ray.serve._private.common import ( + DeploymentID, ReplicaID, + ReplicaQueueLengthInfo, + RequestMetadata, RunningReplicaInfo, ) -from ray.serve._private.constants import ( - RAY_SERVE_REPLICA_GRPC_MAX_MESSAGE_LENGTH, - SERVE_LOGGER_NAME, -) +from ray.serve._private.constants import RAY_SERVE_REPLICA_GRPC_MAX_MESSAGE_LENGTH from ray.serve._private.replica_result import ( ActorReplicaResult, ReplicaResult, @@ -35,8 +36,6 @@ _is_tracing_enabled, ) -logger = logging.getLogger(SERVE_LOGGER_NAME) - class ReplicaWrapper(ABC): """This is used to abstract away details of the transport layer @@ -201,6 +200,10 @@ def __init__(self, replica_info: RunningReplicaInfo): self._actor_replica_wrapper = ActorReplicaWrapper(self._actor_handle) self._grpc_replica_wrapper = None + # Active local slot reservation tokens for Java replicas. Python replicas + # reserve capacity on the actor-side semaphore. + self._reserved_slots: Set[str] = set() + def update_replica_info(self, replica_info: RunningReplicaInfo) -> None: """Update mutable fields from a new RunningReplicaInfo. @@ -324,3 +327,129 @@ def try_send_request( return wrapper.send_request_java(pr) return wrapper.send_request_python(pr, with_rejection=with_rejection) + + async def reserve_slot( + self, request_metadata: RequestMetadata + ) -> Tuple[str, ReplicaQueueLengthInfo]: + """Reserve a slot on this replica for an upcoming request. + + Returns a unique token that can be used to release the slot later. + This is used in the choose_replica/dispatch pattern to track + reservations that haven't been dispatched yet. + """ + if self._replica_info.is_cross_language: + slot_token = str(uuid.uuid4()) + self._reserved_slots.add(slot_token) + return slot_token, ReplicaQueueLengthInfo( + accepted=True, + num_ongoing_requests=len(self._reserved_slots), + ) + + slot_token = str(uuid.uuid4()) + obj_ref = self._actor_handle.reserve_slot.remote(request_metadata, slot_token) + try: + accepted, num_ongoing_requests = await obj_ref + except asyncio.CancelledError: + ray.cancel(obj_ref) + self._actor_handle.release_slot.remote(slot_token) + raise + + return slot_token, ReplicaQueueLengthInfo( + accepted=accepted, + num_ongoing_requests=num_ongoing_requests, + ) + + async def release_slot(self, slot_token: str) -> ReplicaQueueLengthInfo: + """Release a previously reserved slot. + + This should be called if a request is not dispatched after + reserving a slot (e.g., due to an error or cancellation). + """ + if self._replica_info.is_cross_language: + self._reserved_slots.discard(slot_token) + return ReplicaQueueLengthInfo( + accepted=True, + num_ongoing_requests=len(self._reserved_slots), + ) + + _, num_ongoing_requests = await self._actor_handle.release_slot.remote( + slot_token + ) + return ReplicaQueueLengthInfo( + accepted=True, + num_ongoing_requests=num_ongoing_requests, + ) + + +@dataclass +class ReplicaSelection: + """Represents a selected replica, holding information for dispatch or coordination. + + This class is returned by the choose_replica() context manager. + The slot reservation lifecycle is managed by the context manager. + """ + + # Public, user-accessible fields + replica_id: str + """Unique identifier for the selected replica.""" + + node_ip: str + """IP address of the node running this replica.""" + + port: Optional[int] + """Port number for direct communication (if configured).""" + + node_id: str + """Ray node ID where the replica is running.""" + + availability_zone: Optional[str] + """Cloud availability zone of the replica's node.""" + + # Internal fields (not part of public API) + _replica: RunningReplica + _deployment_id: Optional[DeploymentID] + _request_metadata: RequestMetadata + _method_name: str + _slot_token: str # Token for reserved slot + _dispatched: bool = field( + default=False, init=False + ) # Tracks if dispatch was called + + @property + def address(self) -> str: + """Returns the replica address in host:port format.""" + if self.port: + return f"{self.node_ip}:{self.port}" + return self.node_ip + + def to_dict(self) -> Dict[str, Any]: + """Serialize public fields to a dictionary.""" + return { + "replica_id": self.replica_id, + "node_ip": self.node_ip, + "port": self.port, + "node_id": self.node_id, + "availability_zone": self.availability_zone, + } + + def _mark_dispatched(self) -> None: + """Internal: Mark this selection as dispatched (slot consumed). + + Raises: + RuntimeError: If the selection has already been dispatched. + """ + if self._dispatched: + raise RuntimeError( + f"ReplicaSelection for {self.replica_id} has already been dispatched. " + "Each selection can only be dispatched once." + ) + self._dispatched = True + + async def _release_slot( + self, *, force: bool = False + ) -> Optional[ReplicaQueueLengthInfo]: + """Internal: Release the reserved slot.""" + if self._dispatched and not force: + return None + + return await self._replica.release_slot(self._slot_token) diff --git a/python/ray/serve/_private/request_router/request_router.py b/python/ray/serve/_private/request_router/request_router.py index dd9b5c4f0021..c39a6c61e925 100644 --- a/python/ray/serve/_private/request_router/request_router.py +++ b/python/ray/serve/_private/request_router/request_router.py @@ -850,6 +850,25 @@ def on_send_request(self, replica_id: ReplicaID): self._replica_queue_len_cache.update(replica_id, new_queue_len) self._update_router_queue_len_gauge(replica_id, new_queue_len) + def on_replica_result_finished(self, replica_id: ReplicaID): + """Decrement queue length cache when a request finishes or is cancelled. + + This is used when a reserved slot is released without being dispatched + (e.g., in choose_replica context manager cleanup). + + We cannot rely on on_new_queue_len_info() to correct the cache in this + path. The queue length cache is incremented optimistically when a slot is + reserved, before dispatch happens. If dispatch is never called or fails + before the request reaches the replica, no queue_len_info response is + produced, so the cache would otherwise remain inflated. + """ + if self._use_replica_queue_len_cache: + num_ongoing_requests = self._replica_queue_len_cache.get(replica_id) or 0 + if num_ongoing_requests > 0: + new_queue_len = num_ongoing_requests - 1 + self._replica_queue_len_cache.update(replica_id, new_queue_len) + self._update_router_queue_len_gauge(replica_id, new_queue_len) + def decrement_queue_len_cache(self, replica_id: ReplicaID): """Decrement the queue length cache for a replica. diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index c4bad02b0f18..d4463534214c 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -1,6 +1,7 @@ import asyncio import concurrent.futures import logging +import sys import threading import time import weakref @@ -8,11 +9,12 @@ from asyncio import AbstractEventLoop, ensure_future, futures from collections import defaultdict from collections.abc import MutableMapping -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from dataclasses import replace from functools import lru_cache, partial from typing import ( Any, + AsyncIterator, Callable, Coroutine, DefaultDict, @@ -58,7 +60,10 @@ from ray.serve._private.request_router.pow_2_router import ( PowerOfTwoChoicesRequestRouter, ) -from ray.serve._private.request_router.replica_wrapper import RunningReplica +from ray.serve._private.request_router.replica_wrapper import ( + ReplicaSelection, + RunningReplica, +) from ray.serve._private.tracing_utils import ( create_propagated_context, is_span_recording, @@ -81,6 +86,7 @@ BackPressureError, DeploymentUnavailableError, RayServeException, + ReplicaUnavailableError, ) from ray.types import ObjectRef from ray.util import metrics @@ -154,6 +160,25 @@ def __init__( # so non-atomic read and write operations need to be guarded by # this thread-safe lock. self._queries_lock = threading.Lock() + + # Track reserved slots for choose_replica operations + self._num_reserved_slots = 0 + self._reserved_slots_gauge = metrics.Gauge( + "serve_reserved_slots_active", + description=( + "The current number of reserved slots for choose_replica operations." + ), + tag_keys=("deployment", "application", "handle", "actor_id"), + ) + self._reserved_slots_gauge.set_default_tags( + { + "deployment": deployment_id.name, + "application": deployment_id.app_name, + "handle": self._handle_id, + "actor_id": self._self_actor_id, + } + ) + self._reserved_slots_gauge.set(0) # Regularly aggregate and push autoscaling metrics to controller self.metrics_pusher = MetricsPusher() self.metrics_store = InMemoryMetricsStore() @@ -356,6 +381,14 @@ def dec_num_queued_requests(self): if not self._cached_metrics_enabled: self.num_queued_requests_gauge.set(self.num_queued_requests) + def inc_reserved_slots(self): + self._num_reserved_slots += 1 + self._reserved_slots_gauge.set(self._num_reserved_slots) + + def dec_reserved_slots(self): + self._num_reserved_slots -= 1 + self._reserved_slots_gauge.set(self._num_reserved_slots) + def inc_num_running_requests_for_replica(self, replica_id: ReplicaID): with self._queries_lock: self.num_requests_sent_to_replicas[replica_id] += 1 @@ -532,6 +565,26 @@ def assign_request( ) -> concurrent.futures.Future[ReplicaResult]: pass + @abstractmethod + @asynccontextmanager + async def choose_replica( + self, + request_meta: RequestMetadata, + *request_args, + **request_kwargs, + ) -> AsyncIterator[ReplicaSelection]: + pass + + @abstractmethod + def dispatch( + self, + selection: ReplicaSelection, + request_meta: RequestMetadata, + *request_args, + **request_kwargs, + ) -> concurrent.futures.Future[ReplicaResult]: + pass + @abstractmethod async def broadcast( self, @@ -955,6 +1008,7 @@ async def _route_and_send_request_once( result: Optional[ReplicaResult] = None replica: Optional[RunningReplica] = None callback_registered = False + queue_len_incremented = False try: # Resolve request arguments BEFORE incrementing queued requests. # This ensures that queue metrics reflect actual pending work, @@ -982,6 +1036,7 @@ async def _route_and_send_request_once( result = replica.try_send_request(pr, with_rejection=with_rejection) # Proactively update the queue length cache. self.request_router.on_send_request(replica.replica_id) + queue_len_incremented = True # Keep track of requests that have been sent out to replicas if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE: @@ -1076,6 +1131,11 @@ async def _route_and_send_request_once( self.request_router.on_request_completed( replica.replica_id, pr.metadata.internal_request_id ) + # Only decrement if on_send_request was called (i.e., try_send_request + # succeeded and incremented the cache). If try_send_request raised error, + # on_send_request was never called so there is no +1 to undo. + if queue_len_incremented: + self.request_router.on_replica_result_finished(replica.replica_id) return None @@ -1196,6 +1256,216 @@ async def assign_request( if exc: set_span_exception(exc, escaped=True) + @asynccontextmanager + async def choose_replica( + self, + request_meta: RequestMetadata, + *request_args, + **request_kwargs, + ) -> AsyncIterator[ReplicaSelection]: + """Execute routing and reserve a slot, with automatic cleanup. + + This method: + 1. Checks deployment availability + 2. Checks backpressure (max_queued_requests) + 3. Increments serve_num_router_requests metric + 4. Selects a replica and reserves a slot + 5. Increments serve_reserved_slots_active metric + 6. Yields the ReplicaSelection + 7. On exit, releases the slot if not dispatched + """ + if not self._deployment_available: + raise DeploymentUnavailableError(self.deployment_id) + + # Wait for the router to be initialized before sending the request. + await self._request_router_initialized.wait() + + with self._metrics_manager.wrap_request_assignment(request_meta): + pr = PendingRequest( + args=list(request_args), + kwargs=request_kwargs, + metadata=request_meta, + ) + + # Resolve request arguments BEFORE incrementing queued requests. + # This ensures that queue metrics reflect actual pending work, + # not time spent waiting for upstream DeploymentResponse arguments. + # See: https://github.com/ray-project/ray/issues/60624 + if not pr.resolved: + await self._resolve_request_arguments(pr) + + is_retry = False + while True: + num_curr_replicas = len(self.request_router.curr_replicas) + with self._metrics_manager.wrap_queued_request( + is_retry=is_retry, num_curr_replicas=num_curr_replicas + ): + replica = await self.request_router._choose_replica_for_request( + pr, is_retry=is_retry + ) + + # Reserve capacity on the replica actor. This must happen on + # the replica, not just in the router cache, so dispatch can + # send without the rejection protocol. + try: + slot_token, queue_info = await replica.reserve_slot( + request_meta + ) + except ActorDiedError as e: + self._handle_actor_died_error( + replica.replica_id, replica.actor_id, e + ) + is_retry = True + continue + except ActorUnavailableError: + self.request_router.on_replica_actor_unavailable( + replica.replica_id + ) + logger.warning( + f"{replica.replica_id} is temporarily unavailable." + ) + is_retry = True + continue + + self.request_router.on_new_queue_len_info( + replica.replica_id, queue_info + ) + if queue_info.accepted: + break + + is_retry = True + + # Increment reserved slots metric (after queue metric is decremented) + self._metrics_manager.inc_reserved_slots() + + selection = ReplicaSelection( + replica_id=replica.replica_id.unique_id, + node_ip=replica._replica_info.node_ip, + port=replica._replica_info.port, + node_id=replica.node_id, + availability_zone=replica.availability_zone, + _replica=replica, + _deployment_id=None, # Injected by DeploymentHandle for dispatch-time validation. + _request_metadata=request_meta, + _method_name=request_meta.call_method, + _slot_token=slot_token, + ) + + try: + yield selection + finally: + queue_info = await selection._release_slot() + if queue_info is not None: + self.request_router.on_new_queue_len_info( + replica.replica_id, queue_info + ) + + # Decrement reserved slots metric + self._metrics_manager.dec_reserved_slots() + + async def dispatch( + self, + selection: ReplicaSelection, + request_meta: RequestMetadata, + *request_args, + **request_kwargs, + ) -> ReplicaResult: + """Dispatch to a specific replica, consuming the reserved slot. + + Args: + selection: The replica selection from choose_replica(). + request_meta: Request metadata. + *request_args: Request positional arguments. + **request_kwargs: Request keyword arguments. + + Returns: + ReplicaResult for the dispatched request. + + Raises: + RuntimeError: If the selection has already been dispatched. + ReplicaUnavailableError: If the replica is no longer available. + """ + selection._mark_dispatched() + return await self._dispatch_to_marked_selection( + selection, request_meta, *request_args, **request_kwargs + ) + + async def _dispatch_to_marked_selection( + self, + selection: ReplicaSelection, + request_meta: RequestMetadata, + *request_args, + **request_kwargs, + ) -> ReplicaResult: + """Dispatch a selection that has already been consumed by dispatch().""" + # Verify replica is still available + replica = selection._replica + pr = PendingRequest( + args=list(request_args), + kwargs=request_kwargs, + metadata=replace(request_meta, _reserved_slot_token=selection._slot_token), + ) + + # Send the request without rejection since we already reserved a slot + # The slot reservation guarantees that the replica will accept this request + try: + if replica.replica_id not in self.request_router.curr_replicas: + raise ReplicaUnavailableError( + f"Replica {selection.replica_id} is no longer available" + ) + if not pr.resolved: + await self._resolve_request_arguments(pr) + result = replica.try_send_request(pr, with_rejection=False) + except BaseException: + queue_info = await selection._release_slot(force=True) + if queue_info is not None: + self.request_router.on_new_queue_len_info( + replica.replica_id, queue_info + ) + raise + + # Keep track of requests that have been sent out to replicas + if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE: + self._metrics_manager.inc_num_running_requests_for_replica( + replica.replica_id + ) + + # Always register callback to notify router when request completes + # (needed for token release in queue-based routing, metrics tracking, etc.) + callback = partial( + self._process_finished_request, + replica.replica_id, + pr.metadata.internal_request_id, + replica.actor_id, + ) + result.add_done_callback( + lambda _, cb=callback: self._event_loop.call_soon_threadsafe(cb, _) + ) + result.add_done_callback( + lambda _: self._event_loop.call_soon_threadsafe( + self.request_router.decrement_queue_len_cache, + replica.replica_id, + ) + ) + result.add_done_callback( + lambda _: self._event_loop.call_soon_threadsafe( + lambda: self._event_loop.create_task( + self._release_slot_if_still_reserved(selection) + ) + ) + ) + + return result + + async def _release_slot_if_still_reserved( + self, selection: ReplicaSelection + ) -> None: + """Best-effort cleanup if a dispatched request was cancelled before starting.""" + try: + await selection._release_slot(force=True) + except Exception: + logger.debug("Failed to release reserved replica slot.", exc_info=True) + async def broadcast( self, request_meta: RequestMetadata, @@ -1390,10 +1660,94 @@ def assign_request( asyncio event loop thread. It returns a `concurrent.futures.Future` that can be awaited or queried from the calling thread. + Args: + request_meta: Metadata for the request. + *request_args: Positional arguments for the request. + **request_kwargs: Keyword arguments for the request. + Returns: A concurrent.futures.Future resolving to the ReplicaResult representing the assigned request. """ + return self._wrap_asyncio_call_in_future( + self._asyncio_router.assign_request( + request_meta, *request_args, **request_kwargs + ) + ) + + @asynccontextmanager + async def choose_replica( + self, + request_meta: RequestMetadata, + *request_args, + **request_kwargs, + ) -> AsyncIterator[ReplicaSelection]: + """Bridge async context manager to router event loop. + + This ensures choose_replica runs on the singleton router loop, + maintaining thread safety for all state modifications. + """ + # Enter context on router loop + async def enter_context(): + cm = self._asyncio_router.choose_replica( + request_meta, *request_args, **request_kwargs + ) + selection = await cm.__aenter__() + return selection, cm + + future = asyncio.run_coroutine_threadsafe(enter_context(), self._asyncio_loop) + selection, context_manager = await asyncio.wrap_future(future) + + try: + yield selection + finally: + # Exit context on router loop + async def exit_context(exc_type, exc_val, exc_tb): + return await context_manager.__aexit__(exc_type, exc_val, exc_tb) + + exc_info = sys.exc_info() + future = asyncio.run_coroutine_threadsafe( + exit_context(*exc_info), self._asyncio_loop + ) + await asyncio.wrap_future(future) + + def dispatch( + self, + selection: ReplicaSelection, + request_meta: RequestMetadata, + *request_args, + **request_kwargs, + ) -> concurrent.futures.Future[ReplicaResult]: + """Dispatch request to a previously selected replica.""" + try: + selection._mark_dispatched() + except Exception as exc: + future = concurrent.futures.Future() + future.set_exception(exc) + return future + + return self._wrap_asyncio_call_in_future( + self._asyncio_router._dispatch_to_marked_selection( + selection, request_meta, *request_args, **request_kwargs + ) + ) + + def _wrap_asyncio_call_in_future( + self, + coro: Coroutine, + ) -> concurrent.futures.Future[ReplicaResult]: + """Wrap an async call in a concurrent.futures.Future for cross-thread execution. + + This is a helper method to execute AsyncioRouter's async methods on the dedicated asyncio event loop thread. + + Args: + coro: The coroutine to execute (e.g., _asyncio_router.assign_request(...)) + + Returns: + A concurrent.futures.Future that resolves to the ReplicaResult. + """ + # Extract operation name from coroutine for logging + operation_name = coro.__name__ if hasattr(coro, "__name__") else "operation" def asyncio_future_callback( asyncio_future: asyncio.Future, concurrent_future: concurrent.futures.Future @@ -1406,7 +1760,6 @@ def asyncio_future_callback( asyncio_future didn't see the cancellation event in time. Think of it like a second line of defense for cancellation of replica results. """ - # Check if the cancellation originated from the concurrent.futures.Future if ( concurrent_future.cancelled() and not asyncio_future.cancelled() @@ -1414,27 +1767,21 @@ def asyncio_future_callback( ): result: ReplicaResult = asyncio_future.result() logger.info( - "Asyncio task completed despite cancellation attempt. " - "Attempting to cancel the request that was assigned to a replica." + f"Asyncio task completed despite cancellation attempt during {operation_name}. " + "Attempting to cancel the request." ) result.cancel() concurrent_future = concurrent.futures.Future() def create_task_and_setup(): - task = self._asyncio_loop.create_task( - self._asyncio_router.assign_request( - request_meta, *request_args, **request_kwargs - ) - ) + task = self._asyncio_loop.create_task(coro) - # Set up your cancellation callback task.add_done_callback( lambda _: asyncio_future_callback(_, concurrent_future) ) try: - # chain the two futures to handle direction channel of cancellation futures._chain_future( ensure_future(task, loop=self._asyncio_loop), concurrent_future ) @@ -1445,7 +1792,6 @@ def create_task_and_setup(): concurrent_future.set_exception(exc) raise - # Schedule on the event loop thread self._asyncio_loop.call_soon_threadsafe(create_task_and_setup) return concurrent_future @@ -1578,6 +1924,43 @@ def assign_request( ), ) + @asynccontextmanager + async def choose_replica( + self, + request_meta: RequestMetadata, + *request_args, + **request_kwargs, + ) -> AsyncIterator[ReplicaSelection]: + """Delegate to AsyncioRouter's choose_replica.""" + async with self._asyncio_router.choose_replica( + request_meta, *request_args, **request_kwargs + ) as selection: + yield selection + + def dispatch( + self, + selection: ReplicaSelection, + request_meta: RequestMetadata, + *request_args, + **request_kwargs, + ) -> asyncio.Future[ReplicaResult]: + """Dispatch request to a previously selected replica. + + Returns an asyncio.Future wrapping the async dispatch call. + """ + try: + selection._mark_dispatched() + except Exception as exc: + future = self._asyncio_loop.create_future() + future.set_exception(exc) + return future + + return self._asyncio_loop.create_task( + self._asyncio_router._dispatch_to_marked_selection( + selection, request_meta, *request_args, **request_kwargs + ) + ) + async def broadcast( self, request_meta: RequestMetadata, diff --git a/python/ray/serve/exceptions.py b/python/ray/serve/exceptions.py index 6e3d201655e4..51671c7fbd14 100644 --- a/python/ray/serve/exceptions.py +++ b/python/ray/serve/exceptions.py @@ -101,3 +101,10 @@ def __init__(self, deployment_id: DeploymentID): @property def message(self) -> str: return f"{self._deployment_id} is unavailable because it failed to deploy." + + +@PublicAPI(stability="alpha") +class ReplicaUnavailableError(RayServeException): + """Raised when the selected replica is no longer available.""" + + pass diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 6b274ffdea47..4c299710eae5 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -3,6 +3,7 @@ import logging import time import warnings +from contextlib import asynccontextmanager from typing import ( Any, AsyncIterator, @@ -19,6 +20,8 @@ cast, ) +from typing_extensions import AsyncContextManager + import ray from ray import serve from ray._raylet import ObjectRefGenerator # type: ignore[attr-defined] @@ -40,6 +43,7 @@ InitHandleOptionsBase, ) from ray.serve._private.replica_result import ReplicaResult +from ray.serve._private.request_router.replica_wrapper import ReplicaSelection from ray.serve._private.router import Router from ray.serve._private.usage import ServeUsageTag from ray.serve._private.utils import ( @@ -182,6 +186,24 @@ def _is_router_running_in_separate_loop(self) -> bool: return False return self.init_options._run_router_in_separate_loop + def _init_router(self) -> Router: + if not self.is_initialized: + self._init() + + if self._router is None: + raise RuntimeError("Router is not initialized") + + return self._router + + def _init_router_and_get_metadata(self) -> Tuple[Router, RequestMetadata]: + router = self._init_router() + + metadata = serve._private.default_impl.get_request_metadata( + self.init_options, self.handle_options + ) + + return router, metadata + def _options( self, _prefer_local_routing=DEFAULT.VALUE, **kwargs ) -> "DeploymentHandle[T]": @@ -216,23 +238,58 @@ def _remote( args: Tuple[Any], kwargs: Dict[str, Any], ) -> Tuple[concurrent.futures.Future, RequestMetadata]: - if not self.is_initialized: - self._init() + router, metadata = self._init_router_and_get_metadata() - metadata = serve._private.default_impl.get_request_metadata( - self.init_options, self.handle_options + self.request_counter.inc( + tags={ + "route": metadata.route, + "application": metadata.app_name, + } ) + return router.assign_request(metadata, *args, **kwargs), metadata + @asynccontextmanager + async def _choose_replica( + self, + args: Tuple[Any], + kwargs: Dict[str, Any], + ) -> AsyncIterator[ReplicaSelection]: + """Execute the request router to select a replica without dispatching.""" + router, metadata = self._init_router_and_get_metadata() self.request_counter.inc( tags={ "route": metadata.route, "application": metadata.app_name, } ) - if self._router is None: - raise RuntimeError("Router is not initialized") - return self._router.assign_request(metadata, *args, **kwargs), metadata + # Call the router's choose_replica and inject the deployment handle + async with router.choose_replica(metadata, *args, **kwargs) as selection: + # Record the owning deployment for dispatch-time validation. + selection._deployment_id = self.deployment_id + yield selection + + def _dispatch( + self, + selection: ReplicaSelection, + args: Tuple[Any], + kwargs: Dict[str, Any], + ) -> Tuple[concurrent.futures.Future, RequestMetadata]: + """Dispatch a request to a previously selected replica.""" + # Validate that the selection belongs to the same deployment + if ( + selection._deployment_id is not None + and selection._deployment_id != self.deployment_id + ): + raise ValueError( + f"Cannot dispatch a selection created for a different deployment. " + f"This handle is for {self.deployment_id}, but the selection was created " + f"for {selection._deployment_id}." + ) + + metadata = selection._request_metadata + router = self._init_router() + return router.dispatch(selection, metadata, *args, **kwargs), metadata def options( self, @@ -1145,6 +1202,73 @@ def remote( _is_router_running_in_separate_loop=self._is_router_running_in_separate_loop(), ) + def choose_replica( + self, + *args: Any, + **kwargs: Any, + ) -> AsyncContextManager[ReplicaSelection]: + """Execute the request router to select a replica without dispatching. + + This method runs the full routing logic (load balancing, locality awareness, + queue length probing, etc.) and returns an async context manager that yields + a ReplicaSelection. A request slot is reserved on the selected replica, + guaranteeing that dispatch will succeed. + + The context manager ensures proper cleanup: + - If dispatch() is called, the slot is consumed normally. + - If the context exits without dispatch (e.g., exception, early return), the slot is released. + + Args: + *args: Arguments that may influence routing decisions + **kwargs: Keyword arguments that may influence routing decisions. + + Returns: + AsyncContextManager[ReplicaSelection] - must be used with async with. + """ + return self._choose_replica(args, kwargs) + + def dispatch( + self, + selection: ReplicaSelection, + *args: Any, + **kwargs: Any, + ) -> Union[DeploymentResponse[Any], DeploymentResponseGenerator[Any]]: + """Dispatch a request to a previously selected replica. + + By default, the result is a `DeploymentResponse` that can be awaited to fetch + the result of the call. Like `.remote()`, `DeploymentResponse` objects can be + passed as arguments for deployment composition. + + If `handle.options(stream=True)` is set and a generator method is called, this + returns a `DeploymentResponseGenerator` instead. + If the selected replica becomes unavailable before dispatch executes, + ``ReplicaUnavailableError`` is propagated from the router dispatch path. + + Args: + selection: A ReplicaSelection from choose_replica() context manager. + *args: The request arguments to send to the replica. + **kwargs: The request keyword arguments to send to the replica. + + Returns: + DeploymentResponse or DeploymentResponseGenerator (if streaming). + + Raises: + ValueError: If selection was created by a different DeploymentHandle. + """ + future, request_metadata = self._dispatch(selection, args, kwargs) + if self.handle_options.stream: + return DeploymentResponseGenerator( + future, + request_metadata, + _is_router_running_in_separate_loop=self._is_router_running_in_separate_loop(), + ) + else: + return DeploymentResponse( + future, + request_metadata, + _is_router_running_in_separate_loop=self._is_router_running_in_separate_loop(), + ) + def broadcast( self, method_name: str, diff --git a/python/ray/serve/tests/test_handle.py b/python/ray/serve/tests/test_handle.py index fb1bd2cac9be..4ed3d70cf8ed 100644 --- a/python/ray/serve/tests/test_handle.py +++ b/python/ray/serve/tests/test_handle.py @@ -6,12 +6,17 @@ import ray from ray import serve from ray._common.test_utils import SignalActor, wait_for_condition -from ray.serve._private.common import OBJ_REF_NOT_SUPPORTED_ERROR +from ray.serve._private.common import ( + OBJ_REF_NOT_SUPPORTED_ERROR, + DeploymentID, + RequestMetadata, +) from ray.serve._private.replica_result import ( ActorReplicaResult, ReplicaResult, gRPCReplicaResult, ) +from ray.serve._private.request_router.replica_wrapper import ReplicaSelection from ray.serve.handle import DeploymentHandle from ray.serve.tests.conftest import * # noqa from ray.serve.tests.conftest import _shared_serve_instance # noqa @@ -123,6 +128,29 @@ def __call__(self, inp1: str, inp2: str): ) +def test_dispatch_rejects_selection_from_different_deployment(): + handle = DeploymentHandle("deployment-a", "app") + selection = ReplicaSelection( + replica_id="replica-1", + node_ip="127.0.0.1", + port=None, + node_id="node-1", + availability_zone=None, + _replica=object(), + _deployment_id=DeploymentID(name="deployment-b", app_name="app"), + _request_metadata=RequestMetadata( + request_id="request-id", + internal_request_id="internal-request-id", + call_method="__call__", + ), + _method_name="__call__", + _slot_token="slot-1", + ) + + with pytest.raises(ValueError, match="different deployment"): + handle.dispatch(selection) + + @pytest.mark.asyncio @pytest.mark.timeout(30) async def test_non_grpc_exception_no_self_cause(serve_instance): diff --git a/python/ray/serve/tests/test_handle_1.py b/python/ray/serve/tests/test_handle_1.py index 20cbc7442976..e78b14599aea 100644 --- a/python/ray/serve/tests/test_handle_1.py +++ b/python/ray/serve/tests/test_handle_1.py @@ -2,6 +2,7 @@ import concurrent.futures import sys import threading +from contextlib import AsyncExitStack from typing import Any import pytest @@ -13,6 +14,7 @@ RAY_SERVE_FORCE_LOCAL_TESTING_MODE, SERVE_DEFAULT_APP_NAME, ) +from ray.serve.api import get_replica_context from ray.serve.exceptions import RayServeException from ray.serve.handle import DeploymentHandle @@ -339,5 +341,238 @@ async def __call__(self): assert h.remote().result(timeout_s=10) == ("((r1))", "((r1))") +@pytest.mark.asyncio +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode doesn't support choose_replica/dispatch", +) +async def test_choose_replica_and_dispatch_single(serve_instance): + """Test choose_replica + dispatch for simple single selection pattern.""" + + @serve.deployment(num_replicas=2) + class Backend: + def process(self, msg: str): + replica_id = get_replica_context().replica_id.unique_id + return {"actual_replica_id": replica_id, "response": msg} + + @serve.deployment + class SimpleProxy: + def __init__(self, backend: DeploymentHandle): + self.backend = backend + + async def handle_request(self, request: str): + # Context manager ensures slot is released if dispatch fails or is skipped + async with self.backend.process.choose_replica(request) as selection: + assert selection.replica_id is not None + assert selection.node_ip is not None + + # Dispatch to the selected replica + response = await self.backend.process.dispatch(selection, request) + + # Return both the selection and the response for verification + return {"selected_replica_id": selection.replica_id, **response} + + h = serve.run(SimpleProxy.bind(Backend.bind())) + result = await h.handle_request.remote("test_message") + + # Verify the result contains the message + assert result["response"] == "test_message" + + # Verify that dispatch sent the request to the replica we selected + assert result["actual_replica_id"] == result["selected_replica_id"], ( + f"dispatch sent request to wrong replica: " + f"selected {result['selected_replica_id']}, but got response from {result['actual_replica_id']}" + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode doesn't support choose_replica/dispatch", +) +async def test_choose_replica_early_return_releases_slot(serve_instance): + """Early return from choose_replica should release the reserved slot.""" + + @serve.deployment(num_replicas=1, max_ongoing_requests=1) + class Backend: + def process(self, msg: str): + return msg + + @serve.deployment + class Proxy: + def __init__(self, backend: DeploymentHandle): + self.backend = backend + + async def handle_request(self, request: str): + async with self.backend.process.choose_replica(request): + pass + + # Ensure the second choose_replica request work + async with self.backend.process.choose_replica(request) as selection: + response = await asyncio.wait_for( + self.backend.process.dispatch(selection, request), timeout=2 + ) + return response + + h = serve.run(Proxy.bind(Backend.bind())) + assert await h.handle_request.remote("test_message") == "test_message" + + +@pytest.mark.asyncio +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode doesn't support choose_replica/dispatch", +) +async def test_choose_replica_exception_releases_slot(serve_instance): + """Exception in choose_replica context should release the reserved slot.""" + + @serve.deployment(num_replicas=1, max_ongoing_requests=1) + class Backend: + def process(self, msg: str): + return msg + + @serve.deployment + class Proxy: + def __init__(self, backend: DeploymentHandle): + self.backend = backend + + async def handle_request(self, request: str): + try: + async with self.backend.process.choose_replica(request): + raise RuntimeError("test exception") + except RuntimeError: + pass + + # Ensure the second choose_replica request work + async with self.backend.process.choose_replica(request) as selection: + response = await asyncio.wait_for( + self.backend.process.dispatch(selection, request), timeout=2 + ) + return response + + h = serve.run(Proxy.bind(Backend.bind())) + assert await h.handle_request.remote("test_message") == "test_message" + + +@pytest.mark.asyncio +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode doesn't support choose_replica/dispatch", +) +async def test_choose_replica_and_dispatch_streaming(serve_instance): + """Test choose_replica + dispatch with handle.options(stream=True).""" + + @serve.deployment(num_replicas=2) + class Backend: + async def stream(self, msg: str): + replica_id = get_replica_context().replica_id.unique_id + for i in range(3): + yield {"actual_replica_id": replica_id, "chunk": f"{msg}-{i}"} + + @serve.deployment + class StreamingProxy: + def __init__(self, backend: DeploymentHandle): + self.backend_stream = backend.stream.options(stream=True) + + async def handle_request(self, request: str): + async with self.backend_stream.choose_replica(request) as selection: + gen = self.backend_stream.dispatch(selection, request) + chunks = [item async for item in gen] + return {"selected_replica_id": selection.replica_id, "chunks": chunks} + + h = serve.run(StreamingProxy.bind(Backend.bind())) + result = await h.handle_request.remote("stream_test") + + assert [item["chunk"] for item in result["chunks"]] == [ + "stream_test-0", + "stream_test-1", + "stream_test-2", + ] + assert all( + item["actual_replica_id"] == result["selected_replica_id"] + for item in result["chunks"] + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + RAY_SERVE_FORCE_LOCAL_TESTING_MODE, + reason="local_testing_mode doesn't support choose_replica/dispatch", +) +async def test_choose_replica_and_dispatch_parallel(serve_instance): + """Test parallel selection pattern (e.g., PD proxy) using AsyncExitStack.""" + + @serve.deployment(num_replicas=2) + class PrefillServer: + def chat(self, msg: str): + replica_id = get_replica_context().replica_id.unique_id + return {"actual_replica_id": replica_id, "response": msg} + + @serve.deployment(num_replicas=2) + class DecodeServer: + def chat(self, msg: str): + replica_id = get_replica_context().replica_id.unique_id + return {"actual_replica_id": replica_id, "response": msg} + + @serve.deployment + class PDProxy: + def __init__( + self, + prefill_server: DeploymentHandle, + decode_server: DeploymentHandle, + ): + self.prefill = prefill_server + self.decode = decode_server + + async def handle_request(self, request: str): + # Use AsyncExitStack to manage multiple context managers in parallel + async with AsyncExitStack() as stack: + # Select and RESERVE replicas from BOTH deployments in parallel + p_selection, d_selection = await asyncio.gather( + stack.enter_async_context(self.prefill.chat.choose_replica()), + stack.enter_async_context(self.decode.chat.choose_replica()), + ) + + p_msg = f"prefill:{request}" + d_msg = f"decode:{request}" + + # Dispatch to both selected replicas + p_result, d_result = await asyncio.gather( + self.prefill.chat.dispatch(p_selection, p_msg), + self.decode.chat.dispatch(d_selection, d_msg), + ) + return { + "prefill": { + "selected_replica_id": p_selection.replica_id, + **p_result, + }, + "decode": { + "selected_replica_id": d_selection.replica_id, + **d_result, + }, + } + + h = serve.run(PDProxy.bind(PrefillServer.bind(), DecodeServer.bind())) + result = await h.handle_request.remote("test_parallel") + + assert result["prefill"]["response"] == "prefill:test_parallel" + assert result["decode"]["response"] == "decode:test_parallel" + + # Verify that dispatch sent the request to the replica we selected + assert ( + result["prefill"]["actual_replica_id"] + == result["prefill"]["selected_replica_id"] + ), ( + f"dispatch sent request to wrong replica for prefill: " + f"selected {result['prefill']['selected_replica_id']}, but got response from {result['prefill']['actual_replica_id']}" + ) + assert ( + result["decode"]["actual_replica_id"] == result["decode"]["selected_replica_id"] + ), ( + f"dispatch sent request to wrong replica for decode: " + f"selected {result['decode']['selected_replica_id']}, but got response from {result['decode']['actual_replica_id']}" + ) + + if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/unit/test_pow_2_request_router.py b/python/ray/serve/tests/unit/test_pow_2_request_router.py index c3855dc6d19a..0cfdcbbf2605 100644 --- a/python/ray/serve/tests/unit/test_pow_2_request_router.py +++ b/python/ray/serve/tests/unit/test_pow_2_request_router.py @@ -63,6 +63,7 @@ def __init__( self._availability_zone = availability_zone self._queue_len = 0 self._max_ongoing_requests = max_ongoing_requests + self._reserved_slots: Set[str] = set() self._has_queue_len_response = asyncio.Event() self._reset_after_response = reset_after_response self._model_ids = model_ids or set() diff --git a/python/ray/serve/tests/unit/test_router.py b/python/ray/serve/tests/unit/test_router.py index e624bfa27d6d..4c4040f2a0cd 100644 --- a/python/ray/serve/tests/unit/test_router.py +++ b/python/ray/serve/tests/unit/test_router.py @@ -4,6 +4,7 @@ import sys import threading from collections import defaultdict +from dataclasses import replace from typing import Callable, Dict, List, Optional, Set, Tuple from unittest.mock import Mock, patch @@ -26,6 +27,7 @@ RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE, RAY_SERVE_METRICS_EXPORT_INTERVAL_MS, ) +from ray.serve._private.replica import Replica as ServeReplica from ray.serve._private.replica_result import ReplicaResult from ray.serve._private.request_router import ( PendingRequest, @@ -36,13 +38,22 @@ from ray.serve._private.router import ( QUEUED_REQUESTS_KEY, AsyncioRouter, + CurrentLoopRouter, RouterMetricsManager, SingletonThreadRouter, ) from ray.serve._private.test_utils import FakeCounter, FakeGauge, MockTimer -from ray.serve._private.utils import decompress_metric_report, get_random_string +from ray.serve._private.utils import ( + Semaphore, + decompress_metric_report, + get_random_string, +) from ray.serve.config import AutoscalingConfig, RequestRouterConfig -from ray.serve.exceptions import BackPressureError, DeploymentUnavailableError +from ray.serve.exceptions import ( + BackPressureError, + DeploymentUnavailableError, + ReplicaUnavailableError, +) class FakeReplicaResult(ReplicaResult): @@ -103,18 +114,43 @@ def __init__( queue_len_info: Optional[ReplicaQueueLengthInfo] = None, is_cross_language: bool = False, error: Optional[Exception] = None, + node_id: str = "fake-node-id", + availability_zone: Optional[str] = None, + node_ip: str = "127.0.0.1", + port: Optional[int] = 8000, actor_id: Optional[ray.ActorID] = None, ): self._replica_id = replica_id self._is_cross_language = is_cross_language self._queue_len_info = queue_len_info self._error = error + self._reserved_slots: Set[str] = set() + self._slot_counter = 0 + self._requests_sent = [] # Track all requests sent to this replica + self._reject_reservation = False + self._reservation_queue_len = 1 + + # Create a minimal _replica_info object to satisfy router.py requirements + self._replica_info = Mock() + self._replica_info.node_id = node_id + self._replica_info.availability_zone = availability_zone + self._replica_info.node_ip = node_ip + self._replica_info.port = port + self._replica_info.replica_id = replica_id self._actor_id = actor_id @property def replica_id(self) -> ReplicaID: return self._replica_id + @property + def node_id(self) -> str: + return self._replica_info.node_id + + @property + def availability_zone(self) -> Optional[str]: + return self._replica_info.availability_zone + @property def actor_id(self) -> Optional[ray.ActorID]: return self._actor_id @@ -126,6 +162,44 @@ def is_cross_language(self) -> bool: def get_queue_len(self, *, deadline_s: float) -> int: raise NotImplementedError + async def reserve_slot( + self, request_metadata: RequestMetadata + ) -> Tuple[str, ReplicaQueueLengthInfo]: + """Reserve a slot and return a token.""" + if self._reject_reservation: + return "", ReplicaQueueLengthInfo( + accepted=False, + num_ongoing_requests=self._reservation_queue_len, + ) + + self._slot_counter += 1 + token = f"slot-token-{self._slot_counter}" + self._reserved_slots.add(token) + return token, ReplicaQueueLengthInfo( + accepted=True, + num_ongoing_requests=len(self._reserved_slots), + ) + + async def release_slot(self, slot_token: str) -> ReplicaQueueLengthInfo: + """Release a reserved slot.""" + self._reserved_slots.discard(slot_token) + return ReplicaQueueLengthInfo( + accepted=True, + num_ongoing_requests=len(self._reserved_slots), + ) + + def send_request_with_slot( + self, pr: PendingRequest, slot_token: str + ) -> FakeReplicaResult: + """Send request using a reserved slot.""" + assert slot_token in self._reserved_slots, f"Invalid slot token: {slot_token}" + + # Create result same way as try_send_request + if pr.metadata.is_streaming: + return FakeReplicaResult(self._replica_id, is_generator_object=True) + else: + return FakeReplicaResult(self._replica_id, is_generator_object=False) + def try_send_request( self, pr: PendingRequest, with_rejection: bool ) -> FakeReplicaResult: @@ -134,6 +208,17 @@ def try_send_request( if self._error: raise self._error + # Track the request + if pr.metadata._reserved_slot_token: + self._reserved_slots.discard(pr.metadata._reserved_slot_token) + + self._requests_sent.append( + { + "request_id": pr.metadata.request_id, + "with_rejection": with_rejection, + } + ) + if with_rejection: assert ( not self.is_cross_language @@ -209,6 +294,15 @@ def on_send_request(self, replica_id: ReplicaID): num_ongoing_requests = self._replica_queue_len_cache.get(replica_id) or 0 self._replica_queue_len_cache.update(replica_id, num_ongoing_requests + 1) + def on_replica_result_finished(self, replica_id: ReplicaID): + """Decrement queue length cache when a request finishes or is cancelled.""" + if self._use_queue_len_cache: + num_ongoing_requests = self._replica_queue_len_cache.get(replica_id) or 0 + if num_ongoing_requests > 0: + self._replica_queue_len_cache.update( + replica_id, num_ongoing_requests - 1 + ) + def on_replica_actor_unavailable(self, replica_id: ReplicaID): self._replica_queue_len_cache.invalidate_key(replica_id) @@ -305,6 +399,52 @@ def dummy_request_metadata(is_streaming: bool = False) -> RequestMetadata: ) +class FakeReplicaMetricsManager: + def __init__(self): + self.num_ongoing_requests = 0 + + def get_num_ongoing_requests(self) -> int: + return self.num_ongoing_requests + + def inc_num_ongoing_requests(self, request_metadata: RequestMetadata): + self.num_ongoing_requests += 1 + + def dec_num_ongoing_requests(self, request_metadata: RequestMetadata): + self.num_ongoing_requests -= 1 + + +@pytest.mark.asyncio +class TestReplicaSlotReservation: + async def test_reserved_slot_counts_against_replica_capacity(self): + replica = ServeReplica.__new__(ServeReplica) + replica._deployment_config = Mock(max_ongoing_requests=1) + replica._metrics_manager = FakeReplicaMetricsManager() + replica._reserved_slots = set() + replica._semaphore = Semaphore(lambda: replica.max_ongoing_requests) + + request_metadata = dummy_request_metadata() + + slot_token = "slot-token-1" + accepted, queue_len = await replica.reserve_slot(request_metadata, slot_token) + assert accepted + assert slot_token in replica._reserved_slots + assert queue_len == 1 + assert replica.get_num_ongoing_requests() == 1 + + accepted, queue_len = await replica.reserve_slot( + request_metadata, "slot-token-2" + ) + assert not accepted + assert queue_len == 1 + + dispatch_metadata = replace(request_metadata, _reserved_slot_token=slot_token) + async with replica._start_request(dispatch_metadata): + assert slot_token not in replica._reserved_slots + assert replica.get_num_ongoing_requests() == 1 + + assert replica.get_num_ongoing_requests() == 0 + + @pytest.mark.asyncio class TestBroadcast: async def test_unavailable_fails_without_waiting_for_router_init(self): @@ -1073,6 +1213,519 @@ async def test_on_request_routed( assert fake_request_router.on_request_routed_called is True +@pytest.mark.asyncio +class TestChooseReplica: + """Tests for choose_replica() and dispatch() flow.""" + + async def test_choose_replica_raises_deployment_unavailable( + self, setup_router: Tuple[AsyncioRouter, FakeRequestRouter] + ): + """choose_replica fails when deployment is marked unavailable.""" + router, _ = setup_router + router._deployment_available = False + + request_metadata = RequestMetadata( + request_id="test-request-1", + internal_request_id="test-internal-request-1", + ) + + with pytest.raises(DeploymentUnavailableError): + async with router.choose_replica(request_metadata): + pass + + async def test_choose_replica_raises_backpressure( + self, setup_router: Tuple[AsyncioRouter, FakeRequestRouter] + ): + """choose_replica raises BackPressureError when queue limit is reached.""" + router, _ = setup_router + router.update_deployment_config(DeploymentConfig(max_queued_requests=1)) + # Bump num_queued_requests to 1 to simulate a full queue + router._metrics_manager.inc_num_queued_requests() + + request_metadata = RequestMetadata( + request_id="test-request-1", + internal_request_id="test-internal-request-1", + ) + + with pytest.raises(BackPressureError): + async with router.choose_replica(request_metadata): + pass + + async def test_basic_choose_and_dispatch( + self, setup_router: Tuple[AsyncioRouter, FakeRequestRouter] + ): + """Test basic choose_replica() and dispatch() workflow.""" + router, fake_request_router = setup_router + + r1_id = ReplicaID( + unique_id="test-replica-1", deployment_id=DeploymentID(name="test") + ) + replica = FakeReplica(r1_id) + fake_request_router.set_replica_to_return(replica) + + request_metadata = RequestMetadata( + request_id="test-request-1", + internal_request_id="test-internal-request-1", + call_method="test_method", + ) + + # Choose replica + async with router.choose_replica(request_metadata) as selection: + # Verify selection contains correct information + assert selection.replica_id == r1_id.unique_id + assert selection._replica == replica + assert selection._method_name == "test_method" + assert selection._slot_token in replica._reserved_slots + + # Dispatch request + replica_result = await router.dispatch(selection, request_metadata) + assert replica_result._replica_id == r1_id + assert not replica_result._is_generator_object + + # After context exit, slot should be released + assert selection._slot_token not in replica._reserved_slots + + async def test_choose_without_dispatch_releases_slot( + self, setup_router: Tuple[AsyncioRouter, FakeRequestRouter] + ): + """Test that exiting context without dispatch releases the slot.""" + router, fake_request_router = setup_router + + r1_id = ReplicaID( + unique_id="test-replica-1", deployment_id=DeploymentID(name="test") + ) + replica = FakeReplica(r1_id) + fake_request_router.set_replica_to_return(replica) + + request_metadata = RequestMetadata( + request_id="test-request-1", + internal_request_id="test-internal-request-1", + ) + + slot_token = None + async with router.choose_replica(request_metadata) as selection: + slot_token = selection._slot_token + assert slot_token in replica._reserved_slots + # Exit without calling dispatch + + # After context exit, slot should be released + assert slot_token not in replica._reserved_slots + + async def test_choose_with_exception_releases_slot( + self, setup_router: Tuple[AsyncioRouter, FakeRequestRouter] + ): + """Test that exception in context releases the slot.""" + router, fake_request_router = setup_router + + r1_id = ReplicaID( + unique_id="test-replica-1", deployment_id=DeploymentID(name="test") + ) + replica = FakeReplica(r1_id) + fake_request_router.set_replica_to_return(replica) + + request_metadata = RequestMetadata( + request_id="test-request-1", + internal_request_id="test-internal-request-1", + ) + + slot_token = None + with pytest.raises(RuntimeError): + async with router.choose_replica(request_metadata) as selection: + slot_token = selection._slot_token + assert slot_token in replica._reserved_slots + raise RuntimeError("Test exception") + + # After exception, slot should be released + assert slot_token not in replica._reserved_slots + + @pytest.mark.parametrize("is_streaming", [False, True]) + async def test_choose_and_dispatch_streaming( + self, + setup_router: Tuple[AsyncioRouter, FakeRequestRouter], + is_streaming: bool, + ): + """Test choose_replica() and dispatch() with streaming requests.""" + router, fake_request_router = setup_router + + r1_id = ReplicaID( + unique_id="test-replica-1", deployment_id=DeploymentID(name="test") + ) + replica = FakeReplica(r1_id) + fake_request_router.set_replica_to_return(replica) + + request_metadata = RequestMetadata( + request_id="test-request-1", + internal_request_id="test-internal-request-1", + is_streaming=is_streaming, + ) + + async with router.choose_replica(request_metadata) as selection: + replica_result = await router.dispatch(selection, request_metadata) + assert replica_result._replica_id == r1_id + if is_streaming: + assert replica_result._is_generator_object + else: + assert not replica_result._is_generator_object + + async def test_slot_reservation_mock_interaction( + self, setup_router: Tuple[AsyncioRouter, FakeRequestRouter] + ): + """Test that reserve_slot and send_request_with_slot are called correctly.""" + router, fake_request_router = setup_router + + r1_id = ReplicaID( + unique_id="test-replica-1", deployment_id=DeploymentID(name="test") + ) + replica = FakeReplica(r1_id) + fake_request_router.set_replica_to_return(replica) + + request_metadata = RequestMetadata( + request_id="test-request-1", + internal_request_id="test-internal-request-1", + ) + + # Track initial state + initial_slot_count = replica._slot_counter + + async with router.choose_replica(request_metadata) as selection: + # Verify reserve_slot was called + assert replica._slot_counter == initial_slot_count + 1 + assert len(replica._reserved_slots) == 1 + + # Dispatch and verify send_request_with_slot works + replica_result = await router.dispatch(selection, request_metadata) + assert replica_result._replica_id == r1_id + + async def test_multiple_sequential_selections( + self, setup_router: Tuple[AsyncioRouter, FakeRequestRouter] + ): + """Test multiple sequential choose_replica() and dispatch() calls.""" + router, fake_request_router = setup_router + + r1_id = ReplicaID( + unique_id="test-replica-1", deployment_id=DeploymentID(name="test") + ) + replica = FakeReplica(r1_id) + fake_request_router.set_replica_to_return(replica) + + for i in range(3): + request_metadata = RequestMetadata( + request_id=f"test-request-{i}", + internal_request_id=f"test-internal-request-{i}", + ) + + async with router.choose_replica(request_metadata) as selection: + replica_result = await router.dispatch(selection, request_metadata) + assert replica_result._replica_id == r1_id + + # All slots should have been created and used + assert replica._slot_counter == 3 + + @pytest.mark.parametrize( + "setup_router", + [{"enable_queue_len_cache": True}], + indirect=True, + ) + async def test_cache_updated_on_choose_replica( + self, setup_router: Tuple[AsyncioRouter, FakeRequestRouter] + ): + """Test that queue length cache is updated when choosing a replica.""" + router, fake_request_router = setup_router + + r1_id = ReplicaID( + unique_id="test-replica-1", deployment_id=DeploymentID(name="test") + ) + replica = FakeReplica(r1_id) + fake_request_router.set_replica_to_return(replica) + + request_metadata = RequestMetadata( + request_id="test-request-1", + internal_request_id="test-internal-request-1", + ) + + # Initially cache should be empty + assert fake_request_router.replica_queue_len_cache.get(r1_id) is None + + # Choose replica should update cache to 1 + async with router.choose_replica(request_metadata) as selection: + assert fake_request_router.replica_queue_len_cache.get(r1_id) == 1 + + # Dispatch should NOT increment again (already counted in choose) + await router.dispatch(selection, request_metadata) + assert fake_request_router.replica_queue_len_cache.get(r1_id) == 1 + + # After dispatch, cache remains incremented while the request is in flight. + # It is decremented when request completion is observed. + assert fake_request_router.replica_queue_len_cache.get(r1_id) == 1 + + @pytest.mark.parametrize( + "setup_router", + [{"enable_queue_len_cache": True}], + indirect=True, + ) + async def test_current_loop_dispatch_marks_selection_before_task_runs( + self, setup_router: Tuple[AsyncioRouter, FakeRequestRouter] + ): + """Wrapper dispatch should consume the selection before its task runs.""" + router, fake_request_router = setup_router + + r1_id = ReplicaID( + unique_id="test-replica-1", deployment_id=DeploymentID(name="test") + ) + replica = FakeReplica(r1_id) + fake_request_router.set_replica_to_return(replica) + + current_loop_router = CurrentLoopRouter.__new__(CurrentLoopRouter) + current_loop_router._asyncio_loop = asyncio.get_running_loop() + current_loop_router._asyncio_router = router + + request_metadata = RequestMetadata( + request_id="test-request-1", + internal_request_id="test-internal-request-1", + ) + + async with router.choose_replica(request_metadata) as selection: + dispatch_task = current_loop_router.dispatch(selection, request_metadata) + assert selection._dispatched + assert fake_request_router.replica_queue_len_cache.get(r1_id) == 1 + + assert fake_request_router.replica_queue_len_cache.get(r1_id) == 1 + assert len(replica._requests_sent) == 0 + + replica_result = await dispatch_task + assert fake_request_router.replica_queue_len_cache.get(r1_id) == 1 + + replica_result.fire_done_callbacks() + await asyncio.sleep(0) + assert fake_request_router.replica_queue_len_cache.get(r1_id) == 0 + + @pytest.mark.parametrize( + "setup_router", + [{"enable_queue_len_cache": True}], + indirect=True, + ) + async def test_current_loop_dispatch_failure_releases_cache( + self, setup_router: Tuple[AsyncioRouter, FakeRequestRouter] + ): + """If wrapper dispatch fails after consuming a selection, it owns cleanup.""" + router, fake_request_router = setup_router + + r1_id = ReplicaID( + unique_id="test-replica-1", deployment_id=DeploymentID(name="test") + ) + replica = FakeReplica(r1_id) + fake_request_router.set_replica_to_return(replica) + + current_loop_router = CurrentLoopRouter.__new__(CurrentLoopRouter) + current_loop_router._asyncio_loop = asyncio.get_running_loop() + current_loop_router._asyncio_router = router + + request_metadata = RequestMetadata( + request_id="test-request-1", + internal_request_id="test-internal-request-1", + ) + + async with router.choose_replica(request_metadata) as selection: + fake_request_router._replica_to_return = None + dispatch_task = current_loop_router.dispatch(selection, request_metadata) + assert selection._dispatched + + assert fake_request_router.replica_queue_len_cache.get(r1_id) == 1 + + with pytest.raises(ReplicaUnavailableError): + await dispatch_task + assert fake_request_router.replica_queue_len_cache.get(r1_id) == 0 + + @pytest.mark.parametrize( + "setup_router", + [{"enable_queue_len_cache": True}], + indirect=True, + ) + async def test_cache_decremented_on_choose_without_dispatch( + self, setup_router: Tuple[AsyncioRouter, FakeRequestRouter] + ): + """Test that cache is decremented when choose exits without dispatch.""" + router, fake_request_router = setup_router + + r1_id = ReplicaID( + unique_id="test-replica-1", deployment_id=DeploymentID(name="test") + ) + replica = FakeReplica(r1_id) + fake_request_router.set_replica_to_return(replica) + + request_metadata = RequestMetadata( + request_id="test-request-1", + internal_request_id="test-internal-request-1", + ) + + # Choose replica without dispatch + async with router.choose_replica(request_metadata): + # Cache should be 1 (reservation) + assert fake_request_router.replica_queue_len_cache.get(r1_id) == 1 + # Exit without calling dispatch + + # After context exit without dispatch, cache should be decremented to 0 + assert fake_request_router.replica_queue_len_cache.get(r1_id) == 0 + + @pytest.mark.parametrize( + "setup_router", + [{"enable_queue_len_cache": True}], + indirect=True, + ) + async def test_choose_replica_retries_when_reservation_rejected( + self, setup_router: Tuple[AsyncioRouter, FakeRequestRouter] + ): + """choose_replica should only yield after replica-side capacity is reserved.""" + router, fake_request_router = setup_router + + r1_id = ReplicaID( + unique_id="test-replica-1", deployment_id=DeploymentID(name="test") + ) + r2_id = ReplicaID( + unique_id="test-replica-2", deployment_id=DeploymentID(name="test") + ) + r1 = FakeReplica(r1_id) + r2 = FakeReplica(r2_id) + r1._reject_reservation = True + fake_request_router.set_replica_to_return(r1) + fake_request_router.set_replica_to_return_on_retry(r2) + + request_metadata = RequestMetadata( + request_id="test-request-1", + internal_request_id="test-internal-request-1", + ) + + async with router.choose_replica(request_metadata) as selection: + assert selection._replica == r2 + assert fake_request_router.replica_queue_len_cache.get(r1_id) == 1 + assert fake_request_router.replica_queue_len_cache.get(r2_id) == 1 + + assert fake_request_router.replica_queue_len_cache.get(r2_id) == 0 + + @pytest.mark.parametrize( + "setup_router", + [{"enable_queue_len_cache": True}], + indirect=True, + ) + async def test_concurrent_choose_replica_updates_cache( + self, setup_router: Tuple[AsyncioRouter, FakeRequestRouter] + ): + """Test that concurrent choose_replica calls correctly update cache.""" + router, fake_request_router = setup_router + + r1_id = ReplicaID( + unique_id="test-replica-1", deployment_id=DeploymentID(name="test") + ) + replica = FakeReplica(r1_id) + fake_request_router.set_replica_to_return(replica) + + # Start 3 concurrent choose_replica operations + async def choose_and_hold(request_id: str): + metadata = RequestMetadata( + request_id=request_id, + internal_request_id=request_id, + ) + async with router.choose_replica(metadata) as selection: + # Hold the selection + await asyncio.sleep(0.01) + return selection + + # Create tasks + tasks = [asyncio.create_task(choose_and_hold(f"request-{i}")) for i in range(3)] + + # Wait a bit for all to enter context + await asyncio.sleep(0.005) + + # Cache should reflect all 3 reservations + assert fake_request_router.replica_queue_len_cache.get(r1_id) == 3 + + # Let them all exit + await asyncio.gather(*tasks) + + # After all exit without dispatch, cache should be 0 + assert fake_request_router.replica_queue_len_cache.get(r1_id) == 0 + + async def test_dispatch_replica_unavailable( + self, setup_router: Tuple[AsyncioRouter, FakeRequestRouter] + ): + """Test dispatch() raises error when replica becomes unavailable.""" + router, fake_request_router = setup_router + + r1_id = ReplicaID( + unique_id="test-replica-1", deployment_id=DeploymentID(name="test") + ) + replica = FakeReplica(r1_id) + fake_request_router.set_replica_to_return(replica) + + request_metadata = RequestMetadata( + request_id="test-request-1", + internal_request_id="test-internal-request-1", + ) + + async with router.choose_replica(request_metadata) as selection: + # Simulate replica becoming unavailable + fake_request_router._replica_to_return = None + + # dispatch should raise ReplicaUnavailableError + with pytest.raises(ReplicaUnavailableError): + await router.dispatch(selection, request_metadata) + + async def test_dispatch_uses_reserved_slot( + self, setup_router: Tuple[AsyncioRouter, FakeRequestRouter] + ): + """Test that dispatch() sends request using reserved slot without rejection.""" + router, fake_request_router = setup_router + + r1_id = ReplicaID( + unique_id="test-replica-1", deployment_id=DeploymentID(name="test") + ) + replica = FakeReplica(r1_id) + fake_request_router.set_replica_to_return(replica) + + request_metadata = RequestMetadata( + request_id="test-request-1", + internal_request_id="test-internal-request-1", + ) + + async with router.choose_replica(request_metadata) as selection: + # Verify no requests sent yet + assert len(replica._requests_sent) == 0 + + # Dispatch the request + replica_result = await router.dispatch(selection, request_metadata) + assert replica_result._replica_id == r1_id + + # Verify request was sent with with_rejection=False + assert len(replica._requests_sent) == 1 + assert replica._requests_sent[0]["request_id"] == "test-request-1" + assert replica._requests_sent[0]["with_rejection"] is False + + async def test_multiple_dispatch_calls_fail( + self, setup_router: Tuple[AsyncioRouter, FakeRequestRouter] + ): + """A ReplicaSelection can only be dispatched once.""" + router, fake_request_router = setup_router + + r1_id = ReplicaID( + unique_id="test-replica-1", deployment_id=DeploymentID(name="test") + ) + replica = FakeReplica(r1_id) + fake_request_router.set_replica_to_return(replica) + + request_metadata = RequestMetadata( + request_id="test-request-1", + internal_request_id="test-internal-request-1", + ) + + async with router.choose_replica(request_metadata) as selection: + await router.dispatch(selection, request_metadata) + + with pytest.raises(RuntimeError, match="already been dispatched"): + await router.dispatch(selection, request_metadata) + + assert len(replica._requests_sent) == 1 + + def running_replica_info(replica_id: ReplicaID) -> RunningReplicaInfo: return RunningReplicaInfo( replica_id=replica_id, @@ -1484,6 +2137,52 @@ def test_request_assignment( assert isinstance(future, concurrent.futures.Future) assert future.result()._replica_id == r1_id + @pytest.mark.asyncio + @pytest.mark.parametrize( + "setup_router", + [{"enable_queue_len_cache": True}], + indirect=True, + ) + async def test_dispatch_marks_selection_before_scheduled_coroutine_runs( + self, + setup_router: Tuple[AsyncioRouter, FakeRequestRouter], + setup_singleton_thread_router: SingletonThreadRouter, + monkeypatch, + ): + _, fake_request_router = setup_router + thread_router = setup_singleton_thread_router + + r1_id = ReplicaID( + unique_id="test-replica-1", deployment_id=DeploymentID(name="test") + ) + replica = FakeReplica(r1_id) + fake_request_router.set_replica_to_return(replica) + + pending_coros = [] + + def delay_asyncio_call(coro): + pending_coros.append(coro) + return concurrent.futures.Future() + + monkeypatch.setattr( + thread_router, "_wrap_asyncio_call_in_future", delay_asyncio_call + ) + + request_metadata = RequestMetadata( + request_id="test-request-1", + internal_request_id="test-internal-request-1", + ) + + async with thread_router.choose_replica(request_metadata) as selection: + dispatch_future = thread_router.dispatch(selection, request_metadata) + assert selection._dispatched + + assert fake_request_router.replica_queue_len_cache.get(r1_id) == 1 + assert len(pending_coros) == 1 + + dispatch_future.cancel() + pending_coros[0].close() + @pytest.mark.asyncio async def test_cancellation_propagation( self,