@@ -160,8 +160,10 @@ def get_tpu_worker_resources(
160160 """
161161 accelerator_version = get_tpu_version_from_type (accelerator_type )
162162
163- resolved_chips_per_vm = chips_per_vm or get_chips_per_host (
164- topology , accelerator_version
163+ resolved_chips_per_vm = (
164+ chips_per_vm
165+ if chips_per_vm is not None
166+ else get_chips_per_host (topology , accelerator_version )
165167 )
166168 total_chips_per_slice = get_num_chips_from_topology (topology )
167169
@@ -447,6 +449,8 @@ class SlicePlacementGroup:
447449 TPU head placement group to become ready. Defaults to
448450 ``DEFAULT_TPU_HEAD_RESERVATION_TIMEOUT_S``. Pass ``None`` to wait
449451 indefinitely.
452+ bundle_label_selector: Optional list of label selectors to apply per bundle. These label
453+ selectors are applied in addition to dynamic TPU slice name labels, which take precedence.
450454
451455 Examples:
452456
@@ -490,7 +494,13 @@ def __init__(
490494 head_reservation_timeout_s : Optional [float ] = (
491495 DEFAULT_TPU_HEAD_RESERVATION_TIMEOUT_S
492496 ),
497+ bundle_label_selector : Optional [List [Dict [str , str ]]] = None ,
493498 ):
499+ self ._head_pgs : List [PlacementGroup ] = []
500+ self ._bundle_label_selector : List [Dict [str , str ]] = []
501+ self ._placement_group : Optional [PlacementGroup ] = None
502+ self ._user_bundle_label_selector = bundle_label_selector or []
503+
494504 self ._topology = topology .strip ().lower ()
495505 self ._accelerator_version = get_tpu_version_from_type (
496506 accelerator_version .strip ()
@@ -508,8 +518,10 @@ def __init__(
508518 chips_per_vm = chips_per_vm ,
509519 )
510520
511- self ._chips_per_host = chips_per_vm or get_chips_per_host (
512- self ._topology , self ._accelerator_version
521+ self ._chips_per_host = (
522+ chips_per_vm
523+ if chips_per_vm is not None
524+ else get_chips_per_host (self ._topology , self ._accelerator_version )
513525 )
514526
515527 # Within Ray, a "host" corresponds to a user-visible compute VM.
@@ -518,10 +530,7 @@ def __init__(
518530 hosts_per_slice = max (1 , total_chips // self ._chips_per_host )
519531 self ._num_hosts = hosts_per_slice * self ._num_slices
520532
521- self ._head_pgs : List [PlacementGroup ] = []
522- self ._bundle_label_selector : List [Dict [str , str ]] = []
523533 self ._validate_tpu_config ()
524- self ._placement_group = None
525534
526535 # Reserve a TPU slice of the provided accelerator version and topology.
527536 self ._placement_group = self ._reserve_slice (
@@ -549,6 +558,15 @@ def _reserve_slice(
549558 lifetime : Optional [str ] = None ,
550559 ) -> PlacementGroup :
551560 """Performs the two-step scheduling to reserve a TPU slice."""
561+ if (
562+ self ._user_bundle_label_selector
563+ and len (self ._user_bundle_label_selector ) != self ._num_bundles
564+ ):
565+ raise ValueError (
566+ f"bundle_label_selector length ({ len (self ._user_bundle_label_selector )} ) must "
567+ f"match the number of bundles ({ self ._num_bundles } )."
568+ )
569+
552570 self ._bundle_label_selector = []
553571 bundles = []
554572 bundles_per_slice = self ._num_bundles // self ._num_slices
@@ -557,7 +575,7 @@ def _reserve_slice(
557575 accelerator_type = "TPU-" + self .accelerator_version .upper ()
558576
559577 try :
560- for _ in range (self .num_slices ):
578+ for slice_idx in range (self .num_slices ):
561579 reservation = reserve_tpu_slice (
562580 self ._topology ,
563581 accelerator_type ,
@@ -575,10 +593,20 @@ def _reserve_slice(
575593 slice_name , head_pg = reservation
576594 self ._head_pgs .append (head_pg )
577595
578- # Reserving a slice is done through constructing num_hosts bundles, each with a label selector for
579- # the unique name of an available TPU slice.
580- selector = {ray ._raylet .RAY_NODE_TPU_SLICE_NAME_KEY : slice_name }
581- self ._bundle_label_selector .extend ([selector ] * bundles_per_slice )
596+ dynamic_labels = {ray ._raylet .RAY_NODE_TPU_SLICE_NAME_KEY : slice_name }
597+
598+ for bundle_idx in range (bundles_per_slice ):
599+ global_bundle_idx = slice_idx * bundles_per_slice + bundle_idx
600+
601+ user_labels = (
602+ self ._user_bundle_label_selector [global_bundle_idx ]
603+ if global_bundle_idx < len (self ._user_bundle_label_selector )
604+ else {}
605+ )
606+ # Dynamic TPU slice labels take precedence; user labels fill in the rest.
607+ merged_labels = {** user_labels , ** dynamic_labels }
608+ self ._bundle_label_selector .append (merged_labels )
609+
582610 bundles += [
583611 self ._bundle_resources .copy () for _ in range (bundles_per_slice )
584612 ]
@@ -647,14 +675,47 @@ def bundle_resources(self) -> Dict[str, float]:
647675 """The resources that are assigned to each bundle."""
648676 return self ._bundle_resources
649677
678+ @DeveloperAPI (stability = "alpha" )
679+ def release_head_pgs (self ) -> None :
680+ """Remove all internal head placement groups.
681+
682+ The head PGs exist only to atomically claim a TPU slice's label during
683+ the race window between slice selection and worker-PG construction.
684+ Once the worker PG's bundles are scheduled, the worker PG holds the TPU
685+ resources on every host in the slice and the head PGs are redundant.
686+
687+ Callers should invoke this idempotent call after `self.placement_group.ready()`
688+ resolves successfully.
689+ """
690+ head_pgs = getattr (self , "_head_pgs" , [])
691+ self ._head_pgs = []
692+ for head_pg in head_pgs :
693+ try :
694+ remove_placement_group (head_pg )
695+ except Exception :
696+ logger .exception (
697+ "Failed to remove TPU head placement group %s; the "
698+ "slice reservation marker may leak until the creator "
699+ "process exits." ,
700+ getattr (head_pg , "id" , head_pg ),
701+ )
702+
650703 def shutdown (self ):
651- """Removes the worker placement group and all internal head PGs."""
652- if self ._placement_group :
653- remove_placement_group (self ._placement_group )
704+ """Remove the worker placement group and all internal head PGs.
705+
706+ Idempotent. Safe to call on a partially-constructed instance.
707+ """
708+ worker_pg = getattr (self , "_placement_group" , None )
709+ if worker_pg is not None :
654710 self ._placement_group = None
655- for head_pg in self ._head_pgs :
656- remove_placement_group (head_pg )
657- self ._head_pgs = []
711+ try :
712+ remove_placement_group (worker_pg )
713+ except Exception :
714+ logger .exception (
715+ "Failed to remove TPU worker placement group %s." ,
716+ getattr (worker_pg , "id" , worker_pg ),
717+ )
718+ self .release_head_pgs ()
658719
659720
660721@PublicAPI (stability = "alpha" )
0 commit comments