Skip to content

Commit 5e79faa

Browse files
committed
Add AcceleratorConfig to Serve and fix gang scheduling
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
1 parent 6a1511d commit 5e79faa

13 files changed

Lines changed: 517 additions & 29 deletions

File tree

python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_engine_tpu.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,6 @@ def test_tpu_serve_deployment_explicit_per_chip_bundles(ray_tpu_cluster):
323323
Verifies that a user can explicitly request chip-level bundles (1 TPU per bundle)
324324
for a full multi-host TPU slice via placement_group_config.
325325
"""
326-
from ray.llm._internal.serve.core.configs.accelerators import TPUConfig
327-
328326
llm_config = LLMConfig(
329327
model_loading_config=ModelLoadingConfig(model_id="test-tpu-model"),
330328
accelerator_type="TPU-V6E",

python/ray/serve/_private/common.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from dataclasses import asdict, dataclass, field
33
from enum import Enum
4-
from typing import Any, Awaitable, Callable, Dict, List, Optional
4+
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union
55

66
from starlette.types import Scope
77

@@ -18,6 +18,10 @@
1818
from ray.util.annotations import PublicAPI
1919
from ray.util.placement_group import PlacementGroup
2020

21+
if TYPE_CHECKING:
22+
from ray.serve._private.default_impl import _ReplicaPlacementGroup
23+
from ray.serve.config import AcceleratorConfig
24+
2125
REPLICA_ID_FULL_ID_STR_PREFIX = "SERVE_REPLICA::"
2226
GANG_PG_NAME_PREFIX = "SERVE_GANG::"
2327

@@ -897,6 +901,7 @@ class CreatePlacementGroupRequest:
897901
runtime_env: Optional[str] = None
898902
bundle_label_selector: Optional[List[Dict[str, str]]] = None
899903
fallback_strategy: Optional[List[Dict[str, Any]]] = None
904+
accelerator_config: Optional["AcceleratorConfig"] = None
900905

901906

902907
@dataclass
@@ -922,6 +927,7 @@ class GangPlacementGroupRequest:
922927
"""Label selector for per-replica placement group bundles."""
923928

924929
replica_pg_fallback_strategy: Optional[List[Dict[str, Any]]] = None
930+
accelerator_config: Optional["AcceleratorConfig"] = None
925931
"""Fallback strategy for per-replica placement group bundles."""
926932

927933

@@ -932,7 +938,7 @@ class GangReservationResult:
932938
success: bool
933939
"""True when all gang PGs were created successfully."""
934940
error_message: Optional[str] = None
935-
gang_pgs: Optional[List[PlacementGroup]] = None
941+
gang_pgs: Optional[List[Union[PlacementGroup, "_ReplicaPlacementGroup"]]] = None
936942
gang_ids: Optional[List[str]] = None
937943
gang_pg_names: Optional[List[str]] = None
938944

python/ray/serve/_private/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from ray.serve._private.utils import DEFAULT, DeploymentOptionUpdateType
3434
from ray.serve.config import (
35+
AcceleratorConfig,
3536
AggregationFunction,
3637
AutoscalingConfig,
3738
DeploymentActorConfig,
@@ -191,6 +192,10 @@ class DeploymentConfig(BaseModel):
191192
update_type=DeploymentOptionUpdateType.NeedsActorReconfigure,
192193
)
193194

195+
accelerator_config: Optional[AcceleratorConfig] = Field(
196+
default=None, update_type=DeploymentOptionUpdateType.HeavyWeight
197+
)
198+
194199
# This flag is used to let replica know they are deployed from
195200
# a different language.
196201
is_cross_language: bool = False
@@ -323,6 +328,8 @@ def needs_pickle(self):
323328

324329
def to_proto(self):
325330
data = self.model_dump()
331+
if data.get("accelerator_config") is not None:
332+
data["accelerator_config"] = cloudpickle.dumps(self.accelerator_config)
326333
if data.get("user_config") is not None:
327334
if self.needs_pickle():
328335
data["user_config"] = cloudpickle.dumps(data["user_config"])
@@ -430,6 +437,11 @@ def from_proto(cls, proto: DeploymentConfigProto):
430437
data["is_cross_language"] if "is_cross_language" in data else False
431438
)
432439
needs_pickle = _needs_pickle(deployment_language, is_cross_language)
440+
if "accelerator_config" in data:
441+
if data["accelerator_config"] != b"":
442+
data["accelerator_config"] = cloudpickle.loads(proto.accelerator_config)
443+
else:
444+
data["accelerator_config"] = None
433445
if "user_config" in data:
434446
if data["user_config"] != b"":
435447
if needs_pickle:

python/ray/serve/_private/default_impl.py

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
2-
from typing import Callable, Optional, Tuple
2+
import logging
3+
from dataclasses import dataclass
4+
from typing import Callable, Dict, List, Optional, Tuple, Union
35

46
import ray
57
from ray._common.constants import HEAD_NODE_RESOURCE_NAME
@@ -42,7 +44,11 @@
4244
inside_ray_client_context,
4345
resolve_deployment_response,
4446
)
45-
from ray.util.placement_group import PlacementGroup
47+
from ray.serve.config import TPUAcceleratorConfig
48+
from ray.util.placement_group import PlacementGroup, remove_placement_group
49+
from ray.util.tpu import SlicePlacementGroup, slice_placement_group
50+
51+
logger = logging.getLogger(__name__)
4652

4753
# NOTE: Please read carefully before changing!
4854
#
@@ -51,11 +57,93 @@
5157
# API modified w/o substantial enough justification
5258

5359

60+
@dataclass
61+
class _ReplicaPlacementGroup:
62+
"""Internal Serve handle for a replica's placement group(s).
63+
64+
Wraps the worker PG and any accelerator-specific cleanup hooks so the
65+
controller doesn't need to know whether the underlying request was a
66+
plain CPU/GPU PG or a TPU slice reservation.
67+
"""
68+
69+
placement_group: PlacementGroup
70+
_slice_pg: Optional[SlicePlacementGroup] = None
71+
72+
def release_reservation_holders(self) -> None:
73+
"""Call after ``placement_group.ready()`` resolves successfully.
74+
75+
Releases any internal reservation-holder PGs (e.g. TPU head PGs)
76+
that were only needed to claim resources during scheduling. No-op
77+
for non-accelerator deployments.
78+
"""
79+
if self._slice_pg is not None:
80+
self._slice_pg.release_head_pgs()
81+
82+
def shutdown(self) -> None:
83+
"""Tear down the replica's PG(s). Idempotent."""
84+
if self._slice_pg is not None:
85+
self._slice_pg.shutdown()
86+
self._slice_pg = None
87+
self.placement_group = None
88+
elif self.placement_group is not None:
89+
90+
try:
91+
remove_placement_group(self.placement_group)
92+
except Exception:
93+
logger.exception("Failed to remove placement group.")
94+
finally:
95+
self.placement_group = None
96+
97+
98+
def _create_replica_placement_group(
99+
request: CreatePlacementGroupRequest,
100+
) -> _ReplicaPlacementGroup:
101+
"""Internal entry point that supports accelerator-specific dispatch."""
102+
accelerator_config = request.accelerator_config
103+
104+
if isinstance(accelerator_config, TPUAcceleratorConfig):
105+
slice_pg = _default_create_tpu_placement_group(
106+
tpu_config=accelerator_config,
107+
strategy=request.strategy,
108+
name=request.name,
109+
lifetime="detached",
110+
bundle_label_selector=request.bundle_label_selector,
111+
)
112+
return _ReplicaPlacementGroup(
113+
placement_group=slice_pg.placement_group,
114+
_slice_pg=slice_pg,
115+
)
116+
117+
pg = _default_create_placement_group(request)
118+
return _ReplicaPlacementGroup(placement_group=pg)
119+
120+
121+
def _default_create_tpu_placement_group(
122+
tpu_config: TPUAcceleratorConfig,
123+
strategy: str,
124+
name: str,
125+
lifetime: Optional[str],
126+
bundle_label_selector: Optional[List[Dict[str, str]]] = None,
127+
) -> SlicePlacementGroup:
128+
return slice_placement_group(
129+
topology=tpu_config.topology,
130+
accelerator_version=tpu_config.accelerator_version,
131+
num_slices=tpu_config.num_slices,
132+
chips_per_vm=tpu_config.chips_per_vm,
133+
strategy=strategy,
134+
name=name,
135+
lifetime=lifetime,
136+
bundle_label_selector=bundle_label_selector,
137+
)
138+
139+
54140
def create_cluster_node_info_cache(gcs_client: GcsClient) -> ClusterNodeInfoCache:
55141
return DefaultClusterNodeInfoCache(gcs_client)
56142

