11import sys
2+ import time
23
34import pytest
45
@@ -30,25 +31,28 @@ def test_tpu_slice_placement_group_creation_default_resources(ray_tpu_cluster):
3031 )
3132
3233 engine_config = llm_config .get_engine_config ()
33- pg = engine_config . get_or_create_pg ()
34-
35- assert isinstance ( pg , PlacementGroup )
34+ pg = None
35+ try :
36+ pg = engine_config . get_or_create_pg ( )
3637
37- pg_table = placement_group_table (pg )
38- assert pg_table ["strategy" ] == "PACK"
38+ assert isinstance (pg , PlacementGroup )
3939
40- # 4x4 v6e = 16 chips. We default to 1 TPU chip per bundle.
41- assert len (pg_table ["bundles" ]) == 16
42- for bundle in pg_table ["bundles" ].values ():
43- assert "TPU" in bundle
44- assert bundle ["TPU" ] == 1
40+ pg_table = placement_group_table (pg )
41+ assert pg_table ["strategy" ] == "PACK"
4542
46- # Let the backend tear down its own resources if it has any
47- engine_config .accelerator .shutdown ()
48- try :
49- ray .util .remove_placement_group (pg )
50- except Exception :
51- pass # Already cleaned up by the wrapper
43+ # 4x4 v6e = 16 chips. We default to 1 TPU chip per bundle.
44+ assert len (pg_table ["bundles" ]) == 16
45+ for bundle in pg_table ["bundles" ].values ():
46+ assert "TPU" in bundle
47+ assert bundle ["TPU" ] == 1
48+ finally :
49+ # Let the backend tear down its own resources if it has any
50+ engine_config .accelerator .shutdown ()
51+ if pg is not None :
52+ try :
53+ ray .util .remove_placement_group (pg )
54+ except Exception :
55+ pass
5256
5357
5458def test_tpu_slice_placement_group_creation_host_resources (ray_tpu_cluster ):
@@ -67,24 +71,27 @@ def test_tpu_slice_placement_group_creation_host_resources(ray_tpu_cluster):
6771 )
6872
6973 engine_config = llm_config .get_engine_config ()
70- pg = engine_config .get_or_create_pg ()
71-
72- assert isinstance (pg , PlacementGroup )
73-
74- pg_table = placement_group_table (pg )
75- assert pg_table ["strategy" ] == "STRICT_SPREAD"
76- # We should provision 4 host-level bundles instead of the default 16 chip-level bundles.
77- assert len (pg_table ["bundles" ]) == 4
78- for bundle in pg_table ["bundles" ].values ():
79- assert "TPU" in bundle
80- assert bundle ["TPU" ] == 4
81-
82- # Let the backend tear down its own resources if it has any
83- engine_config .accelerator .shutdown ()
74+ pg = None
8475 try :
85- ray .util .remove_placement_group (pg )
86- except Exception :
87- pass # Already cleaned up by the wrapper
76+ pg = engine_config .get_or_create_pg ()
77+
78+ assert isinstance (pg , PlacementGroup )
79+
80+ pg_table = placement_group_table (pg )
81+ assert pg_table ["strategy" ] == "STRICT_SPREAD"
82+ # We should provision 4 host-level bundles instead of the default 16 chip-level bundles.
83+ assert len (pg_table ["bundles" ]) == 4
84+ for bundle in pg_table ["bundles" ].values ():
85+ assert "TPU" in bundle
86+ assert bundle ["TPU" ] == 4
87+ finally :
88+ # Let the backend tear down its own resources if it has any
89+ engine_config .accelerator .shutdown ()
90+ if pg is not None :
91+ try :
92+ ray .util .remove_placement_group (pg )
93+ except Exception :
94+ pass
8895
8996
9097def test_single_tpu_fallback (ray_tpu_cluster ):
@@ -98,20 +105,23 @@ def test_single_tpu_fallback(ray_tpu_cluster):
98105 )
99106
100107 engine_config = llm_config .get_engine_config ()
101- pg = engine_config . get_or_create_pg ()
102-
103- pg_table = placement_group_table ( pg )
108+ pg = None
109+ try :
110+ pg = engine_config . get_or_create_pg ( )
104111
105- # Verify it falls back to the default PACK strategy for 1 GPU/TPU
106- assert len (pg_table ["bundles" ]) == 1
107- assert pg_table ["strategy" ] == "PACK"
112+ pg_table = placement_group_table (pg )
108113
109- # Let the backend tear down its own resources if it has any
110- engine_config .accelerator .shutdown ()
111- try :
112- ray .util .remove_placement_group (pg )
113- except Exception :
114- pass # Already cleaned up by the wrapper
114+ # Verify it falls back to the default PACK strategy for 1 GPU/TPU
115+ assert len (pg_table ["bundles" ]) == 1
116+ assert pg_table ["strategy" ] == "PACK"
117+ finally :
118+ # Let the backend tear down its own resources if it has any
119+ engine_config .accelerator .shutdown ()
120+ if pg is not None :
121+ try :
122+ ray .util .remove_placement_group (pg )
123+ except Exception :
124+ pass
115125
116126
117127def test_tpu_slice_placement_group_creation_bundle_per_worker (ray_tpu_cluster ):
@@ -221,47 +231,49 @@ def test_tpu_slice_placement_group_creation_cpu_driver_homogeneous_tpu_bundles_p
221231 pass
222232
223233
224- def test_tpu_serve_deployment_default_chip_level_bundles (ray_tpu_cluster ):
234+ def test_tpu_serve_deployment_default_host_level_bundles (ray_tpu_cluster ):
225235 """
226236 Verifies that a Serve deployment created for a multi-host TPU slice defaults
227- to chip -level bundles when no placement_group_config is specified.
237+ to host -level bundles when no placement_group_config is specified.
228238 """
229239 llm_config = LLMConfig (
230240 model_loading_config = ModelLoadingConfig (model_id = "test-tpu-model" ),
231241 accelerator_type = "TPU-V6E" ,
232242 accelerator_config = {"kind" : "tpu" , "topology" : "4x4" },
233243 )
234244
235- app = serve .deployment (LLMServer ).bind (llm_config , engine_cls = PGCreationMockEngine )
236- serve .run (app )
237-
238- pg_table = ray .util .placement_group_table ()
239- active_pgs = list (
240- {k : v for k , v in pg_table .items () if v ["state" ] == "CREATED" }.values ()
245+ serve_options = LLMServer .get_deployment_options (llm_config )
246+ app = serve .deployment (** serve_options )(LLMServer ).bind (
247+ llm_config , engine_cls = PGCreationMockEngine
241248 )
242-
243- assert (
244- len (active_pgs ) == 2
245- ), "Expected 2 PGs - one for TPU Head, one for worker bundles"
246-
247- tpu_head_resource = "TPU-v6e-16-head"
248- head_pgs = [
249- pg
250- for pg in active_pgs
251- if len (pg ["bundles" ]) == 1
252- and tpu_head_resource in list (pg ["bundles" ].values ())[0 ]
253- ]
254- assert len (head_pgs ) == 1
255-
256- worker_pg = [pg for pg in active_pgs if pg not in head_pgs ][0 ]
257-
258- assert worker_pg ["strategy" ] == "PACK"
259- # 4x4 topology = 16 chips. Default is 16 bundles of 1 TPU.
260- assert len (worker_pg ["bundles" ]) == 16
261- for bundle in worker_pg ["bundles" ].values ():
262- assert bundle .get ("TPU" , 0 ) == 1
263-
264- serve .shutdown ()
249+ try :
250+ serve .run (app )
251+
252+ # Wait for the head PG to be removed (eventual consistency).
253+ start_time = time .time ()
254+ timeout = 10
255+ while time .time () - start_time < timeout :
256+ pg_table = ray .util .placement_group_table ()
257+ active_pgs = list (
258+ {k : v for k , v in pg_table .items () if v ["state" ] == "CREATED" }.values ()
259+ )
260+ if len (active_pgs ) == 1 :
261+ break
262+ time .sleep (0.5 )
263+
264+ assert (
265+ len (active_pgs ) == 1
266+ ), f"Expected exactly 1 active PG (the worker PG), but found { len (active_pgs )} . Head PG may not have been removed."
267+
268+ worker_pg = active_pgs [0 ]
269+
270+ assert worker_pg ["strategy" ] == "PACK"
271+ # 4x4 topology = 16 chips. Default is host-level bundles (4 bundles of 4 TPUs).
272+ assert len (worker_pg ["bundles" ]) == 4
273+ for bundle in worker_pg ["bundles" ].values ():
274+ assert bundle .get ("TPU" , 0 ) == 4
275+ finally :
276+ serve .shutdown ()
265277
266278
267279def test_tpu_serve_deployment_explicit_host_level_bundles (ray_tpu_cluster ):
@@ -276,36 +288,38 @@ def test_tpu_serve_deployment_explicit_host_level_bundles(ray_tpu_cluster):
276288 placement_group_config = {"bundle_per_worker" : {"TPU" : 4 }},
277289 )
278290
279- app = serve .deployment (LLMServer ).bind (llm_config , engine_cls = PGCreationMockEngine )
280- serve .run (app )
281-
282- pg_table = ray .util .placement_group_table ()
283- active_pgs = list (
284- {k : v for k , v in pg_table .items () if v ["state" ] == "CREATED" }.values ()
291+ serve_options = LLMServer .get_deployment_options (llm_config )
292+ app = serve .deployment (** serve_options )(LLMServer ).bind (
293+ llm_config , engine_cls = PGCreationMockEngine
285294 )
286-
287- assert (
288- len (active_pgs ) == 2
289- ), "Expected 2 PGs - one for TPU Head, one for worker bundles"
290-
291- tpu_head_resource = "TPU-v6e-16-head"
292- head_pgs = [
293- pg
294- for pg in active_pgs
295- if len (pg ["bundles" ]) == 1
296- and tpu_head_resource in list (pg ["bundles" ].values ())[0 ]
297- ]
298- assert len (head_pgs ) == 1
299-
300- worker_pg = [pg for pg in active_pgs if pg not in head_pgs ][0 ]
301-
302- assert worker_pg ["strategy" ] == "PACK"
303- # 4x4 topology = 16 chips. With 4 TPUs per bundle, expect exactly 4 bundles.
304- assert len (worker_pg ["bundles" ]) == 4
305- for bundle in worker_pg ["bundles" ].values ():
306- assert bundle .get ("TPU" , 0 ) == 4
307-
308- serve .shutdown ()
295+ try :
296+ serve .run (app )
297+
298+ # Wait for the head PG to be removed (eventual consistency).
299+ start_time = time .time ()
300+ timeout = 10
301+ while time .time () - start_time < timeout :
302+ pg_table = ray .util .placement_group_table ()
303+ active_pgs = list (
304+ {k : v for k , v in pg_table .items () if v ["state" ] == "CREATED" }.values ()
305+ )
306+ if len (active_pgs ) == 1 :
307+ break
308+ time .sleep (0.5 )
309+
310+ assert (
311+ len (active_pgs ) == 1
312+ ), f"Expected exactly 1 active PG (the worker PG), but found { len (active_pgs )} . Head PG may not have been removed."
313+
314+ worker_pg = active_pgs [0 ]
315+
316+ assert worker_pg ["strategy" ] == "PACK"
317+ # 4x4 topology = 16 chips. With 4 TPUs per bundle, expect exactly 4 bundles.
318+ assert len (worker_pg ["bundles" ]) == 4
319+ for bundle in worker_pg ["bundles" ].values ():
320+ assert bundle .get ("TPU" , 0 ) == 4
321+ finally :
322+ serve .shutdown ()
309323
310324
311325if __name__ == "__main__" :
0 commit comments