Skip to content

Commit 35059d2

Browse files
committed
Update Ray LLM to use TPUAcceleratorConfig
1 parent 9d52483 commit 35059d2

2 files changed

Lines changed: 26 additions & 7 deletions

File tree

python/ray/llm/_internal/serve/core/server/llm_server.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
MODEL_RESPONSE_BATCH_TIMEOUT_MS,
2323
RAYLLM_VLLM_ENGINE_CLS_ENV,
2424
)
25+
from ray.llm._internal.serve.core.configs.accelerators import TPUConfig
2526
from ray.llm._internal.serve.core.configs.llm_config import (
2627
DiskMultiplexConfig,
2728
LLMConfig,
@@ -39,6 +40,8 @@
3940
from ray.llm._internal.serve.utils.server_utils import (
4041
get_serve_request_id,
4142
)
43+
from ray.serve.config import TPUAcceleratorConfig
44+
from ray.util.tpu import get_tpu_version_from_type
4245

4346
if TYPE_CHECKING:
4447
from ray.llm._internal.serve.core.configs.openai_api_models import (
@@ -737,4 +740,20 @@ def get_deployment_options(cls, llm_config: "LLMConfig"):
737740
}
738741
deployment_options["ray_actor_options"] = ray_actor_options
739742

743+
if llm_config.accelerator_config is not None and isinstance(
744+
llm_config.accelerator_config, TPUConfig
745+
):
746+
if not llm_config.accelerator_type:
747+
raise ValueError(
748+
"llm_config.accelerator_type must be specified when "
749+
"accelerator_config is a TPUConfig."
750+
)
751+
version = get_tpu_version_from_type(llm_config.accelerator_type)
752+
753+
deployment_options["accelerator_config"] = TPUAcceleratorConfig(
754+
topology=llm_config.accelerator_config.topology,
755+
accelerator_version=version,
756+
num_slices=1,
757+
)
758+
740759
return deployment_options

python/ray/llm/tests/serve/cpu/deployments/llm/test_llm_engine_tpu.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ def test_tpu_slice_placement_group_creation_default_resources(ray_tpu_cluster):
3737
pg_table = placement_group_table(pg)
3838
assert pg_table["strategy"] == "PACK"
3939

40-
# 4x4 v6e = 16 chips. We default to 1 TPU chip per bundle.
41-
assert len(pg_table["bundles"]) == 16
40+
# 4x4 v6e = 16 chips. We default to 4 TPU chips per bundle (per-host).
41+
assert len(pg_table["bundles"]) == 4
4242
for bundle in pg_table["bundles"].values():
4343
assert "TPU" in bundle
44-
assert bundle["TPU"] == 1
44+
assert bundle["TPU"] == 4.0
4545

4646
# Let the backend tear down its own resources if it has any
4747
engine_config.accelerator.shutdown()
@@ -62,7 +62,7 @@ def test_tpu_slice_placement_group_creation_host_resources(ray_tpu_cluster):
6262
accelerator_config={"kind": "tpu", "topology": "4x4"},
6363
placement_group_config={
6464
"strategy": "STRICT_SPREAD",
65-
"bundles": [{"TPU": 4}],
65+
"bundles": [{"TPU": 4}] * 4,
6666
},
6767
)
6868

@@ -256,10 +256,10 @@ def test_tpu_serve_deployment_default_chip_level_bundles(ray_tpu_cluster):
256256
worker_pg = [pg for pg in active_pgs if pg not in head_pgs][0]
257257

258258
assert worker_pg["strategy"] == "PACK"
259-
# 4x4 topology = 16 chips. Default is 16 bundles of 1 TPU.
260-
assert len(worker_pg["bundles"]) == 16
259+
# 4x4 topology = 16 chips. Default is 4 bundles of 4 TPUs (per-host).
260+
assert len(worker_pg["bundles"]) == 4
261261
for bundle in worker_pg["bundles"].values():
262-
assert bundle.get("TPU", 0) == 1
262+
assert bundle.get("TPU", 0) == 4.0
263263

264264
serve.shutdown()
265265

0 commit comments

Comments
 (0)