From 012573bd418911bab408bc7a3e6b770380527c78 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Tue, 7 Apr 2026 18:54:44 -0700 Subject: [PATCH] Generalize TPU Slice Health Checks for elasticity - Create DefaultSliceHealthChecker. - Move slice_to_devices to __init__. - Bring details from simple executions' docstrings to DefaultSliceHealthChecker. PiperOrigin-RevId: 896196196 --- pathwaysutils/elastic/elastic.py | 156 +++++++++++++++++-------------- 1 file changed, 84 insertions(+), 72 deletions(-) diff --git a/pathwaysutils/elastic/elastic.py b/pathwaysutils/elastic/elastic.py index e7243d3..1529e1e 100644 --- a/pathwaysutils/elastic/elastic.py +++ b/pathwaysutils/elastic/elastic.py @@ -18,10 +18,13 @@ events. It also provides a utility for waiting for slices to become active. """ +from abc import ABC, abstractmethod import collections -from collections.abc import Mapping, Sequence, Set +from collections.abc import Callable, Mapping, Sequence, Set +import itertools import logging import time +from typing import Any import jax import numpy as np @@ -41,44 +44,87 @@ ]) -def _plus_one(x: jax.Array) -> jax.Array: - """Adds one to each element in the array. +class SliceHealthChecker(ABC): + """Base class for slice health checkers. - Used to test if a slice is active. + Implementations must provide `dispatch` and `validate` methods. + `dispatch` starts the health check operations, and `validate` waits for them + and returns the active slice indices. + """ - Args: - x: The array to add one to. + def __init__(self, slice_to_devices: Mapping[int, Sequence[jax.Device]]): + self.slice_to_devices = slice_to_devices - Returns: - The array with one added to each element. - """ - return x + 1 + @abstractmethod + def dispatch(self) -> None: + """Dispatches the JAX operations in the background.""" + @abstractmethod + def validate(self) -> Set[int]: + """Blocks on results and returns the set of active slice indices. -def _simple_execution(devices: Sequence[jax.Device]) -> jax.Array: - """Simple execution to test if a slice is active. + Returns: + Set of active slice indices. + """ - This function is used to test if a slice is active. It executes a simple - computation on the devices and returns the result. If any of the devices are - not active, the returned array will fail with a JaxRuntimeError used. - Simply executing this function is not enough to determine if the slice is - active. We also need to check the value of the returned array. +class DefaultSliceHealthChecker(SliceHealthChecker): + """Default implementation that checks each slice individually. - Args: - devices: The devices to execute on. + It executes a simple computation on each slice independently to verify + that the devices are active and responsive. - Returns: - The result of the execution. + The computation involves creating a zero array of the same size as the + number of devices in the slice, adding (_SIMPLE_EXECUTION_TEST_VALUE - 1) + to it, and then running a pmap to add 1 to each element. This verifies + that the devices are active and can perform computations. """ - if not devices: - raise ValueError("No devices") - test_input = np.zeros(len(devices), dtype=float) + ( - _SIMPLE_EXECUTION_TEST_VALUE - 1 - ) + def _plus_one(self, x: jax.Array) -> jax.Array: + return x + 1 + + def _simple_execution(self, devices: Sequence[jax.Device]) -> jax.Array: + if not devices: + raise ValueError("No devices") + + test_input = np.zeros(len(devices), dtype=float) + ( + _SIMPLE_EXECUTION_TEST_VALUE - 1 + ) + + return jax.pmap(self._plus_one, devices=devices)(test_input) + + def dispatch(self) -> None: + self.results = { + slice_index: self._simple_execution(devices) + for slice_index, devices in self.slice_to_devices.items() + } + + def validate(self) -> Set[int]: + active_slice_indices = set() + for slice_index, x in self.results.items(): + expected = ( + np.zeros(len(self.slice_to_devices[slice_index]), dtype=float) + + _SIMPLE_EXECUTION_TEST_VALUE + ) + try: + jax.block_until_ready(x) + if np.allclose(x, expected): + active_slice_indices.add(slice_index) + else: + msg = ( + f"Error with _simple_execution for slice_index={slice_index}. " + f"Expected: {expected!r}, Actual: {x!r}" + ) + _logger.error(msg) + raise ValueError(msg) + except jax.errors.JaxRuntimeError as error: + _logger.debug( + "Caught JaxRuntimeError for slice_index=%s: %s", slice_index, error + ) + if not is_error_due_to_slice_down(error): + raise + return active_slice_indices - return jax.pmap(_plus_one, devices=devices)(test_input) def get_slice_to_devices( @@ -94,6 +140,7 @@ def get_slice_to_devices( @timing.timeit def get_active_slice_indices( slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None, + checker: SliceHealthChecker | None = None, ) -> Set[int]: """Returns the set of active slices indices. @@ -101,6 +148,8 @@ def get_active_slice_indices( slice_to_devices: A mapping from slice index to devices. If None, `get_slice_to_devices(jax.devices())` is used to gather all available devices and group them by slice. + checker: A SliceHealthChecker instance. If None, DefaultSliceHealthChecker + is used. Returns: A set of integers representing the indices of the active slices. @@ -109,55 +158,16 @@ def get_active_slice_indices( _logger.debug("slice_to_devices is None. Getting from jax.devices().") slice_to_devices = get_slice_to_devices(tuple(jax.devices())) + if checker is None: + checker = DefaultSliceHealthChecker(slice_to_devices) + _logger.debug( "Getting active slice indices for slices: %s", sorted(list(slice_to_devices.keys())), ) - active_slice_indices = set() - - results = { - slice_index: _simple_execution(devices) - for slice_index, devices in slice_to_devices.items() - } - - for slice_index, x in results.items(): - _logger.debug("Checking slice_index=%s", slice_index) - expected = ( - np.zeros(len(slice_to_devices[slice_index]), dtype=float) - + _SIMPLE_EXECUTION_TEST_VALUE - ) - try: - with timing.Timer(f"Checking {slice_index=}"): - _logger.debug("Blocking until ready for slice_index=%s", slice_index) - jax.block_until_ready(x) - _logger.debug("Execution finished for slice_index=%s", slice_index) - if np.allclose(x, expected): - active_slice_indices.add(slice_index) - _logger.debug("slice_index=%s active", slice_index) - else: - _logger.error( - "Error with _simple_execution for slice_index=%s. " - "This should never happen. Expected: %r, Actual: %r", - slice_index, - expected, - x, - ) - raise ValueError( - f"Error with _simple_execution for slice_index={slice_index}." - ) - except jax.errors.JaxRuntimeError as error: - _logger.debug( - "Caught JaxRuntimeError for slice_index=%s: %s", slice_index, error - ) - if not is_error_due_to_slice_down(error): - _logger.info("Re-raising error for slice_index=%s", slice_index) - raise - _logger.debug("slice_index=%s bad", slice_index) - - _logger.debug("active_slice_indices=%s", active_slice_indices) - - return active_slice_indices + checker.dispatch() + return checker.validate() def wait_for_slices( @@ -165,6 +175,7 @@ def wait_for_slices( poll_interval: float | int = 10, timeout: float | int | None = None, slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None, + checker: SliceHealthChecker | None = None, ) -> Set[int]: """Waits until after at least `slice_count` slices become active. @@ -177,6 +188,7 @@ def wait_for_slices( timeout. slice_to_devices: A mapping from slice index to devices. If None, `get_slice_to_devices(jax.devices())` is used. + checker: A SliceHealthChecker instance to use for checking health. Returns: The active slice indices @@ -201,7 +213,7 @@ def wait_for_slices( check_start_time = time.time() _logger.debug("Checking active slices...") - active_slice_indices = get_active_slice_indices(slice_to_devices) + active_slice_indices = get_active_slice_indices(slice_to_devices, checker) if len(active_slice_indices) >= slice_count: _logger.info( "Sufficient slices active: %s >= %s. Active indices: %s",