Skip to content

Commit 6f172f3

Browse files
committed
use Serve's TPUAcceleratorConfig in Ray LLM
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com> Update Ray LLM to use TPUAcceleratorConfig Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
1 parent 8eab85a commit 6f172f3

6 files changed

Lines changed: 191 additions & 106 deletions

File tree

python/ray/llm/_internal/serve/core/configs/accelerators.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import ray.util.accelerators.accelerators as accelerators
99
from ray.llm._internal.serve.observability.logging import get_logger
10+
from ray.serve.config import TPUAcceleratorConfig
1011
from ray.util.placement_group import PlacementGroup, placement_group
1112
from ray.util.tpu import get_tpu_version_from_type, slice_placement_group
1213

@@ -111,6 +112,15 @@ def get_remote_options(self, accelerator_type_str: str = None) -> Dict[str, Any]
111112
"""Returns the hardware-specific kwargs for ray.remote().options()."""
112113
pass
113114

115+
def get_deployment_options(
116+
self,
117+
*,
118+
accelerator_type: Optional[str] = None,
119+
placement_group_config: Optional[Dict[str, Any]] = None,
120+
) -> Dict[str, Any]:
121+
"""Returns Serve deployment options specific to this accelerator."""
122+
return {}
123+
114124
def shutdown(self) -> None:
115125
"""Release any resources owned by this backend. Idempotent."""
116126
return
@@ -261,6 +271,34 @@ def get_remote_options(self, accelerator_type_str: str = None):
261271
options["accelerator_type"] = accelerator_type_str
262272
return options
263273

274+
def get_deployment_options(
275+
self,
276+
*,
277+
accelerator_type: Optional[str] = None,
278+
placement_group_config: Optional[Dict[str, Any]] = None,
279+
) -> Dict[str, Any]:
280+
if not self._config.topology:
281+
return {}
282+
283+
if not accelerator_type:
284+
raise ValueError(
285+
"accelerator_type must be specified when "
286+
"accelerator_config is a TPUConfig with topology."
287+
)
288+
289+
version = get_tpu_version_from_type(accelerator_type)
290+
resources_per_bundle = (placement_group_config or {}).get("bundle_per_worker")
291+
292+
return {
293+
"accelerator_config": TPUAcceleratorConfig(
294+
topology=self._config.topology,
295+
accelerator_version=version,
296+
num_slices=getattr(self._config, "num_slices", 1),
297+
chips_per_vm=getattr(self._config, "chips_per_vm", None),
298+
resources_per_bundle=resources_per_bundle,
299+
)
300+
}
301+
264302
def shutdown(self):
265303
if self._slice_pg_wrapper is not None:
266304
try:

python/ray/llm/_internal/serve/core/server/llm_server.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,4 +740,12 @@ def get_deployment_options(cls, llm_config: "LLMConfig"):
740740
}
741741
deployment_options["ray_actor_options"] = ray_actor_options
742742

743+
# Let the accelerator backend populate hardware-specific deployment options.
744+
deployment_options.update(
745+
engine_config.accelerator.get_deployment_options(
746+
accelerator_type=llm_config.accelerator_type,
747+
placement_group_config=llm_config.placement_group_config,
748+
)
749+
)
750+
743751
return deployment_options

python/ray/llm/tests/serve/cpu/deployments/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ def ray_tpu_cluster():
2020
"""
2121
Simulates a Ray cluster with a multi-host TPU v6e-16 slice (4x4 topology).
2222
"""
23+
if ray.is_initialized():
24+
ray.shutdown()
25+
2326
pod_type = "v6e-16"
2427
topology = "4x4"
2528

python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_engine_tpu.py

Lines changed: 119 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import sys
2+
import time
23

34
import pytest
45

@@ -30,25 +31,28 @@ def test_tpu_slice_placement_group_creation_default_resources(ray_tpu_cluster):
3031
)
3132

3233
engine_config = llm_config.get_engine_config()
33-
pg = engine_config.get_or_create_pg()
34-
35-
assert isinstance(pg, PlacementGroup)
34+
pg = None
35+
try:
36+
pg = engine_config.get_or_create_pg()
3637

37-
pg_table = placement_group_table(pg)
38-
assert pg_table["strategy"] == "PACK"
38+
assert isinstance(pg, PlacementGroup)
3939

