Skip to content

Commit e737494

Browse files
committed
Add AcceleratorConfig to Serve
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
1 parent 393578d commit e737494

12 files changed

Lines changed: 416 additions & 36 deletions

File tree

python/ray/llm/_internal/serve/core/configs/accelerators.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def default_bundles(
199199

200200
num_hosts = max(1, num_devices // chips_per_host)
201201

202-
bundle = {"TPU": float(chips_per_host)}
202+
bundle = {"TPU": chips_per_host}
203203
bundle[format_ray_accelerator_resource(accelerator_type_str)] = 0.001
204204

205205
return [bundle.copy() for _ in range(num_hosts)]
@@ -293,7 +293,3 @@ def shutdown(self):
293293
logger.warning(f"Failed to shut down TPU slice PG: {e}")
294294
finally:
295295
self._slice_pg_wrapper = None
296-
297-
def __del__(self):
298-
"""Ensure placement groups are cleaned up when this backend is garbage collected."""
299-
self.shutdown()

python/ray/llm/_internal/serve/core/server/llm_server.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
MODEL_RESPONSE_BATCH_TIMEOUT_MS,
2323
RAYLLM_VLLM_ENGINE_CLS_ENV,
2424
)
25+
from ray.llm._internal.serve.core.configs.accelerators import TPUConfig
2526
from ray.llm._internal.serve.core.configs.llm_config import (
2627
DiskMultiplexConfig,
2728
LLMConfig,
@@ -39,6 +40,8 @@
3940
from ray.llm._internal.serve.utils.server_utils import (
4041
get_serve_request_id,
4142
)
43+
from ray.serve.config import AcceleratorConfig, TPUSliceSpec
44+
from ray.util.tpu import get_tpu_version_from_type
4245

4346
if TYPE_CHECKING:
4447
from ray.llm._internal.serve.core.configs.openai_api_models import (
@@ -737,4 +740,23 @@ def get_deployment_options(cls, llm_config: "LLMConfig"):
737740
}
738741
deployment_options["ray_actor_options"] = ray_actor_options
739742

743+
if llm_config.accelerator_config is not None and isinstance(
744+
llm_config.accelerator_config, TPUConfig
745+
):
746+
if not llm_config.accelerator_type:
747+
raise ValueError(
748+
"llm_config.accelerator_type must be specified when "
749+
"accelerator_config is a TPUConfig."
750+
)
751+
version = get_tpu_version_from_type(llm_config.accelerator_type)
752+
753+
deployment_options["accelerator_config"] = AcceleratorConfig(
754+
accelerator_type="tpu",
755+
tpu=TPUSliceSpec(
756+
topology=llm_config.accelerator_config.topology,
757+
accelerator_version=version,
758+
num_slices=1,
759+
),
760+
)
761+
740762
return deployment_options

python/ray/llm/tests/serve/cpu/configs/test_models.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -420,13 +420,13 @@ def test_requires_deferred_placement_group(self):
420420
@pytest.mark.parametrize(
421421
"topology,num_devices,accelerator_type_str,expected_bundles_count,expected_chips_per_host",
422422
[
423-
("1x1", 1, "v6e", 1, 1.0),
424-
("1x1", 1, "v7x", 1, 4.0),
425-
("4x4", 16, "v6e", 4, 4.0),
426-
("2x2x2", 8, "v5p", 2, 4.0),
427-
("2x2", 4, "v5litepod", 1, 4.0),
428-
("2x2x1", 4, "v4", 1, 4.0),
429-
("2x4", 8, "v6e", 1, 8.0),
423+
("1x1", 1, "v6e", 1, 1),
424+
("1x1", 1, "v7x", 1, 4),
425+
("4x4", 16, "v6e", 4, 4),
426+
("2x2x2", 8, "v5p", 2, 4),
427+
("2x2", 4, "v5litepod", 1, 4),
428+
("2x2x1", 4, "v4", 1, 4),
429+
("2x4", 8, "v6e", 1, 8),
430430
],
431431
)
432432
def test_default_bundles_topology(
@@ -457,6 +457,40 @@ def test_default_bundles_topology_missing_accelerator_type_raises(self):
457457
):
458458
tpu_accel.default_bundles(num_devices=16, accelerator_type_str=None)
459459

460+
def test_default_bundles_v6e_4x4(self):
461+
"""Test that v6e 4x4 topology returns per-host bundles."""
462+
tpu_accel = TPUAccelerator(TPUConfig(kind="tpu", topology="4x4"))
463+
bundles = tpu_accel.default_bundles(num_devices=16, accelerator_type_str="v6e")
464+
465+
# 4x4 v6e = 16 chips. 4 chips per host -> 4 hosts.
466+
assert len(bundles) == 4
467+
for bundle in bundles:
468+
assert bundle["TPU"] == 4.0
469+
assert "accelerator_type:v6e" in bundle
470+
471+
def test_default_bundles_v5p_2x2x2(self):
472+
"""Test that v5p 2x2x2 topology returns per-host bundles."""
473+
tpu_accel = TPUAccelerator(TPUConfig(kind="tpu", topology="2x2x2"))
474+
bundles = tpu_accel.default_bundles(num_devices=8, accelerator_type_str="v5p")
475+
476+
# 2x2x2 v5p = 8 chips. 4 chips per host -> 2 hosts.
477+
assert len(bundles) == 2
478+
for bundle in bundles:
479+
assert bundle["TPU"] == 4.0
480+
assert "accelerator_type:v5p" in bundle
481+
482+
def test_default_bundles_single_host_topology(self):
483+
"""Test that a single-host topology returns a single bundle."""
484+
tpu_accel = TPUAccelerator(TPUConfig(kind="tpu", topology="2x2"))
485+
bundles = tpu_accel.default_bundles(
486+
num_devices=4, accelerator_type_str="v5litepod"
487+
)
488+
489+
# 2x2 v5litepod = 4 chips on 1 host.
490+
assert len(bundles) == 1
491+
assert bundles[0]["TPU"] == 4.0
492+
assert "accelerator_type:v5litepod" in bundles[0]
493+
460494

461495
if __name__ == "__main__":
462496
sys.exit(pytest.main(["-v", __file__]))

python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_engine_tpu.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ def test_tpu_slice_placement_group_creation_default_resources(ray_tpu_cluster):
3737
pg_table = placement_group_table(pg)
3838
assert pg_table["strategy"] == "PACK"
3939

40-
# 4x4 v6e = 16 chips. We default to 1 TPU chip per bundle.
41-
assert len(pg_table["bundles"]) == 16
40+
# 4x4 v6e = 16 chips. We default to 4 TPU chips per bundle (per-host).
41+
assert len(pg_table["bundles"]) == 4
4242
for bundle in pg_table["bundles"].values():
4343
assert "TPU" in bundle
44-
assert bundle["TPU"] == 1
44+
assert bundle["TPU"] == 4.0
4545

4646
# Let the backend tear down its own resources if it has any
4747
engine_config.accelerator.shutdown()
@@ -62,7 +62,7 @@ def test_tpu_slice_placement_group_creation_host_resources(ray_tpu_cluster):
6262
accelerator_config={"kind": "tpu", "topology": "4x4"},
6363
placement_group_config={
6464
"strategy": "STRICT_SPREAD",
65-
"bundles": [{"TPU": 4}],
65+
"bundles": [{"TPU": 4}] * 4,
6666
},
6767
)
6868

@@ -256,10 +256,10 @@ def test_tpu_serve_deployment_default_chip_level_bundles(ray_tpu_cluster):
256256
worker_pg = [pg for pg in active_pgs if pg not in head_pgs][0]
257257

258258
assert worker_pg["strategy"] == "PACK"
259-
# 4x4 topology = 16 chips. Default is 16 bundles of 1 TPU.
260-
assert len(worker_pg["bundles"]) == 16
259+
# 4x4 topology = 16 chips. Default is 4 bundles of 4 TPUs (per-host).
260+
assert len(worker_pg["bundles"]) == 4
261261
for bundle in worker_pg["bundles"].values():
262-
assert bundle.get("TPU", 0) == 1
262+
assert bundle.get("TPU", 0) == 4.0
263263

264264
serve.shutdown()
265265

python/ray/serve/_private/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from ray.serve._private.utils import DEFAULT, DeploymentOptionUpdateType
3434
from ray.serve.config import (
35+
AcceleratorConfig,
3536
AggregationFunction,
3637
AutoscalingConfig,
3738
DeploymentActorConfig,
@@ -191,6 +192,10 @@ class DeploymentConfig(BaseModel):
191192
update_type=DeploymentOptionUpdateType.NeedsActorReconfigure,
192193
)
193194

195+
accelerator_config: Optional[AcceleratorConfig] = Field(
196+
default=None, update_type=DeploymentOptionUpdateType.HeavyWeight
197+
)
198+
194199
# This flag is used to let replica know they are deployed from
195200
# a different language.
196201
is_cross_language: bool = False
@@ -323,6 +328,8 @@ def needs_pickle(self):
323328

324329
def to_proto(self):
325330
data = self.model_dump()
331+
if data.get("accelerator_config") is not None:
332+
data["accelerator_config"] = cloudpickle.dumps(data["accelerator_config"])
326333
if data.get("user_config") is not None:
327334
if self.needs_pickle():
328335
data["user_config"] = cloudpickle.dumps(data["user_config"])
@@ -430,6 +437,11 @@ def from_proto(cls, proto: DeploymentConfigProto):
430437
data["is_cross_language"] if "is_cross_language" in data else False
431438
)
432439
needs_pickle = _needs_pickle(deployment_language, is_cross_language)
440+
if "accelerator_config" in data:
441+
if data["accelerator_config"] != b"":
442+
data["accelerator_config"] = cloudpickle.loads(proto.accelerator_config)
443+
else:
444+
data["accelerator_config"] = None
433445
if "user_config" in data:
434446
if data["user_config"] != b"":
435447
if needs_pickle:

