|
9 | 9 | from collections import defaultdict |
10 | 10 | from collections.abc import MutableMapping |
11 | 11 | from contextlib import asynccontextmanager, contextmanager |
| 12 | +from dataclasses import replace |
12 | 13 | from functools import lru_cache, partial |
13 | 14 | from typing import ( |
14 | 15 | Any, |
@@ -1263,25 +1264,108 @@ async def broadcast( |
1263 | 1264 | *request_args, |
1264 | 1265 | **request_kwargs, |
1265 | 1266 | ) -> List[ReplicaResult]: |
1266 | | - """Send a request to all current replicas and return all results. |
| 1267 | + """Send a request to all running replicas in parallel. |
1267 | 1268 |
|
1268 | | - This is a fan-out operation: the same request is dispatched to every |
1269 | | - replica and the list of ReplicaResults is returned to the caller. |
| 1269 | + Bypasses the normal load-balancing path and sends the request |
| 1270 | + directly to every replica. Waits for the request router to be |
| 1271 | + initialized so the replica set is populated. |
1270 | 1272 | """ |
| 1273 | + # Propagate tracing context, matching assign_request behavior. |
| 1274 | + if is_span_recording(): |
| 1275 | + propagate_context = create_propagated_context() |
| 1276 | + request_meta.tracing_context = propagate_context |
| 1277 | + else: |
| 1278 | + request_meta.tracing_context = None |
| 1279 | + |
| 1280 | + if not self._deployment_available: |
| 1281 | + raise DeploymentUnavailableError(self.deployment_id) |
| 1282 | + |
1271 | 1283 | await self._request_router_initialized.wait() |
1272 | 1284 |
|
| 1285 | + if not self._deployment_available: |
| 1286 | + raise DeploymentUnavailableError(self.deployment_id) |
| 1287 | + |
| 1288 | + replicas: List[RunningReplica] = list( |
| 1289 | + self.request_router.curr_replicas.values() |
| 1290 | + ) |
| 1291 | + if not replicas: |
| 1292 | + raise DeploymentUnavailableError(self.deployment_id) |
| 1293 | + |
| 1294 | + # Resolve arguments (e.g. DeploymentResponse objects) before sending. |
1273 | 1295 | pr = PendingRequest( |
1274 | 1296 | args=list(request_args), |
1275 | | - kwargs=request_kwargs, |
| 1297 | + kwargs=dict(request_kwargs), |
1276 | 1298 | metadata=request_meta, |
1277 | 1299 | ) |
1278 | | - if not pr.resolved: |
1279 | | - await self._resolve_request_arguments(pr) |
| 1300 | + await self._resolve_request_arguments(pr) |
| 1301 | + |
| 1302 | + results: List[ReplicaResult] = [] |
| 1303 | + for replica in replicas: |
| 1304 | + replica_pr = PendingRequest( |
| 1305 | + args=list(pr.args), |
| 1306 | + kwargs=dict(pr.kwargs), |
| 1307 | + metadata=replace( |
| 1308 | + request_meta, |
| 1309 | + internal_request_id=generate_request_id(), |
| 1310 | + ), |
| 1311 | + ) |
| 1312 | + replica_pr.resolved = True |
| 1313 | + try: |
| 1314 | + result = replica.try_send_request(replica_pr, with_rejection=False) |
| 1315 | + except ActorDiedError: |
| 1316 | + # Replica has died but controller hasn't notified the router yet. |
| 1317 | + # Skip this replica and continue broadcasting to healthy replicas. |
| 1318 | + self.request_router.on_replica_actor_died(replica.replica_id) |
| 1319 | + logger.warning( |
| 1320 | + f"{replica.replica_id} will not be considered for future " |
| 1321 | + "requests because it has died." |
| 1322 | + ) |
| 1323 | + continue |
| 1324 | + except ActorUnavailableError: |
| 1325 | + # Replica is temporarily unavailable. Invalidate the cache entry |
| 1326 | + # and continue broadcasting to other replicas. |
| 1327 | + self.request_router.on_replica_actor_unavailable(replica.replica_id) |
| 1328 | + logger.warning(f"{replica.replica_id} is temporarily unavailable.") |
| 1329 | + continue |
| 1330 | + |
| 1331 | + # Proactively update the queue length cache. |
| 1332 | + self.request_router.on_send_request(replica.replica_id) |
| 1333 | + |
| 1334 | + # Track running requests and register callback for completion |
| 1335 | + # handling, matching the pattern in _route_and_send_request_once. |
| 1336 | + if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE: |
| 1337 | + self._metrics_manager.inc_num_running_requests_for_replica( |
| 1338 | + replica.replica_id |
| 1339 | + ) |
| 1340 | + # NOTE: add_done_callback fires from a C++ worker thread (for |
| 1341 | + # actor ObjectRefs) or a gRPC callback thread. |
| 1342 | + # _process_finished_request and decrement_queue_len_cache both |
| 1343 | + # access shared router state that is not thread-safe, so we |
| 1344 | + # schedule them on the router's event loop. |
| 1345 | + callback = partial( |
| 1346 | + self._process_finished_request, |
| 1347 | + replica.replica_id, |
| 1348 | + replica_pr.metadata.internal_request_id, |
| 1349 | + replica.actor_id, |
| 1350 | + ) |
| 1351 | + result.add_done_callback( |
| 1352 | + lambda _, cb=callback: self._event_loop.call_soon_threadsafe(cb, _) |
| 1353 | + ) |
| 1354 | + result.add_done_callback( |
| 1355 | + lambda _, rid=replica.replica_id: ( |
| 1356 | + self._event_loop.call_soon_threadsafe( |
| 1357 | + self.request_router.decrement_queue_len_cache, |
| 1358 | + rid, |
| 1359 | + ) |
| 1360 | + ) |
| 1361 | + ) |
| 1362 | + |
| 1363 | + results.append(result) |
| 1364 | + |
| 1365 | + if not results: |
| 1366 | + raise DeploymentUnavailableError(self.deployment_id) |
1280 | 1367 |
|
1281 | | - return [ |
1282 | | - replica.try_send_request(pr, with_rejection=False) |
1283 | | - for replica in self.request_router.curr_replicas.values() |
1284 | | - ] |
| 1368 | + return results |
1285 | 1369 |
|
1286 | 1370 | async def shutdown(self): |
1287 | 1371 | await self._metrics_manager.shutdown() |
|
0 commit comments