40-
# 4x4 v6e = 16 chips. We default to 1 TPU chip per bundle.
41-
assert len(pg_table["bundles"]) == 16
42-
for bundle in pg_table["bundles"].values():
43-
assert "TPU" in bundle
44-
assert bundle["TPU"] == 1
40+
pg_table = placement_group_table(pg)
41+
assert pg_table["strategy"] == "PACK"
4542

46-
# Let the backend tear down its own resources if it has any
47-
engine_config.accelerator.shutdown()
48-
try:
49-
ray.util.remove_placement_group(pg)
50-
except Exception:
51-
pass # Already cleaned up by the wrapper
43+
# 4x4 v6e = 16 chips. We default to 1 TPU chip per bundle.
44+
assert len(pg_table["bundles"]) == 16
45+
for bundle in pg_table["bundles"].values():
46+
assert "TPU" in bundle
47+
assert bundle["TPU"] == 1
48+
finally:
49+
# Let the backend tear down its own resources if it has any
50+
engine_config.accelerator.shutdown()
51+
if pg is not None:
52+
try:
53+
ray.util.remove_placement_group(pg)
54+
except Exception:
55+
pass
5256

5357

5458
def test_tpu_slice_placement_group_creation_host_resources(ray_tpu_cluster):
@@ -67,24 +71,27 @@ def test_tpu_slice_placement_group_creation_host_resources(ray_tpu_cluster):
6771
)
6872

6973
engine_config = llm_config.get_engine_config()
70-
pg = engine_config.get_or_create_pg()
71-
72-
assert isinstance(pg, PlacementGroup)
73-
74-
pg_table = placement_group_table(pg)
75-
assert pg_table["strategy"] == "STRICT_SPREAD"
76-
# We should provision 4 host-level bundles instead of the default 16 chip-level bundles.
77-
assert len(pg_table["bundles"]) == 4
78-
for bundle in pg_table["bundles"].values():
79-
assert "TPU" in bundle
80-
assert bundle["TPU"] == 4
81-
82-
# Let the backend tear down its own resources if it has any
83-
engine_config.accelerator.shutdown()
74+
pg = None
8475
try:
85-
ray.util.remove_placement_group(pg)
86-
except Exception:
87-
pass # Already cleaned up by the wrapper
76+
pg = engine_config.get_or_create_pg()
77+
78+
assert isinstance(pg, PlacementGroup)
79+
80+
pg_table = placement_group_table(pg)
81+
assert pg_table["strategy"] == "STRICT_SPREAD"
82+
# We should provision 4 host-level bundles instead of the default 16 chip-level bundles.
83+
assert len(pg_table["bundles"]) == 4
84+
for bundle in pg_table["bundles"].values():
85+
assert "TPU" in bundle
86+
assert bundle["TPU"] == 4
87+
finally:
88+
# Let the backend tear down its own resources if it has any
89+
engine_config.accelerator.shutdown()
90+
if pg is not None:
91+
try:
92+
ray.util.remove_placement_group(pg)
93+
except Exception:
94+
pass
8895

8996

9097
def test_single_tpu_fallback(ray_tpu_cluster):
@@ -98,20 +105,23 @@ def test_single_tpu_fallback(ray_tpu_cluster):
98105
)
99106

100107
engine_config = llm_config.get_engine_config()
101-
pg = engine_config.get_or_create_pg()
102-
103-
pg_table = placement_group_table(pg)
108+
pg = None
109+
try:
110+
pg = engine_config.get_or_create_pg()
104111

105-
# Verify it falls back to the default PACK strategy for 1 GPU/TPU
106-
assert len(pg_table["bundles"]) == 1
107-
assert pg_table["strategy"] == "PACK"
112+
pg_table = placement_group_table(pg)
108113

109-
# Let the backend tear down its own resources if it has any
110-
engine_config.accelerator.shutdown()
111-
try:
112-
ray.util.remove_placement_group(pg)
113-
except Exception:
114-
pass # Already cleaned up by the wrapper
114+
# Verify it falls back to the default PACK strategy for 1 GPU/TPU
115+
assert len(pg_table["bundles"]) == 1
116+
assert pg_table["strategy"] == "PACK"
117+
finally:
118+
# Let the backend tear down its own resources if it has any
119+
engine_config.accelerator.shutdown()
120+
if pg is not None:
121+
try:
122+
ray.util.remove_placement_group(pg)
123+
except Exception:
124+
pass
115125

