Skip to content

Commit 6a1511d

Browse files
committed
Improve lifecycle handling of SlicePlacementGroup and support explicit bundle_label_selector
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
1 parent f98d5af commit 6a1511d

2 files changed

Lines changed: 188 additions & 18 deletions

File tree

python/ray/tests/test_tpu.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,5 +839,114 @@ def test_slice_placement_group_chips_per_vm_override(ray_v6e_tpu_cluster):
839839
assert override_pg.bundle_resources["TPU"] == 4
840840

841841

842+
def test_user_bundle_label_selector_merged(ray_tpu_cluster):
843+
"""Verifies that user-passed bundle_label_selector is merged with dynamic TPU labels."""
844+
user_selectors = [{"env": "prod"}, {"env": "test"}]
845+
846+
# 2x2x2 v4 = 2 hosts = 2 bundles
847+
slice_pg = SlicePlacementGroup(
848+
topology="2x2x2", accelerator_version="v4", bundle_label_selector=user_selectors
849+
)
850+
851+
assert len(slice_pg._bundle_label_selector) == 2
852+
853+
# Verify slice 0
854+
assert slice_pg._bundle_label_selector[0]["env"] == "prod"
855+
assert ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY in slice_pg._bundle_label_selector[0]
856+
857+
# Verify slice 1
858+
assert slice_pg._bundle_label_selector[1]["env"] == "test"
859+
assert ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY in slice_pg._bundle_label_selector[1]
860+
861+
862+
def test_user_bundle_label_selector_collision_dynamic_wins(ray_v6e_tpu_cluster):
863+
"""Verifies that dynamic TPU labels take precedence on collision."""
864+
user_selectors = [{ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY: "user-requested-slice"}]
865+
866+
# v6e-8 is single host (1 bundle)
867+
slice_pg = SlicePlacementGroup(
868+
topology="2x4", accelerator_version="v6e", bundle_label_selector=user_selectors
869+
)
870+
871+
assert len(slice_pg._bundle_label_selector) == 1
872+
# The dynamic value should win (it generates test-v6e-slice-N)
873+
actual_val = slice_pg._bundle_label_selector[0][
874+
ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY
875+
]
876+
assert actual_val != "user-requested-slice"
877+
assert "test-v6e-slice-" in actual_val
878+
879+
880+
def test_user_bundle_label_selector_length_mismatch_raises():
881+
"""Verifies that providing wrong length of selector list raises ValueError."""
882+
user_selectors = [{"env": "prod"}] # Only 1 provided but 2x2x2 v4 has 2 hosts
883+
884+
with pytest.raises(ValueError, match="bundle_label_selector length"):
885+
SlicePlacementGroup(
886+
topology="2x2x2",
887+
accelerator_version="v4",
888+
bundle_label_selector=user_selectors,
889+
)
890+
891+
892+
def test_release_head_pgs_idempotent(ray_tpu_cluster):
893+
"""Verifies that release_head_pgs() is idempotent."""
894+
slice_pg = SlicePlacementGroup(topology="2x2x2", accelerator_version="v4")
895+
896+
assert len(slice_pg.head_placement_groups) == 1
897+
898+
slice_pg.release_head_pgs()
899+
assert len(slice_pg.head_placement_groups) == 0
900+
901+
# Call again, should not raise
902+
slice_pg.release_head_pgs()
903+
assert len(slice_pg.head_placement_groups) == 0
904+
905+
906+
def test_shutdown_idempotent(ray_tpu_cluster):
907+
"""Verifies that shutdown() is idempotent."""
908+
slice_pg = SlicePlacementGroup(topology="2x2x2", accelerator_version="v4")
909+
910+
slice_pg.shutdown()
911+
assert slice_pg.placement_group is None
912+
assert len(slice_pg.head_placement_groups) == 0
913+
914+
# Call again, should not raise
915+
slice_pg.shutdown()
916+
917+
918+
def test_shutdown_safe_after_construction_failure():
919+
"""Verifies that shutdown() is safe to call on a partially-constructed instance."""
920+
with patch(
921+
"ray.util.tpu.SlicePlacementGroup._reserve_slice",
922+
side_effect=RuntimeError("Test failure"),
923+
):
924+
with pytest.raises(RuntimeError, match="Test failure"):
925+
SlicePlacementGroup(topology="2x2x2", accelerator_version="v4")
926+
927+
# If the above didn't crash or leak resources, we are good.
928+
# We can also manually construct a partial instance and call shutdown.
929+
partial_pg = SlicePlacementGroup.__new__(SlicePlacementGroup)
930+
partial_pg._head_pgs = []
931+
partial_pg._placement_group = None
932+
933+
# Should not raise even though it's missing attributes
934+
partial_pg.shutdown()
935+
936+
937+
def test_release_head_pgs_after_ready_then_shutdown(ray_tpu_cluster):
938+
"""Validates Slice PG lifecycle: wait until ready, release head PGs, then shutdown."""
939+
slice_pg = SlicePlacementGroup(topology="2x2x2", accelerator_version="v4")
940+
941+
# Wait for ready
942+
ray.get(slice_pg.placement_group.ready())
943+
944+
slice_pg.release_head_pgs()
945+
assert len(slice_pg.head_placement_groups) == 0
946+
947+
slice_pg.shutdown()
948+
assert slice_pg.placement_group is None
949+
950+
842951
if __name__ == "__main__":
843952
sys.exit(pytest.main(["-sv", __file__]))

