Skip to content

Commit 812ef63

Browse files
committed
Update Ray LLM to use TPUAcceleratorConfig
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
1 parent f96ef1e commit 812ef63

4 files changed

Lines changed: 131 additions & 98 deletions

File tree

python/ray/llm/_internal/serve/core/server/llm_server.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
MODEL_RESPONSE_BATCH_TIMEOUT_MS,
2323
RAYLLM_VLLM_ENGINE_CLS_ENV,
2424
)
25+
from ray.llm._internal.serve.core.configs.accelerators import TPUConfig
2526
from ray.llm._internal.serve.core.configs.llm_config import (
2627
DiskMultiplexConfig,
2728
LLMConfig,
@@ -39,6 +40,8 @@
3940
from ray.llm._internal.serve.utils.server_utils import (
4041
get_serve_request_id,
4142
)
43+
from ray.serve.config import TPUAcceleratorConfig
44+
from ray.util.tpu import get_tpu_version_from_type
4245

4346
if TYPE_CHECKING:
4447
from ray.llm._internal.serve.core.configs.openai_api_models import (
@@ -740,4 +743,22 @@ def get_deployment_options(cls, llm_config: "LLMConfig"):
740743
}
741744
deployment_options["ray_actor_options"] = ray_actor_options
742745

746+
if (
747+
llm_config.accelerator_config is not None
748+
and isinstance(llm_config.accelerator_config, TPUConfig)
749+
and llm_config.accelerator_config.topology
750+
):
751+
if not llm_config.accelerator_type:
752+
raise ValueError(
753+
"llm_config.accelerator_type must be specified when "
754+
"accelerator_config is a TPUConfig."
755+
)
756+
version = get_tpu_version_from_type(llm_config.accelerator_type)
757+
758+
deployment_options["accelerator_config"] = TPUAcceleratorConfig(
759+
topology=llm_config.accelerator_config.topology,
760+
accelerator_version=version,
761+
num_slices=1,
762+
)
763+
743764
return deployment_options

python/ray/llm/tests/serve/cpu/deployments/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ def ray_tpu_cluster():
2020
"""
2121
Simulates a Ray cluster with a multi-host TPU v6e-16 slice (4x4 topology).
2222
"""
23+
if ray.is_initialized():
24+
ray.shutdown()
25+
2326
pod_type = "v6e-16"
2427
topology = "4x4"
2528

python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_engine_tpu.py

Lines changed: 107 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,28 @@ def test_tpu_slice_placement_group_creation_default_resources(ray_tpu_cluster):
3030
)
3131

3232
engine_config = llm_config.get_engine_config()
33-
pg = engine_config.get_or_create_pg()
34-
35-
assert isinstance(pg, PlacementGroup)
33+
pg = None
34+
try:
35+
pg = engine_config.get_or_create_pg()
3636

37-
pg_table = placement_group_table(pg)
38-
assert pg_table["strategy"] == "PACK"
37+
assert isinstance(pg, PlacementGroup)
3938

40-
# 4x4 v6e = 16 chips. We default to 1 TPU chip per bundle.
41-
assert len(pg_table["bundles"]) == 16
42-
for bundle in pg_table["bundles"].values():
43-
assert "TPU" in bundle
44-
assert bundle["TPU"] == 1
39+
pg_table = placement_group_table(pg)
40+
assert pg_table["strategy"] == "PACK"
4541

46-
# Let the backend tear down its own resources if it has any
47-
engine_config.accelerator.shutdown()
48-
try:
49-
ray.util.remove_placement_group(pg)
50-
except Exception:
51-
pass # Already cleaned up by the wrapper
42+
# 4x4 v6e = 16 chips. We default to 1 TPU chip per bundle.
43+
assert len(pg_table["bundles"]) == 16
44+
for bundle in pg_table["bundles"].values():
45+
assert "TPU" in bundle
46+
assert bundle["TPU"] == 1
47+
finally:
48+
# Let the backend tear down its own resources if it has any
49+
engine_config.accelerator.shutdown()
50+
if pg is not None:
51+
try:
52+
ray.util.remove_placement_group(pg)
53+
except Exception:
54+
pass
5255

5356

5457
def test_tpu_slice_placement_group_creation_host_resources(ray_tpu_cluster):
@@ -67,24 +70,27 @@ def test_tpu_slice_placement_group_creation_host_resources(ray_tpu_cluster):
6770
)
6871