116126

117127
def test_tpu_slice_placement_group_creation_bundle_per_worker(ray_tpu_cluster):
@@ -221,47 +231,49 @@ def test_tpu_slice_placement_group_creation_cpu_driver_homogeneous_tpu_bundles_p
221231
pass
222232

223233

224-
def test_tpu_serve_deployment_default_chip_level_bundles(ray_tpu_cluster):
234+
def test_tpu_serve_deployment_default_host_level_bundles(ray_tpu_cluster):
225235
"""
226236
Verifies that a Serve deployment created for a multi-host TPU slice defaults
227-
to chip-level bundles when no placement_group_config is specified.
237+
to host-level bundles when no placement_group_config is specified.
228238
"""
229239
llm_config = LLMConfig(
230240
model_loading_config=ModelLoadingConfig(model_id="test-tpu-model"),
231241
accelerator_type="TPU-V6E",
232242
accelerator_config={"kind": "tpu", "topology": "4x4"},
233243
)
234244

235-
app = serve.deployment(LLMServer).bind(llm_config, engine_cls=PGCreationMockEngine)
236-
serve.run(app)
237-
238-
pg_table = ray.util.placement_group_table()
239-
active_pgs = list(
240-
{k: v for k, v in pg_table.items() if v["state"] == "CREATED"}.values()
245+
serve_options = LLMServer.get_deployment_options(llm_config)
246+
app = serve.deployment(**serve_options)(LLMServer).bind(
247+
llm_config, engine_cls=PGCreationMockEngine
241248
)
242-
243-
assert (
244-
len(active_pgs) == 2
245-
), "Expected 2 PGs - one for TPU Head, one for worker bundles"
246-
247-
tpu_head_resource = "TPU-v6e-16-head"
248-
head_pgs = [
249-
pg
250-
for pg in active_pgs
251-
if len(pg["bundles"]) == 1
252-
and tpu_head_resource in list(pg["bundles"].values())[0]
253-
]
254-
assert len(head_pgs) == 1
255-
256-
worker_pg = [pg for pg in active_pgs if pg not in head_pgs][0]
257-
258-
assert worker_pg["strategy"] == "PACK"
259-
# 4x4 topology = 16 chips. Default is 16 bundles of 1 TPU.
260-
assert len(worker_pg["bundles"]) == 16
261-
for bundle in worker_pg["bundles"].values():
262-
assert bundle.get("TPU", 0) == 1
263-
264-
serve.shutdown()
249+
try:
250+
serve.run(app)
251+
252+
# Wait for the head PG to be removed (eventual consistency).
253+
start_time = time.time()
254+
timeout = 10
255+
while time.time() - start_time < timeout:
256+
pg_table = ray.util.placement_group_table()
257+
active_pgs = list(
258+
{k: v for k, v in pg_table.items() if v["state"] == "CREATED"}.values()
259+
)
260+
if len(active_pgs) == 1:
261+
break
262+
time.sleep(0.5)
263+
264+
assert (
265+
len(active_pgs) == 1
266+
), f"Expected exactly 1 active PG (the worker PG), but found {len(active_pgs)}. Head PG may not have been removed."
267+
268+
worker_pg = active_pgs[0]
269+
270+
assert worker_pg["strategy"] == "PACK"
271+
# 4x4 topology = 16 chips. Default is host-level bundles (4 bundles of 4 TPUs).
272+
assert len(worker_pg["bundles"]) == 4
273+
for bundle in worker_pg["bundles"].values():
274+
assert bundle.get("TPU", 0) == 4
275+
finally:
276+
serve.shutdown()
265277

266278

267279
def test_tpu_serve_deployment_explicit_host_level_bundles(ray_tpu_cluster):
@@ -276,36 +288,38 @@ def test_tpu_serve_deployment_explicit_host_level_bundles(ray_tpu_cluster):
276288
placement_group_config={"bundle_per_worker": {"TPU": 4}},
277289
)
278290

