Skip to content

Commit 012573b

Browse files
lukebaumanncopybara-github
authored andcommitted
Generalize TPU Slice Health Checks for elasticity
- Create DefaultSliceHealthChecker. - Move slice_to_devices to __init__. - Bring details from simple executions' docstrings to DefaultSliceHealthChecker. PiperOrigin-RevId: 896196196
1 parent 66e9754 commit 012573b

1 file changed

Lines changed: 84 additions & 72 deletions

File tree

pathwaysutils/elastic/elastic.py

Lines changed: 84 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@
1818
events. It also provides a utility for waiting for slices to become active.
1919
"""
2020

21+
from abc import ABC, abstractmethod
2122
import collections
22-
from collections.abc import Mapping, Sequence, Set
23+
from collections.abc import Callable, Mapping, Sequence, Set
24+
import itertools
2325
import logging
2426
import time
27+
from typing import Any
2528

2629
import jax
2730
import numpy as np
@@ -41,44 +44,87 @@
4144
])
4245

4346

44-
def _plus_one(x: jax.Array) -> jax.Array:
45-
"""Adds one to each element in the array.
47+
class SliceHealthChecker(ABC):
48+
"""Base class for slice health checkers.
4649
47-
Used to test if a slice is active.
50+
Implementations must provide `dispatch` and `validate` methods.
51+
`dispatch` starts the health check operations, and `validate` waits for them
52+
and returns the active slice indices.
53+
"""
4854

49-
Args:
50-
x: The array to add one to.
55+
def __init__(self, slice_to_devices: Mapping[int, Sequence[jax.Device]]):
56+
self.slice_to_devices = slice_to_devices
5157

52-
Returns:
53-
The array with one added to each element.
54-
"""
55-
return x + 1
58+
@abstractmethod
59+
def dispatch(self) -> None:
60+
"""Dispatches the JAX operations in the background."""
5661

62+
@abstractmethod
63+
def validate(self) -> Set[int]:
64+
"""Blocks on results and returns the set of active slice indices.
5765
58-
def _simple_execution(devices: Sequence[jax.Device]) -> jax.Array:
59-
"""Simple execution to test if a slice is active.
66+
Returns:
67+
Set of active slice indices.
68+
"""
6069

61-
This function is used to test if a slice is active. It executes a simple
62-
computation on the devices and returns the result. If any of the devices are
63-
not active, the returned array will fail with a JaxRuntimeError used.
6470

65-
Simply executing this function is not enough to determine if the slice is
66-
active. We also need to check the value of the returned array.
71+
class DefaultSliceHealthChecker(SliceHealthChecker):
72+
"""Default implementation that checks each slice individually.
6773
68-
Args:
69-
devices: The devices to execute on.
74+
It executes a simple computation on each slice independently to verify
75+
that the devices are active and responsive.
7076
71-
Returns:
72-
The result of the execution.
77+
The computation involves creating a zero array of the same size as the
78+
number of devices in the slice, adding (_SIMPLE_EXECUTION_TEST_VALUE - 1)
79+
to it, and then running a pmap to add 1 to each element. This verifies
80+
that the devices are active and can perform computations.
7381
"""
74-
if not devices:
75-
raise ValueError("No devices")
7682

77-
test_input = np.zeros(len(devices), dtype=float) + (
78-
_SIMPLE_EXECUTION_TEST_VALUE - 1
79-
)
83+
def _plus_one(self, x: jax.Array) -> jax.Array:
84+
return x + 1
85+
86+
def _simple_execution(self, devices: Sequence[jax.Device]) -> jax.Array:
87+
if not devices:
88+
raise ValueError("No devices")
89+
90+
test_input = np.zeros(len(devices), dtype=float) + (
91+
_SIMPLE_EXECUTION_TEST_VALUE - 1
92+
)
93+
94+
return jax.pmap(self._plus_one, devices=devices)(test_input)
95+
96+
def dispatch(self) -> None:
97+
self.results = {
98+
slice_index: self._simple_execution(devices)
99+
for slice_index, devices in self.slice_to_devices.items()
100+
}
101+
102+
def validate(self) -> Set[int]:
103+
active_slice_indices = set()
104+
for slice_index, x in self.results.items():
105+
expected = (
106+
np.zeros(len(self.slice_to_devices[slice_index]), dtype=float)
107+
+ _SIMPLE_EXECUTION_TEST_VALUE
108+
)
109+
try:
110+
jax.block_until_ready(x)
111+
if np.allclose(x, expected):
112+
active_slice_indices.add(slice_index)
113+
else:
114+
msg = (
115+
f"Error with _simple_execution for slice_index={slice_index}. "
116+
f"Expected: {expected!r}, Actual: {x!r}"
117+
)
118+
_logger.error(msg)
119+
raise ValueError(msg)
120+
except jax.errors.JaxRuntimeError as error:
121+
_logger.debug(
122+
"Caught JaxRuntimeError for slice_index=%s: %s", slice_index, error
123+
)
124+
if not is_error_due_to_slice_down(error):
125+
raise
126+
return active_slice_indices
80127

