@@ -420,13 +420,13 @@ def test_requires_deferred_placement_group(self):
420420 @pytest .mark .parametrize (
421421 "topology,num_devices,accelerator_type_str,expected_bundles_count,expected_chips_per_host" ,
422422 [
423- ("1x1" , 1 , "v6e" , 1 , 1.0 ),
424- ("1x1" , 1 , "v7x" , 1 , 4.0 ),
425- ("4x4" , 16 , "v6e" , 4 , 4.0 ),
426- ("2x2x2" , 8 , "v5p" , 2 , 4.0 ),
427- ("2x2" , 4 , "v5litepod" , 1 , 4.0 ),
428- ("2x2x1" , 4 , "v4" , 1 , 4.0 ),
429- ("2x4" , 8 , "v6e" , 1 , 8.0 ),
423+ ("1x1" , 1 , "v6e" , 1 , 1 ),
424+ ("1x1" , 1 , "v7x" , 1 , 4 ),
425+ ("4x4" , 16 , "v6e" , 4 , 4 ),
426+ ("2x2x2" , 8 , "v5p" , 2 , 4 ),
427+ ("2x2" , 4 , "v5litepod" , 1 , 4 ),
428+ ("2x2x1" , 4 , "v4" , 1 , 4 ),
429+ ("2x4" , 8 , "v6e" , 1 , 8 ),
430430 ],
431431 )
432432 def test_default_bundles_topology (
@@ -457,6 +457,40 @@ def test_default_bundles_topology_missing_accelerator_type_raises(self):
457457 ):
458458 tpu_accel .default_bundles (num_devices = 16 , accelerator_type_str = None )
459459
460+ def test_default_bundles_v6e_4x4 (self ):
461+ """Test that v6e 4x4 topology returns per-host bundles."""
462+ tpu_accel = TPUAccelerator (TPUConfig (kind = "tpu" , topology = "4x4" ))
463+ bundles = tpu_accel .default_bundles (num_devices = 16 , accelerator_type_str = "v6e" )
464+
465+ # 4x4 v6e = 16 chips. 4 chips per host -> 4 hosts.
466+ assert len (bundles ) == 4
467+ for bundle in bundles :
468+ assert bundle ["TPU" ] == 4.0
469+ assert "accelerator_type:v6e" in bundle
470+
471+ def test_default_bundles_v5p_2x2x2 (self ):
472+ """Test that v5p 2x2x2 topology returns per-host bundles."""
473+ tpu_accel = TPUAccelerator (TPUConfig (kind = "tpu" , topology = "2x2x2" ))
474+ bundles = tpu_accel .default_bundles (num_devices = 8 , accelerator_type_str = "v5p" )
475+
476+ # 2x2x2 v5p = 8 chips. 4 chips per host -> 2 hosts.
477+ assert len (bundles ) == 2
478+ for bundle in bundles :
479+ assert bundle ["TPU" ] == 4.0
480+ assert "accelerator_type:v5p" in bundle
481+
482+ def test_default_bundles_single_host_topology (self ):
483+ """Test that a single-host topology returns a single bundle."""
484+ tpu_accel = TPUAccelerator (TPUConfig (kind = "tpu" , topology = "2x2" ))
485+ bundles = tpu_accel .default_bundles (
486+ num_devices = 4 , accelerator_type_str = "v5litepod"
487+ )
488+
489+ # 2x2 v5litepod = 4 chips on 1 host.
490+ assert len (bundles ) == 1
491+ assert bundles [0 ]["TPU" ] == 4.0
492+ assert "accelerator_type:v5litepod" in bundles [0 ]
493+
460494
461495if __name__ == "__main__" :
462496 sys .exit (pytest .main (["-v" , __file__ ]))
0 commit comments