6972
engine_config = llm_config.get_engine_config()
70-
pg = engine_config.get_or_create_pg()
71-
72-
assert isinstance(pg, PlacementGroup)
73-
74-
pg_table = placement_group_table(pg)
75-
assert pg_table["strategy"] == "STRICT_SPREAD"
76-
# We should provision 4 host-level bundles instead of the default 16 chip-level bundles.
77-
assert len(pg_table["bundles"]) == 4
78-
for bundle in pg_table["bundles"].values():
79-
assert "TPU" in bundle
80-
assert bundle["TPU"] == 4
81-
82-
# Let the backend tear down its own resources if it has any
83-
engine_config.accelerator.shutdown()
73+
pg = None
8474
try:
85-
ray.util.remove_placement_group(pg)
86-
except Exception:
87-
pass # Already cleaned up by the wrapper
75+
pg = engine_config.get_or_create_pg()
76+
77+
assert isinstance(pg, PlacementGroup)
78+
79+
pg_table = placement_group_table(pg)
80+
assert pg_table["strategy"] == "STRICT_SPREAD"
81+
# We should provision 4 host-level bundles instead of the default 16 chip-level bundles.
82+
assert len(pg_table["bundles"]) == 4
83+
for bundle in pg_table["bundles"].values():
84+
assert "TPU" in bundle
85+
assert bundle["TPU"] == 4
86+
finally:
87+
# Let the backend tear down its own resources if it has any
88+
engine_config.accelerator.shutdown()
89+
if pg is not None:
90+
try:
91+
ray.util.remove_placement_group(pg)
92+
except Exception:
93+
pass
8894

8995

9096
def test_single_tpu_fallback(ray_tpu_cluster):
@@ -98,20 +104,23 @@ def test_single_tpu_fallback(ray_tpu_cluster):
98104
)
99105

100106
engine_config = llm_config.get_engine_config()
101-
pg = engine_config.get_or_create_pg()
107+
pg = None
108+
try:
109+
pg = engine_config.get_or_create_pg()
102110

103-
pg_table = placement_group_table(pg)
111+
pg_table = placement_group_table(pg)
104112

105-
# Verify it falls back to the default PACK strategy for 1 GPU/TPU
106-
assert len(pg_table["bundles"]) == 1
107-
assert pg_table["strategy"] == "PACK"
108-
109-
# Let the backend tear down its own resources if it has any
110-
engine_config.accelerator.shutdown()
111-
try:
112-
ray.util.remove_placement_group(pg)
113-
except Exception:
114-
pass # Already cleaned up by the wrapper
113+
# Verify it falls back to the default PACK strategy for 1 GPU/TPU
114+
assert len(pg_table["bundles"]) == 1
115+
assert pg_table["strategy"] == "PACK"
116+
finally:
117+
# Let the backend tear down its own resources if it has any
118+
engine_config.accelerator.shutdown()
119+
if pg is not None:
120+
try:
121+
ray.util.remove_placement_group(pg)
122+
except Exception:
123+
pass
115124

116125

117126
def test_tpu_slice_placement_group_creation_bundle_per_worker(ray_tpu_cluster):
@@ -233,35 +242,36 @@ def test_tpu_serve_deployment_default_chip_level_bundles(ray_tpu_cluster):
233242
)
234243

235244
app = serve.deployment(LLMServer).bind(llm_config, engine_cls=PGCreationMockEngine)
236-
serve.run(app)
237-
238-
pg_table = ray.util.placement_group_table()
239-
active_pgs = list(
240-
{k: v for k, v in pg_table.items() if v["state"] == "CREATED"}.values()
241-
)
245+
try:
246+
serve.run(app)
242247

243-
assert (
244-
len(active_pgs) == 2
245-
), "Expected 2 PGs - one for TPU Head, one for worker bundles"
248+
pg_table = ray.util.placement_group_table()
249+
active_pgs = list(
250+
{k: v for k, v in pg_table.items() if v["state"] == "CREATED"}.values()
251+
)
246252

247-
tpu_head_resource = "TPU-v6e-16-head"
248-
head_pgs = [
249-
pg
250-
for pg in active_pgs
251-
if len(pg["bundles"]) == 1
252-
and tpu_head_resource in list(pg["bundles"].values())[0]
253-
]
254-
assert len(head_pgs) == 1
253+
assert (
254+
len(active_pgs) == 2
255+
), "Expected 2 PGs - one for TPU Head, one for worker bundles"
255256

256-
worker_pg = [pg for pg in active_pgs if pg not in head_pgs][0]
257+
tpu_head_resource = "TPU-v6e-16-head"
258+
head_pgs = [
259+
pg
260+
for pg in active_pgs
261+
if len(pg["bundles"]) == 1
262+
and tpu_head_resource in list(pg["bundles"].values())[0]
263+
]
264+
assert len(head_pgs) == 1
257265