81-
return jax.pmap(_plus_one, devices=devices)(test_input)
82128

83129

84130
def get_slice_to_devices(
@@ -94,13 +140,16 @@ def get_slice_to_devices(
94140
@timing.timeit
95141
def get_active_slice_indices(
96142
slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None,
143+
checker: SliceHealthChecker | None = None,
97144
) -> Set[int]:
98145
"""Returns the set of active slices indices.
99146
100147
Args:
101148
slice_to_devices: A mapping from slice index to devices. If None,
102149
`get_slice_to_devices(jax.devices())` is used to gather all available
103150
devices and group them by slice.
151+
checker: A SliceHealthChecker instance. If None, DefaultSliceHealthChecker
152+
is used.
104153
105154
Returns:
106155
A set of integers representing the indices of the active slices.
@@ -109,62 +158,24 @@ def get_active_slice_indices(
109158
_logger.debug("slice_to_devices is None. Getting from jax.devices().")
110159
slice_to_devices = get_slice_to_devices(tuple(jax.devices()))
111160

161+
if checker is None:
162+
checker = DefaultSliceHealthChecker(slice_to_devices)
163+
112164
_logger.debug(
113165
"Getting active slice indices for slices: %s",
114166
sorted(list(slice_to_devices.keys())),
115167
)
116168

117-
active_slice_indices = set()
118-
119-
results = {
120-
slice_index: _simple_execution(devices)
121-
for slice_index, devices in slice_to_devices.items()
122-
}
123-
124-
for slice_index, x in results.items():
125-
_logger.debug("Checking slice_index=%s", slice_index)
126-
expected = (
127-
np.zeros(len(slice_to_devices[slice_index]), dtype=float)
128-
+ _SIMPLE_EXECUTION_TEST_VALUE
129-
)
130-
try:
131-
with timing.Timer(f"Checking {slice_index=}"):
132-
_logger.debug("Blocking until ready for slice_index=%s", slice_index)
133-
jax.block_until_ready(x)
134-
_logger.debug("Execution finished for slice_index=%s", slice_index)
135-
if np.allclose(x, expected):
136-
active_slice_indices.add(slice_index)
137-
_logger.debug("slice_index=%s active", slice_index)
138-
else:
139-
_logger.error(
140-
"Error with _simple_execution for slice_index=%s. "
141-
"This should never happen. Expected: %r, Actual: %r",
142-
slice_index,
143-
expected,
144-
x,
145-
)
146-
raise ValueError(
147-
f"Error with _simple_execution for slice_index={slice_index}."
148-
)
149-
except jax.errors.JaxRuntimeError as error:
150-
_logger.debug(
151-
"Caught JaxRuntimeError for slice_index=%s: %s", slice_index, error
152-
)
153-
if not is_error_due_to_slice_down(error):
154-
_logger.info("Re-raising error for slice_index=%s", slice_index)
155-
raise
156-
_logger.debug("slice_index=%s bad", slice_index)
157-
158-
_logger.debug("active_slice_indices=%s", active_slice_indices)
159-
160-
return active_slice_indices
169+
checker.dispatch()
170+
return checker.validate()
161171

162172

163173
def wait_for_slices(
164174
slice_count: int,
165175
poll_interval: float | int = 10,
166176
timeout: float | int | None = None,
167177
slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None,
178+
checker: SliceHealthChecker | None = None,
168179
) -> Set[int]:
169180
"""Waits until after at least `slice_count` slices become active.
170181
@@ -177,6 +188,7 @@ def wait_for_slices(
177188
timeout.
178189
slice_to_devices: A mapping from slice index to devices. If None,
179190
`get_slice_to_devices(jax.devices())` is used.
191+
checker: A SliceHealthChecker instance to use for checking health.
180192
181193
Returns:
182194
The active slice indices
@@ -201,7 +213,7 @@ def wait_for_slices(
201213
check_start_time = time.time()
202214

203215
_logger.debug("Checking active slices...")
204-
active_slice_indices = get_active_slice_indices(slice_to_devices)
216+
active_slice_indices = get_active_slice_indices(slice_to_devices, checker)
205217
if len(active_slice_indices) >= slice_count:
206218
_logger.info(
207219
"Sufficient slices active: %s >= %s. Active indices: %s",

0 commit comments

Comments
 (0)