57143

58-
CreatePlacementGroupFn = Callable[[CreatePlacementGroupRequest], PlacementGroup]
144+
CreatePlacementGroupFn = Callable[
145+
[CreatePlacementGroupRequest], Union[PlacementGroup, _ReplicaPlacementGroup]
146+
]
59147

60148

61149
def _default_create_placement_group(
@@ -81,7 +169,7 @@ def create_deployment_scheduler(
81169
cluster_node_info_cache,
82170
head_node_id,
83171
create_placement_group_fn=create_placement_group_fn_override
84-
or _default_create_placement_group,
172+
or _create_replica_placement_group,
85173
)
86174

87175

python/ray/serve/_private/deployment_scheduler.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from dataclasses import dataclass
88
from enum import Enum
99
from functools import total_ordering
10-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
10+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union
1111

1212
import ray
1313
from ray._raylet import node_labels_match_selector
@@ -27,6 +27,7 @@
2727
RAY_SERVE_USE_PACK_SCHEDULING_STRATEGY,
2828
SERVE_LOGGER_NAME,
2929
)
30+
from ray.serve.config import AcceleratorConfig
3031
from ray.util.placement_group import PlacementGroup
3132
from ray.util.scheduling_strategies import (
3233
LabelMatchExpressionsT,
@@ -35,6 +36,9 @@
3536
PlacementGroupSchedulingStrategy,
3637
)
3738