python/ray/util/tpu.py

Lines changed: 79 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,10 @@ def get_tpu_worker_resources(
160160
"""
161161
accelerator_version = get_tpu_version_from_type(accelerator_type)
162162

163-
resolved_chips_per_vm = chips_per_vm or get_chips_per_host(
164-
topology, accelerator_version
163+
resolved_chips_per_vm = (
164+
chips_per_vm
165+
if chips_per_vm is not None
166+
else get_chips_per_host(topology, accelerator_version)
165167
)
166168
total_chips_per_slice = get_num_chips_from_topology(topology)
167169

@@ -447,6 +449,8 @@ class SlicePlacementGroup:
447449
TPU head placement group to become ready. Defaults to
448450
``DEFAULT_TPU_HEAD_RESERVATION_TIMEOUT_S``. Pass ``None`` to wait
449451
indefinitely.
452+
bundle_label_selector: Optional list of label selectors to apply per bundle. These label
453+
selectors are applied in addition to dynamic TPU slice name labels, which take precedence.
450454
451455
Examples:
452456
@@ -490,7 +494,13 @@ def __init__(
490494
head_reservation_timeout_s: Optional[float] = (
491495
DEFAULT_TPU_HEAD_RESERVATION_TIMEOUT_S
492496
),
497+
bundle_label_selector: Optional[List[Dict[str, str]]] = None,
493498
):
499+
self._head_pgs: List[PlacementGroup] = []
500+
self._bundle_label_selector: List[Dict[str, str]] = []
501+
self._placement_group: Optional[PlacementGroup] = None
502+
self._user_bundle_label_selector = bundle_label_selector or []
503+
494504
self._topology = topology.strip().lower()
495505
self._accelerator_version = get_tpu_version_from_type(
496506
accelerator_version.strip()
@@ -508,8 +518,10 @@ def __init__(
508518
chips_per_vm=chips_per_vm,
509519
)
510520

511-
self._chips_per_host = chips_per_vm or get_chips_per_host(
512-
self._topology, self._accelerator_version
521+
self._chips_per_host = (
522+
chips_per_vm
523+
if chips_per_vm is not None
524+
else get_chips_per_host(self._topology, self._accelerator_version)
513525
)
514526

515527
# Within Ray, a "host" corresponds to a user-visible compute VM.
@@ -518,10 +530,7 @@ def __init__(
518530
hosts_per_slice = max(1, total_chips // self._chips_per_host)
519531
self._num_hosts = hosts_per_slice * self._num_slices
520532

521-
self._head_pgs: List[PlacementGroup] = []
522-
self._bundle_label_selector: List[Dict[str, str]] = []
523533
self._validate_tpu_config()
524-
self._placement_group = None
525534

526535
# Reserve a TPU slice of the provided accelerator version and topology.
527536
self._placement_group = self._reserve_slice(
@@ -549,6 +558,15 @@ def _reserve_slice(
549558
lifetime: Optional[str] = None,
550559
) -> PlacementGroup:
551560
"""Performs the two-step scheduling to reserve a TPU slice."""
561+
if (
562+
self._user_bundle_label_selector
563+
and len(self._user_bundle_label_selector) != self._num_bundles
564+
):
565+
raise ValueError(
566+
f"bundle_label_selector length ({len(self._user_bundle_label_selector)}) must "
567+
f"match the number of bundles ({self._num_bundles})."
568+
)
569+
552570
self._bundle_label_selector = []
553571
bundles = []
554572
bundles_per_slice = self._num_bundles // self._num_slices
@@ -557,7 +575,7 @@ def _reserve_slice(
557575
accelerator_type = "TPU-" + self.accelerator_version.upper()
558576

559577
try:
560-
for _ in range(self.num_slices):
578+
for slice_idx in range(self.num_slices):
561579
reservation = reserve_tpu_slice(
562580
self._topology,
563581
accelerator_type,
@@ -575,10 +593,20 @@ def _reserve_slice(
575593
slice_name, head_pg = reservation
576594
self._head_pgs.append(head_pg)
577595

578-
# Reserving a slice is done through constructing num_hosts bundles, each with a label selector for
579-
# the unique name of an available TPU slice.
580-
selector = {ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY: slice_name}
581-
self._bundle_label_selector.extend([selector] * bundles_per_slice)
596+
dynamic_labels = {ray._raylet.RAY_NODE_TPU_SLICE_NAME_KEY: slice_name}
597+
598+
for bundle_idx in range(bundles_per_slice):
599+
global_bundle_idx = slice_idx * bundles_per_slice + bundle_idx
600+
601+
user_labels = (
602+
self._user_bundle_label_selector[global_bundle_idx]
603+
if global_bundle_idx < len(self._user_bundle_label_selector)
604+
else {}
605+
)
606+
# Dynamic TPU slice labels take precedence; user labels fill in the rest.
607+
merged_labels = {**user_labels, **dynamic_labels}
608+
self._bundle_label_selector.append(merged_labels)
609+
582610
bundles += [
583611
self._bundle_resources.copy() for _ in range(bundles_per_slice)
584612
]
@@ -647,14 +675,47 @@ def bundle_resources(self) -> Dict[str, float]:
647675
"""The resources that are assigned to each bundle."""
648676
return self._bundle_resources
649677

678+
@DeveloperAPI(stability="alpha")
679+
def release_head_pgs(self) -> None:
680+
"""Remove all internal head placement groups.
681+
682+
The head PGs exist only to atomically claim a TPU slice's label during
683+
the race window between slice selection and worker-PG construction.
684+
Once the worker PG's bundles are scheduled, the worker PG holds the TPU
685+
resources on every host in the slice and the head PGs are redundant.
686+
687+
Callers should invoke this idempotent call after `self.placement_group.ready()`
688+
resolves successfully.
689+
"""
690+
head_pgs = getattr(self, "_head_pgs", [])
691+
self._head_pgs = []
692+
for head_pg in head_pgs:
693+
try:
694+
remove_placement_group(head_pg)
695+
except Exception:
696+
logger.exception(
697+
"Failed to remove TPU head placement group %s; the "
698+
"slice reservation marker may leak until the creator "
699+
"process exits.",
700+
getattr(head_pg, "id", head_pg),
701+
)
702+
650703
def shutdown(self):
651-
"""Removes the worker placement group and all internal head PGs."""
652-
if self._placement_group:
653-
remove_placement_group(self._placement_group)
704+
"""Remove the worker placement group and all internal head PGs.
705+
706+
Idempotent. Safe to call on a partially-constructed instance.
707+
"""
708+
worker_pg = getattr(self, "_placement_group", None)
709+
if worker_pg is not None:
654710
self._placement_group = None
655-
for head_pg in self._head_pgs:
656-
remove_placement_group(head_pg)
657-
self._head_pgs = []
711+
try:
712+
remove_placement_group(worker_pg)
713+
except Exception:
714+
logger.exception(
715+
"Failed to remove TPU worker placement group %s.",
716+
getattr(worker_pg, "id", worker_pg),
717+
)
718+
self.release_head_pgs()
658719

659720

660721
@PublicAPI(stability="alpha")

0 commit comments

Comments
 (0)