python/ray/serve/_private/default_impl.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
2-
from typing import Callable, Optional, Tuple
2+
import logging
3+
from dataclasses import dataclass
4+
from typing import Callable, Dict, List, Optional, Tuple
35

46
import ray
57
from ray._common.constants import HEAD_NODE_RESOURCE_NAME
@@ -42,7 +44,11 @@
4244
inside_ray_client_context,
4345
resolve_deployment_response,
4446
)
45-
from ray.util.placement_group import PlacementGroup
47+
from ray.serve.config import AcceleratorConfig, TPUSliceSpec
48+
from ray.util.placement_group import PlacementGroup, remove_placement_group
49+
from ray.util.tpu import SlicePlacementGroup, slice_placement_group
50+
51+
logger = logging.getLogger(__name__)
4652

4753
# NOTE: Please read carefully before changing!
4854
#
@@ -51,6 +57,86 @@
5157
# API modified w/o substantial enough justification
5258

5359

60+
@dataclass
61+
class _ReplicaPlacementGroup:
62+
"""Internal Serve handle for a replica's placement group(s).
63+
64+
Wraps the worker PG and any accelerator-specific cleanup hooks so the
65+
controller doesn't need to know whether the underlying request was a
66+
plain CPU/GPU PG or a TPU slice reservation.
67+
"""
68+
69+
placement_group: PlacementGroup
70+
_slice_pg: Optional[SlicePlacementGroup] = None
71+
72+
def release_reservation_holders(self) -> None:
73+
"""Call after ``placement_group.ready()`` resolves successfully.
74+
75+
Releases any internal reservation-holder PGs (e.g. TPU head PGs)
76+
that were only needed to claim resources during scheduling. No-op
77+
for non-accelerator deployments.
78+
"""
79+
if self._slice_pg is not None:
80+
self._slice_pg.release_head_pgs()
81+
82+
def shutdown(self) -> None:
83+
"""Tear down the replica's PG(s). Idempotent."""
84+
if self._slice_pg is not None:
85+
self._slice_pg.shutdown()
86+
self._slice_pg = None
87+
self.placement_group = None
88+
elif self.placement_group is not None:
89+
90+
try:
91+
remove_placement_group(self.placement_group)
92+
except Exception:
93+
logger.exception("Failed to remove placement group.")
94+
finally:
95+
self.placement_group = None
96+
97+
98+
def _create_replica_placement_group(
99+
request: CreatePlacementGroupRequest,
100+
*,
101+
accelerator_config: Optional[AcceleratorConfig] = None,
102+
) -> _ReplicaPlacementGroup:
103+
"""Internal entry point that supports accelerator-specific dispatch."""
104+
if accelerator_config is not None and accelerator_config.accelerator_type == "tpu":
105+
slice_pg = _default_create_tpu_placement_group(
106+
tpu_spec=accelerator_config.tpu,
107+
strategy=request.strategy,
108+
name=request.name,
109+
lifetime="detached",
110+
bundle_label_selector=request.bundle_label_selector,
111+
)
112+
return _ReplicaPlacementGroup(
113+
placement_group=slice_pg.placement_group,
114+
_slice_pg=slice_pg,
115+
)
116+
117+
pg = _default_create_placement_group(request)
118+
return _ReplicaPlacementGroup(placement_group=pg)
119+
120+
121+
def _default_create_tpu_placement_group(
122+
tpu_spec: TPUSliceSpec,
123+
strategy: str,
124+
name: str,
125+
lifetime: Optional[str],
126+
bundle_label_selector: Optional[List[Dict[str, str]]] = None,
127+
) -> SlicePlacementGroup:
128+
return slice_placement_group(
129+
topology=tpu_spec.topology,
130+
accelerator_version=tpu_spec.accelerator_version,
131+
num_slices=tpu_spec.num_slices,
132+
chips_per_vm=tpu_spec.chips_per_vm,
133+
strategy=strategy,
134+
name=name,
135+
lifetime=lifetime,
136+
bundle_label_selector=bundle_label_selector,
137+
)
138+
139+
54140
def create_cluster_node_info_cache(gcs_client: GcsClient) -> ClusterNodeInfoCache:
55141
return DefaultClusterNodeInfoCache(gcs_client)
56142