39+
if TYPE_CHECKING:
40+
from ray.serve._private.default_impl import _ReplicaPlacementGroup
41+
3842
logger = logging.getLogger(SERVE_LOGGER_NAME)
3943

4044

@@ -198,6 +202,7 @@ class ReplicaSchedulingRequest:
198202
placement_group_strategy: Optional[str] = None
199203
placement_group_bundle_label_selector: Optional[List[Dict[str, str]]] = None
200204
placement_group_fallback_strategy: Optional[List[Dict[str, Any]]] = None
205+
accelerator_config: Optional[AcceleratorConfig] = None
201206
max_replicas_per_node: Optional[int] = None
202207
# Gang scheduling fields -- if set, replica should be scheduled on
203208
# the reserved gang placement group at the specified bundle index.
@@ -636,12 +641,21 @@ def _schedule_replica(
636641
replica_id = scheduling_request.replica_id
637642
deployment_id = replica_id.deployment_id
638643
placement_group = None
644+
slice_pg = None
639645

640646
scheduling_strategy = default_scheduling_strategy
641647

642648
if scheduling_request.gang_placement_group is not None:
643649
# Gang scheduling -- use the reserved gang placement group
644-
placement_group = scheduling_request.gang_placement_group
650+
pg_wrapper = scheduling_request.gang_placement_group
651+
placement_group = (
652+
pg_wrapper.placement_group
653+
if hasattr(pg_wrapper, "placement_group")
654+
else pg_wrapper
655+
)
656+
# Preserve the wrapper for cleanup of head PGs
657+
slice_pg = pg_wrapper if hasattr(pg_wrapper, "placement_group") else None
658+
645659
scheduling_strategy = PlacementGroupSchedulingStrategy(
646660
placement_group=placement_group,
647661
placement_group_bundle_index=scheduling_request.gang_pg_index,
@@ -651,21 +665,32 @@ def _schedule_replica(
651665
target_labels = None
652666
target_node_id = None
653667
elif scheduling_request.placement_group_bundles is not None:
668+
slice_pg = None
654669
placement_group_strategy = (
655670
scheduling_request.placement_group_strategy
656671
if scheduling_request.placement_group_strategy
657672
else "PACK"
658673
)
659674
try:
660-
pg = self._create_placement_group_fn(
675+
pg_result = self._create_placement_group_fn(
661676
CreatePlacementGroupRequest(
662677
bundles=scheduling_request.placement_group_bundles,
663678
strategy=placement_group_strategy,
664679
target_node_id=target_node_id,
665680
name=scheduling_request.actor_options["name"],
666681
bundle_label_selector=scheduling_request.placement_group_bundle_label_selector,
667-
)
682+
accelerator_config=scheduling_request.accelerator_config,
683+
),
668684
)
685+
686+
from ray.serve._private.default_impl import _ReplicaPlacementGroup
687+
688+
if isinstance(pg_result, _ReplicaPlacementGroup):
689+
pg = pg_result.placement_group
690+
slice_pg = pg_result
691+
else:
692+
pg = pg_result
693+
slice_pg = None
669694
except Exception:
670695
# We add a defensive exception here, so the controller can
671696
# make progress even if the placement group isn't created.
@@ -720,6 +745,15 @@ def _schedule_replica(
720745
scheduling_request.status = (
721746
ReplicaSchedulingRequestStatus.ACTOR_CREATION_FAILED
722747
)
748+
749+
if slice_pg is not None:
750+
slice_pg.shutdown()
751+
elif (
752+
placement_group is not None
753+
and scheduling_request.placement_group_bundles is not None
754+
):
755+
ray.util.remove_placement_group(placement_group)
756+
723757
return False
724758

725759
del self._pending_replicas[deployment_id][replica_id]
@@ -731,7 +765,11 @@ def _schedule_replica(
731765
placement_group = scheduling_strategy.placement_group
732766

733767
scheduling_request.status = ReplicaSchedulingRequestStatus.SUCCEEDED
734-
scheduling_request.on_scheduled(actor_handle, placement_group=placement_group)
768+
scheduling_request.on_scheduled(
769+
actor_handle,
770+
placement_group=placement_group,
771+
placement_group_manager=slice_pg,
772+
)
735773
return True
736774

737775
@abstractmethod
@@ -816,7 +854,7 @@ def _prepare_gangs_for_deployment(
816854

817855
# Flatten per-replica bundles to form a placement group to atomically reserve resources
818856
# required for each gang
819-
gang_pgs: List[PlacementGroup] = []
857+
gang_pgs: List[Union[PlacementGroup, "_ReplicaPlacementGroup"]] = []
820858
gang_ids: List[str] = []
821859
gang_pg_names: List[str] = []
822860
for gang_index in range(num_gangs):
@@ -867,6 +905,7 @@ def _prepare_gangs_for_deployment(
867905
name=pg_name,
868906
bundle_label_selector=label_selector,
869907
fallback_strategy=fallback_strategy,
908+
accelerator_config=request.accelerator_config,
870909
)
871910
)
872911
gang_pgs.append(pg)

0 commit comments

Comments
 (0)