279-
app = serve.deployment(LLMServer).bind(llm_config, engine_cls=PGCreationMockEngine)
280-
serve.run(app)
281-
282-
pg_table = ray.util.placement_group_table()
283-
active_pgs = list(
284-
{k: v for k, v in pg_table.items() if v["state"] == "CREATED"}.values()
291+
serve_options = LLMServer.get_deployment_options(llm_config)
292+
app = serve.deployment(**serve_options)(LLMServer).bind(
293+
llm_config, engine_cls=PGCreationMockEngine
285294
)
286-
287-
assert (
288-
len(active_pgs) == 2
289-
), "Expected 2 PGs - one for TPU Head, one for worker bundles"
290-
291-
tpu_head_resource = "TPU-v6e-16-head"
292-
head_pgs = [
293-
pg
294-
for pg in active_pgs
295-
if len(pg["bundles"]) == 1
296-
and tpu_head_resource in list(pg["bundles"].values())[0]
297-
]
298-
assert len(head_pgs) == 1
299-
300-
worker_pg = [pg for pg in active_pgs if pg not in head_pgs][0]
301-
302-
assert worker_pg["strategy"] == "PACK"
303-
# 4x4 topology = 16 chips. With 4 TPUs per bundle, expect exactly 4 bundles.
304-
assert len(worker_pg["bundles"]) == 4
305-
for bundle in worker_pg["bundles"].values():
306-
assert bundle.get("TPU", 0) == 4
307-
308-
serve.shutdown()
295+
try:
296+
serve.run(app)
297+
298+
# Wait for the head PG to be removed (eventual consistency).
299+
start_time = time.time()
300+
timeout = 10
301+
while time.time() - start_time < timeout:
302+
pg_table = ray.util.placement_group_table()
303+
active_pgs = list(
304+
{k: v for k, v in pg_table.items() if v["state"] == "CREATED"}.values()
305+
)
306+
if len(active_pgs) == 1:
307+
break
308+
time.sleep(0.5)
309+
310+
assert (
311+
len(active_pgs) == 1
312+
), f"Expected exactly 1 active PG (the worker PG), but found {len(active_pgs)}. Head PG may not have been removed."
313+
314+
worker_pg = active_pgs[0]
315+
316+
assert worker_pg["strategy"] == "PACK"
317+
# 4x4 topology = 16 chips. With 4 TPUs per bundle, expect exactly 4 bundles.
318+
assert len(worker_pg["bundles"]) == 4
319+
for bundle in worker_pg["bundles"].values():
320+
assert bundle.get("TPU", 0) == 4
321+
finally:
322+
serve.shutdown()
309323

310324

311325
if __name__ == "__main__":

python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_server.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99

1010
from ray import serve
11+
from ray.llm._internal.serve.core.configs.accelerators import TPUConfig
1112
from ray.llm._internal.serve.core.configs.llm_config import (
1213
LLMConfig,
1314
LoraConfig,
@@ -692,6 +693,28 @@ def test_deferred_placement_group_for_tpu_topology(self):
692693
assert "placement_group_bundles" not in serve_options
693694
assert "placement_group_strategy" not in serve_options
694695

696+
def test_tpu_accelerator_config_translation(self):
697+
"""Test that TPUConfig is correctly translated to Serve TPUAcceleratorConfig."""
698+
699+
llm_config = LLMConfig(
700+
model_loading_config=ModelLoadingConfig(model_id="test-tpu-model"),
701+
accelerator_type="TPU-V6E",
702+
accelerator_config=TPUConfig(kind="tpu", topology="4x4"),
703+
placement_group_config={"bundle_per_worker": {"TPU": 1}},
704+
llm_engine="vLLM",
705+
)
706+
707+
serve_options = LLMServer.get_deployment_options(llm_config)
708+
709+
assert "placement_group_bundles" not in serve_options
710+
assert "placement_group_strategy" not in serve_options
711+
712+
assert "accelerator_config" in serve_options
713+
acc_config = serve_options["accelerator_config"]
714+
assert acc_config.topology == "4x4"
715+
assert acc_config.accelerator_version == "v6e"
716+
assert acc_config.resources_per_bundle == {"TPU": 1}
717+
695718

696719
if __name__ == "__main__":
697720
sys.exit(pytest.main(["-v", __file__]))

0 commit comments

Comments
 (0)