@@ -81,7 +167,7 @@ def create_deployment_scheduler(
81167
cluster_node_info_cache,
82168
head_node_id,
83169
create_placement_group_fn=create_placement_group_fn_override
84-
or _default_create_placement_group,
170+
or _create_replica_placement_group,
85171
)
86172

87173

python/ray/serve/_private/deployment_scheduler.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
RAY_SERVE_USE_PACK_SCHEDULING_STRATEGY,
2828
SERVE_LOGGER_NAME,
2929
)
30+
from ray.serve.config import AcceleratorConfig
3031
from ray.util.placement_group import PlacementGroup
3132
from ray.util.scheduling_strategies import (
3233
LabelMatchExpressionsT,
@@ -198,6 +199,7 @@ class ReplicaSchedulingRequest:
198199
placement_group_strategy: Optional[str] = None
199200
placement_group_bundle_label_selector: Optional[List[Dict[str, str]]] = None
200201
placement_group_fallback_strategy: Optional[List[Dict[str, Any]]] = None
202+
accelerator_config: Optional[AcceleratorConfig] = None
201203
max_replicas_per_node: Optional[int] = None
202204
# Gang scheduling fields -- if set, replica should be scheduled on
203205
# the reserved gang placement group at the specified bundle index.
@@ -636,6 +638,7 @@ def _schedule_replica(
636638
replica_id = scheduling_request.replica_id
637639
deployment_id = replica_id.deployment_id
638640
placement_group = None
641+
sp = None
639642

640643
scheduling_strategy = default_scheduling_strategy
641644

@@ -651,21 +654,32 @@ def _schedule_replica(
651654
target_labels = None
652655
target_node_id = None
653656
elif scheduling_request.placement_group_bundles is not None:
657+
sp = None
654658
placement_group_strategy = (
655659
scheduling_request.placement_group_strategy
656660
if scheduling_request.placement_group_strategy
657661
else "PACK"
658662
)
659663
try:
660-
pg = self._create_placement_group_fn(
664+
pg_result = self._create_placement_group_fn(
661665
CreatePlacementGroupRequest(
662666
bundles=scheduling_request.placement_group_bundles,
663667
strategy=placement_group_strategy,
664668
target_node_id=target_node_id,
665669
name=scheduling_request.actor_options["name"],
666670
bundle_label_selector=scheduling_request.placement_group_bundle_label_selector,
667-
)
671+
),
672+
accelerator_config=scheduling_request.accelerator_config,
668673
)
674+
675+
from ray.serve._private.default_impl import _ReplicaPlacementGroup
676+
677+
if isinstance(pg_result, _ReplicaPlacementGroup):
678+
pg = pg_result.placement_group
679+
sp = pg_result
680+
else:
681+
pg = pg_result
682+
sp = None
669683
except Exception:
670684
# We add a defensive exception here, so the controller can
671685
# make progress even if the placement group isn't created.
@@ -731,7 +745,9 @@ def _schedule_replica(
731745
placement_group = scheduling_strategy.placement_group
732746

733747
scheduling_request.status = ReplicaSchedulingRequestStatus.SUCCEEDED
734-
scheduling_request.on_scheduled(actor_handle, placement_group=placement_group)
748+
scheduling_request.on_scheduled(
749+
actor_handle, placement_group=placement_group, placement_group_manager=sp
750+
)
735751
return True
736752

737753
@abstractmethod
@@ -869,6 +885,10 @@ def _prepare_gangs_for_deployment(
869885
fallback_strategy=fallback_strategy,
870886
)
871887
)
888+
from ray.serve._private.default_impl import _ReplicaPlacementGroup
889+
890+
if isinstance(pg, _ReplicaPlacementGroup):
891+
pg = pg.placement_group
872892
gang_pgs.append(pg)
873893
gang_ids.append(gang_id)
874894
gang_pg_names.append(pg_name)

0 commit comments

Comments
 (0)