258-
assert worker_pg["strategy"] == "PACK"
259-
# 4x4 topology = 16 chips. Default is 16 bundles of 1 TPU.
260-
assert len(worker_pg["bundles"]) == 16
261-
for bundle in worker_pg["bundles"].values():
262-
assert bundle.get("TPU", 0) == 1
266+
worker_pg = [pg for pg in active_pgs if pg not in head_pgs][0]
263267

264-
serve.shutdown()
268+
assert worker_pg["strategy"] == "PACK"
269+
# 4x4 topology = 16 chips. Default is 16 bundles of 1 TPU.
270+
assert len(worker_pg["bundles"]) == 16
271+
for bundle in worker_pg["bundles"].values():
272+
assert bundle.get("TPU", 0) == 1
273+
finally:
274+
serve.shutdown()
265275

266276

267277
def test_tpu_serve_deployment_explicit_host_level_bundles(ray_tpu_cluster):
@@ -277,35 +287,36 @@ def test_tpu_serve_deployment_explicit_host_level_bundles(ray_tpu_cluster):
277287
)
278288

279289
app = serve.deployment(LLMServer).bind(llm_config, engine_cls=PGCreationMockEngine)
280-
serve.run(app)
281-
282-
pg_table = ray.util.placement_group_table()
283-
active_pgs = list(
284-
{k: v for k, v in pg_table.items() if v["state"] == "CREATED"}.values()
285-
)
286-
287-
assert (
288-
len(active_pgs) == 2
289-
), "Expected 2 PGs - one for TPU Head, one for worker bundles"
290-
291-
tpu_head_resource = "TPU-v6e-16-head"
292-
head_pgs = [
293-
pg
294-
for pg in active_pgs
295-
if len(pg["bundles"]) == 1
296-
and tpu_head_resource in list(pg["bundles"].values())[0]
297-
]
298-
assert len(head_pgs) == 1
299-
300-
worker_pg = [pg for pg in active_pgs if pg not in head_pgs][0]
290+
try:
291+
serve.run(app)
301292

302-
assert worker_pg["strategy"] == "PACK"
303-
# 4x4 topology = 16 chips. With 4 TPUs per bundle, expect exactly 4 bundles.
304-
assert len(worker_pg["bundles"]) == 4
305-
for bundle in worker_pg["bundles"].values():
306-
assert bundle.get("TPU", 0) == 4
293+
pg_table = ray.util.placement_group_table()
294+
active_pgs = list(
295+
{k: v for k, v in pg_table.items() if v["state"] == "CREATED"}.values()
296+
)
307297

308-
serve.shutdown()
298+
assert (
299+
len(active_pgs) == 2
300+
), "Expected 2 PGs - one for TPU Head, one for worker bundles"
301+
302+
tpu_head_resource = "TPU-v6e-16-head"
303+
head_pgs = [
304+
pg
305+
for pg in active_pgs
306+
if len(pg["bundles"]) == 1
307+
and tpu_head_resource in list(pg["bundles"].values())[0]
308+
]
309+
assert len(head_pgs) == 1
310+
311+
worker_pg = [pg for pg in active_pgs if pg not in head_pgs][0]
312+
313+
assert worker_pg["strategy"] == "PACK"
314+
# 4x4 topology = 16 chips. With 4 TPUs per bundle, expect exactly 4 bundles.
315+
assert len(worker_pg["bundles"]) == 4
316+
for bundle in worker_pg["bundles"].values():
317+
assert bundle.get("TPU", 0) == 4
318+
finally:
319+
serve.shutdown()
309320

310321

311322
if __name__ == "__main__":

python/ray/serve/tests/test_accelerator_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def mock_tpu_cluster():
6060

6161
def test_tpu_accelerator_config_integration(mock_tpu_cluster):
6262
"""Test that AcceleratorConfig correctly creates SlicePlacementGroup in a mock cluster."""
63-
6463
tpu_config = TPUAcceleratorConfig(topology="4x4", accelerator_version="v6e")
6564

6665
request = CreatePlacementGroupRequest(
@@ -89,7 +88,6 @@ def test_tpu_accelerator_config_integration(mock_tpu_cluster):
8988
replica_pg.shutdown()
9089
assert replica_pg._slice_pg is None
9190

92-
9391
def test_tpu_accelerator_config_partial_failure_cleanup(mock_tpu_cluster):
9492
"""Test that SlicePlacementGroup cleans up head PGs if a multi-slice reservation fails."""
9593

0 commit comments

Comments
 (0)