Skip to content

Commit c547e2a

Browse files
committed
Add AcceleratorConfig to Serve and fix gang scheduling
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
1 parent 6a1511d commit c547e2a

17 files changed

Lines changed: 673 additions & 245 deletions

File tree

python/ray/llm/_internal/serve/core/configs/accelerators.py

Lines changed: 13 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,9 @@
66
from typing_extensions import Annotated
77

88
import ray.util.accelerators.accelerators as accelerators
9-
from ray._private.accelerators.tpu import get_chips_per_host
109
from ray.llm._internal.serve.observability.logging import get_logger
1110
from ray.util.placement_group import PlacementGroup, placement_group
12-
from ray.util.tpu import (
13-
get_num_chips_from_topology,
14-
get_tpu_version_from_type,
15-
slice_placement_group,
16-
)
11+
from ray.util.tpu import get_tpu_version_from_type, slice_placement_group
1712

1813
logger = get_logger(__name__)
1914

@@ -32,21 +27,6 @@ def format_ray_accelerator_resource(accelerator_type_str: str) -> str:
3227
return f"accelerator_type:{accelerator_type_str}"
3328

3429

35-
def get_inferred_tensor_parallel_size(topology: Optional[str]) -> Optional[int]:
36-
"""Infers the tensor parallel size from the TPU topology."""
37-
if not topology:
38-
return None
39-
40-
try:
41-
return get_num_chips_from_topology(topology)
42-
except ValueError as e:
43-
logger.warning(
44-
f"Failed to infer tensor_parallel_size from topology '{topology}': {e}. "
45-
"Defaulting to None."
46-
)
47-
return None
48-
49-
5030
def infer_hardware_kind_from_bundles(
5131
placement_group_config: Optional[Dict[str, Any]]
5232
) -> Optional[str]:
@@ -200,35 +180,10 @@ def __init__(self, config: TPUConfig):
200180
def default_bundles(
201181
self, *, num_devices: int, accelerator_type_str: Optional[str] = None
202182
):
203-
if not self._config.topology:
204-
# Fallback to per-chip bundles if no topology is specified
205-
bundle = {"TPU": 1}
206-
if accelerator_type_str:
207-
bundle[format_ray_accelerator_resource(accelerator_type_str)] = 0.001
208-
return [bundle.copy() for _ in range(num_devices)]
209-
210-
# Topology is specified, compute per-host bundles
211-
if not accelerator_type_str:
212-
raise ValueError(
213-
"`accelerator_type` must be specified when `topology` is present "
214-
"in order to compute TPU resource requirements."
215-
)
216-
version = get_tpu_version_from_type(accelerator_type_str)
217-
chips_per_host = get_chips_per_host(self._config.topology, version)
218-
219-
if num_devices > chips_per_host and num_devices % chips_per_host != 0:
220-
raise ValueError(
221-
f"num_devices ({num_devices}) must be a multiple of "
222-
f"chips_per_host ({chips_per_host}) for TPU topologies."
223-
)
224-
225-
num_hosts = max(1, num_devices // chips_per_host)
226-
227-
tpu_resources = min(num_devices, chips_per_host)
228-
bundle = {"TPU": tpu_resources}
229-
bundle[format_ray_accelerator_resource(accelerator_type_str)] = 0.001
230-
231-
return [bundle.copy() for _ in range(num_hosts)]
183+
bundle = {"TPU": 1}
184+
if accelerator_type_str:
185+
bundle[format_ray_accelerator_resource(accelerator_type_str)] = 0.001
186+
return [bundle.copy() for _ in range(num_devices)]
232187

233188
def create_placement_group(
234189
self,
@@ -299,15 +254,11 @@ def requires_remote_initialization(self) -> bool:
299254
return True
300255

301256
def get_remote_options(self, accelerator_type_str: str = None):
302-
# The PlacementGroupSchedulingStrategy natively handles routing the task to
303-
# the correct hardware. We omit TPU resource requests to avoid consuming
304-
# chips that the model engine workers must use.
305-
options: Dict[str, Any] = {"resources": {}}
257+
# TPUs use custom resource strings rather than a native kwarg
258+
options: Dict[str, Any] = {"resources": {"TPU": 0.001}}
259+
306260
if accelerator_type_str:
307-
# Pin the task to the TPU accelerator to avoid scheduling on a CPU bundle.
308-
options["label_selector"] = {
309-
"ray.io/accelerator-type": accelerator_type_str
310-
}
261+
options["accelerator_type"] = accelerator_type_str
311262
return options
312263

313264
def shutdown(self):
@@ -319,3 +270,7 @@ def shutdown(self):
319270
logger.warning(f"Failed to shut down TPU slice PG: {e}")
320271
finally:
321272
self._slice_pg_wrapper = None
273+
274+
def __del__(self):
275+
"""Ensure placement groups are cleaned up when this backend is garbage collected."""
276+
self.shutdown()

python/ray/llm/_internal/serve/engines/vllm/vllm_models.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
TPUAccelerator,
2424
TPUConfig,
2525
format_ray_accelerator_resource,
26-
get_inferred_tensor_parallel_size,
2726
)
2827
from ray.llm._internal.serve.core.configs.llm_config import (
2928
AcceleratorType,
@@ -194,19 +193,6 @@ def from_llm_config(cls, llm_config: LLMConfig) -> "VLLMEngineConfig":
194193
mirror_config = llm_config.model_loading_config.model_source
195194

196195
all_engine_kwargs = llm_config.engine_kwargs.copy()
197-
198-
# If tensor_parallel_size is not specified, try to infer it from topology
199-
if "tensor_parallel_size" not in all_engine_kwargs:
200-
if isinstance(llm_config.accelerator_config, TPUConfig):
201-
total_chips = get_inferred_tensor_parallel_size(
202-
llm_config.accelerator_config.topology
203-
)
204-
if total_chips is not None:
205-
all_engine_kwargs["tensor_parallel_size"] = total_chips
206-
logger.info(
207-
f"Inferred tensor_parallel_size={total_chips} from TPUConfig."
208-
)
209-
210196
engine_kwargs = {}
211197
frontend_kwargs = {}
212198

python/ray/llm/tests/serve/cpu/configs/test_models.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -417,52 +417,6 @@ def test_requires_deferred_placement_group(self):
417417
tpu_accel_with_topo = TPUAccelerator(TPUConfig(kind="tpu", topology="4x4"))
418418
assert tpu_accel_with_topo.requires_deferred_placement_group is True
419419

420-
@pytest.mark.parametrize(
421-
"topology,num_devices,accelerator_type_str,expected_bundles_count,expected_chips_per_host",
422-
[
423-
("1x1", 1, "TPU-V6E", 1, 1),
424-
("1x1", 1, "TPU-V7X", 1, 1),
425-
("4x4", 16, "TPU-V6E", 4, 4),
426-
("2x2x2", 8, "TPU-V5P", 2, 4),
427-
("2x2", 4, "TPU-V5LITEPOD", 1, 4),
428-
("2x2x1", 4, "TPU-V4", 1, 4),
429-
("2x4", 8, "TPU-V6E", 1, 8),
430-
],
431-
)
432-
def test_default_bundles_topology(
433-
self,
434-
topology,
435-
num_devices,
436-
accelerator_type_str,
437-
expected_bundles_count,
438-
expected_chips_per_host,
439-
):
440-
"""Test that different topologies return correct per-host bundles."""
441-
tpu_accel = TPUAccelerator(TPUConfig(kind="tpu", topology=topology))
442-
bundles = tpu_accel.default_bundles(
443-
num_devices=num_devices, accelerator_type_str=accelerator_type_str
444-
)
445-
446-
assert len(bundles) == expected_bundles_count
447-
for bundle in bundles:
448-
assert bundle["TPU"] == expected_chips_per_host
449-
assert f"accelerator_type:{accelerator_type_str}" in bundle
450-
451-
def test_default_bundles_topology_missing_accelerator_type_raises(self):
452-
"""Test that ValueError is raised when topology is present but accelerator type is missing."""
453-
tpu_accel = TPUAccelerator(TPUConfig(kind="tpu", topology="4x4"))
454-
with pytest.raises(
455-
ValueError,
456-
match="`accelerator_type` must be specified when `topology` is present",
457-
):
458-
tpu_accel.default_bundles(num_devices=16, accelerator_type_str=None)
459-
460-
def test_default_bundles_topology_non_multiple_num_devices_raises(self):
461-
"""Test that ValueError is raised when num_devices is not a multiple of chips_per_host."""
462-
tpu_accel = TPUAccelerator(TPUConfig(kind="tpu", topology="4x4"))
463-
with pytest.raises(ValueError, match="must be a multiple of chips_per_host"):
464-
tpu_accel.default_bundles(num_devices=6, accelerator_type_str="TPU-V6E")
465-
466420

467421
if __name__ == "__main__":
468422
sys.exit(pytest.main(["-v", __file__]))

python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_engine_tpu.py

Lines changed: 41 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -26,33 +26,29 @@ def test_tpu_slice_placement_group_creation_default_resources(ray_tpu_cluster):
2626
llm_config = LLMConfig(
2727
model_loading_config=ModelLoadingConfig(model_id="test-tpu-model"),
2828
accelerator_type="TPU-V6E",
29-
accelerator_config=TPUConfig(kind="tpu", topology="4x4"),
29+
accelerator_config={"kind": "tpu", "topology": "4x4"},
3030
)
3131

3232
engine_config = llm_config.get_engine_config()
33+
pg = engine_config.get_or_create_pg()
3334

34-
pg = None
35-
try:
36-
pg = engine_config.get_or_create_pg()
35+
assert isinstance(pg, PlacementGroup)
3736

38-
assert isinstance(pg, PlacementGroup)
37+
pg_table = placement_group_table(pg)
38+
assert pg_table["strategy"] == "PACK"
3939

40-
pg_table = placement_group_table(pg)
41-
assert pg_table["strategy"] == "PACK"
40+
# 4x4 v6e = 16 chips. We default to 1 TPU chip per bundle.
41+
assert len(pg_table["bundles"]) == 16
42+
for bundle in pg_table["bundles"].values():
43+
assert "TPU" in bundle
44+
assert bundle["TPU"] == 1
4245

43-
# 4x4 v6e = 16 chips. We default to 4 TPU chips per bundle (per-host).
44-
assert len(pg_table["bundles"]) == 4
45-
for bundle in pg_table["bundles"].values():
46-
assert "TPU" in bundle
47-
assert bundle["TPU"] == 4.0
48-
finally:
49-
# Let the backend tear down its own resources if it has any
50-
engine_config.accelerator.shutdown()
51-
if pg is not None:
52-
try:
53-
ray.util.remove_placement_group(pg)
54-
except Exception:
55-
pass
46+
# Let the backend tear down its own resources if it has any
47+
engine_config.accelerator.shutdown()
48+
try:
49+
ray.util.remove_placement_group(pg)
50+
except Exception:
51+
pass # Already cleaned up by the wrapper
5652

5753

5854
def test_tpu_slice_placement_group_creation_host_resources(ray_tpu_cluster):
@@ -63,36 +59,32 @@ def test_tpu_slice_placement_group_creation_host_resources(ray_tpu_cluster):
6359
llm_config = LLMConfig(
6460
model_loading_config=ModelLoadingConfig(model_id="test-tpu-model"),
6561
accelerator_type="TPU-V6E",
66-
accelerator_config=TPUConfig(kind="tpu", topology="4x4"),
62+
accelerator_config={"kind": "tpu", "topology": "4x4"},
6763
placement_group_config={
6864
"strategy": "STRICT_SPREAD",
69-
"bundles": [{"TPU": 4}] * 4,
65+
"bundles": [{"TPU": 4}],
7066
},
7167
)
7268

7369
engine_config = llm_config.get_engine_config()
70+
pg = engine_config.get_or_create_pg()
71+
72+
assert isinstance(pg, PlacementGroup)
7473

75-
pg = None
74+
pg_table = placement_group_table(pg)
75+
assert pg_table["strategy"] == "STRICT_SPREAD"
76+
# We should provision 4 host-level bundles instead of the default 16 chip-level bundles.
77+
assert len(pg_table["bundles"]) == 4
78+
for bundle in pg_table["bundles"].values():
79+
assert "TPU" in bundle
80+
assert bundle["TPU"] == 4
81+
82+
# Let the backend tear down its own resources if it has any
83+
engine_config.accelerator.shutdown()
7684
try:
77-
pg = engine_config.get_or_create_pg()
78-
79-
assert isinstance(pg, PlacementGroup)
80-
81-
pg_table = placement_group_table(pg)
82-
assert pg_table["strategy"] == "STRICT_SPREAD"
83-
# We should provision 4 host-level bundles instead of the default 16 chip-level bundles.
84-
assert len(pg_table["bundles"]) == 4
85-
for bundle in pg_table["bundles"].values():
86-
assert "TPU" in bundle
87-
assert bundle["TPU"] == 4
88-
finally:
89-
# Let the backend tear down its own resources if it has any
90-
engine_config.accelerator.shutdown()
91-
if pg is not None:
92-
try:
93-
ray.util.remove_placement_group(pg)
94-
except Exception:
95-
pass
85+
ray.util.remove_placement_group(pg)
86+
except Exception:
87+
pass # Already cleaned up by the wrapper
9688

9789

9890
def test_single_tpu_fallback(ray_tpu_cluster):
@@ -229,17 +221,15 @@ def test_tpu_slice_placement_group_creation_cpu_driver_homogeneous_tpu_bundles_p
229221
pass
230222

231223

232-
def test_tpu_serve_deployment_default_host_level_bundles(ray_tpu_cluster):
224+
def test_tpu_serve_deployment_default_chip_level_bundles(ray_tpu_cluster):
233225
"""
234226
Verifies that a Serve deployment created for a multi-host TPU slice defaults
235-
to host-level bundles when no placement_group_config is specified.
227+
to chip-level bundles when no placement_group_config is specified.
236228
"""
237-
from ray.llm._internal.serve.core.configs.accelerators import TPUConfig
238-
239229
llm_config = LLMConfig(
240230
model_loading_config=ModelLoadingConfig(model_id="test-tpu-model"),
241231
accelerator_type="TPU-V6E",
242-
accelerator_config=TPUConfig(kind="tpu", topology="4x4"),
232+
accelerator_config={"kind": "tpu", "topology": "4x4"},
243233
)
244234

245235
app = serve.deployment(LLMServer).bind(llm_config, engine_cls=PGCreationMockEngine)
@@ -266,10 +256,10 @@ def test_tpu_serve_deployment_default_host_level_bundles(ray_tpu_cluster):
266256
worker_pg = [pg for pg in active_pgs if pg not in head_pgs][0]
267257

268258
assert worker_pg["strategy"] == "PACK"
269-
# 4x4 topology = 16 chips. Default is 4 bundles of 4 TPUs (per-host).
270-
assert len(worker_pg["bundles"]) == 4
259+
# 4x4 topology = 16 chips. Default is 16 bundles of 1 TPU.
260+
assert len(worker_pg["bundles"]) == 16
271261
for bundle in worker_pg["bundles"].values():
272-
assert bundle.get("TPU", 0) == 4.0
262+
assert bundle.get("TPU", 0) == 1
273263

274264
serve.shutdown()
275265

@@ -282,7 +272,7 @@ def test_tpu_serve_deployment_explicit_host_level_bundles(ray_tpu_cluster):
282272
llm_config = LLMConfig(
283273
model_loading_config=ModelLoadingConfig(model_id="test-tpu-model"),
284274
accelerator_type="TPU-V6E",
285-
accelerator_config=TPUConfig(kind="tpu", topology="4x4"),
275+
accelerator_config={"kind": "tpu", "topology": "4x4"},
286276
placement_group_config={"bundle_per_worker": {"TPU": 4}},
287277
)
288278

@@ -318,52 +308,5 @@ def test_tpu_serve_deployment_explicit_host_level_bundles(ray_tpu_cluster):
318308
serve.shutdown()
319309

320310

321-
def test_tpu_serve_deployment_explicit_per_chip_bundles(ray_tpu_cluster):
322-
"""
323-
Verifies that a user can explicitly request chip-level bundles (1 TPU per bundle)
324-
for a full multi-host TPU slice via placement_group_config.
325-
"""
326-
from ray.llm._internal.serve.core.configs.accelerators import TPUConfig
327-
328-
llm_config = LLMConfig(
329-
model_loading_config=ModelLoadingConfig(model_id="test-tpu-model"),
330-
accelerator_type="TPU-V6E",
331-
accelerator_config=TPUConfig(kind="tpu", topology="4x4"),
332-
placement_group_config={"bundle_per_worker": {"TPU": 1}},
333-
engine_kwargs={"tensor_parallel_size": 16},
334-
)
335-
336-
app = serve.deployment(LLMServer).bind(llm_config, engine_cls=PGCreationMockEngine)
337-
serve.run(app)
338-
339-
pg_table = ray.util.placement_group_table()
340-
active_pgs = list(
341-
{k: v for k, v in pg_table.items() if v["state"] == "CREATED"}.values()
342-
)
343-
344-
assert (
345-
len(active_pgs) == 2
346-
), "Expected 2 PGs - one for TPU Head, one for worker bundles"
347-
348-
tpu_head_resource = "TPU-v6e-16-head"
349-
head_pgs = [
350-
pg
351-
for pg in active_pgs
352-
if len(pg["bundles"]) == 1
353-
and tpu_head_resource in list(pg["bundles"].values())[0]
354-
]
355-
assert len(head_pgs) == 1
356-
357-
worker_pg = [pg for pg in active_pgs if pg not in head_pgs][0]
358-
359-
assert worker_pg["strategy"] == "PACK"
360-
# 4x4 topology = 16 chips. Explicitly requested 16 bundles of 1 TPU.
361-
assert len(worker_pg["bundles"]) == 16
362-
for bundle in worker_pg["bundles"].values():
363-
assert bundle.get("TPU", 0) == 1.0
364-
365-
serve.shutdown()
366-
367-
368311
if __name__ == "__main__":
369312
sys.exit(pytest.main(["-v", __file__]))

0 commit comments

Comments
 (0)