Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 84 additions & 72 deletions pathwaysutils/elastic/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
events. It also provides a utility for waiting for slices to become active.
"""

from abc import ABC, abstractmethod
import collections
from collections.abc import Mapping, Sequence, Set
from collections.abc import Callable, Mapping, Sequence, Set
import itertools
import logging
import time
from typing import Any

import jax
import numpy as np
Expand All @@ -41,44 +44,87 @@
])


def _plus_one(x: jax.Array) -> jax.Array:
"""Adds one to each element in the array.
class SliceHealthChecker(ABC):
"""Base class for slice health checkers.

Used to test if a slice is active.
Implementations must provide `dispatch` and `validate` methods.
`dispatch` starts the health check operations, and `validate` waits for them
and returns the active slice indices.
"""

Args:
x: The array to add one to.
def __init__(self, slice_to_devices: Mapping[int, Sequence[jax.Device]]):
self.slice_to_devices = slice_to_devices

Returns:
The array with one added to each element.
"""
return x + 1
@abstractmethod
def dispatch(self) -> None:
"""Dispatches the JAX operations in the background."""

@abstractmethod
def validate(self) -> Set[int]:
"""Blocks on results and returns the set of active slice indices.

def _simple_execution(devices: Sequence[jax.Device]) -> jax.Array:
"""Simple execution to test if a slice is active.
Returns:
Set of active slice indices.
"""

This function is used to test if a slice is active. It executes a simple
computation on the devices and returns the result. If any of the devices are
not active, the returned array will fail with a JaxRuntimeError used.

Simply executing this function is not enough to determine if the slice is
active. We also need to check the value of the returned array.
class DefaultSliceHealthChecker(SliceHealthChecker):
"""Default implementation that checks each slice individually.

Args:
devices: The devices to execute on.
It executes a simple computation on each slice independently to verify
that the devices are active and responsive.

Returns:
The result of the execution.
The computation involves creating a zero array of the same size as the
number of devices in the slice, adding (_SIMPLE_EXECUTION_TEST_VALUE - 1)
to it, and then running a pmap to add 1 to each element. This verifies
that the devices are active and can perform computations.
"""
if not devices:
raise ValueError("No devices")

test_input = np.zeros(len(devices), dtype=float) + (
_SIMPLE_EXECUTION_TEST_VALUE - 1
)
def _plus_one(self, x: jax.Array) -> jax.Array:
return x + 1

def _simple_execution(self, devices: Sequence[jax.Device]) -> jax.Array:
if not devices:
raise ValueError("No devices")

test_input = np.zeros(len(devices), dtype=float) + (
_SIMPLE_EXECUTION_TEST_VALUE - 1
)

return jax.pmap(self._plus_one, devices=devices)(test_input)

def dispatch(self) -> None:
self.results = {
slice_index: self._simple_execution(devices)
for slice_index, devices in self.slice_to_devices.items()
}

def validate(self) -> Set[int]:
active_slice_indices = set()
for slice_index, x in self.results.items():
expected = (
np.zeros(len(self.slice_to_devices[slice_index]), dtype=float)
+ _SIMPLE_EXECUTION_TEST_VALUE
)
try:
jax.block_until_ready(x)
if np.allclose(x, expected):
active_slice_indices.add(slice_index)
else:
msg = (
f"Error with _simple_execution for slice_index={slice_index}. "
f"Expected: {expected!r}, Actual: {x!r}"
)
_logger.error(msg)
raise ValueError(msg)
except jax.errors.JaxRuntimeError as error:
_logger.debug(
"Caught JaxRuntimeError for slice_index=%s: %s", slice_index, error
)
if not is_error_due_to_slice_down(error):
raise
return active_slice_indices

return jax.pmap(_plus_one, devices=devices)(test_input)


def get_slice_to_devices(
Expand All @@ -94,13 +140,16 @@ def get_slice_to_devices(
@timing.timeit
def get_active_slice_indices(
slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None,
checker: SliceHealthChecker | None = None,
) -> Set[int]:
"""Returns the set of active slices indices.

Args:
slice_to_devices: A mapping from slice index to devices. If None,
`get_slice_to_devices(jax.devices())` is used to gather all available
devices and group them by slice.
checker: A SliceHealthChecker instance. If None, DefaultSliceHealthChecker
is used.

Returns:
A set of integers representing the indices of the active slices.
Expand All @@ -109,62 +158,24 @@ def get_active_slice_indices(
_logger.debug("slice_to_devices is None. Getting from jax.devices().")
slice_to_devices = get_slice_to_devices(tuple(jax.devices()))

if checker is None:
checker = DefaultSliceHealthChecker(slice_to_devices)

_logger.debug(
"Getting active slice indices for slices: %s",
sorted(list(slice_to_devices.keys())),
)

active_slice_indices = set()

results = {
slice_index: _simple_execution(devices)
for slice_index, devices in slice_to_devices.items()
}

for slice_index, x in results.items():
_logger.debug("Checking slice_index=%s", slice_index)
expected = (
np.zeros(len(slice_to_devices[slice_index]), dtype=float)
+ _SIMPLE_EXECUTION_TEST_VALUE
)
try:
with timing.Timer(f"Checking {slice_index=}"):
_logger.debug("Blocking until ready for slice_index=%s", slice_index)
jax.block_until_ready(x)
_logger.debug("Execution finished for slice_index=%s", slice_index)
if np.allclose(x, expected):
active_slice_indices.add(slice_index)
_logger.debug("slice_index=%s active", slice_index)
else:
_logger.error(
"Error with _simple_execution for slice_index=%s. "
"This should never happen. Expected: %r, Actual: %r",
slice_index,
expected,
x,
)
raise ValueError(
f"Error with _simple_execution for slice_index={slice_index}."
)
except jax.errors.JaxRuntimeError as error:
_logger.debug(
"Caught JaxRuntimeError for slice_index=%s: %s", slice_index, error
)
if not is_error_due_to_slice_down(error):
_logger.info("Re-raising error for slice_index=%s", slice_index)
raise
_logger.debug("slice_index=%s bad", slice_index)

_logger.debug("active_slice_indices=%s", active_slice_indices)

return active_slice_indices
checker.dispatch()
return checker.validate()


def wait_for_slices(
slice_count: int,
poll_interval: float | int = 10,
timeout: float | int | None = None,
slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None,
checker: SliceHealthChecker | None = None,
) -> Set[int]:
"""Waits until after at least `slice_count` slices become active.

Expand All @@ -177,6 +188,7 @@ def wait_for_slices(
timeout.
slice_to_devices: A mapping from slice index to devices. If None,
`get_slice_to_devices(jax.devices())` is used.
checker: A SliceHealthChecker instance to use for checking health.

Returns:
The active slice indices
Expand All @@ -201,7 +213,7 @@ def wait_for_slices(
check_start_time = time.time()

_logger.debug("Checking active slices...")
active_slice_indices = get_active_slice_indices(slice_to_devices)
active_slice_indices = get_active_slice_indices(slice_to_devices, checker)
if len(active_slice_indices) >= slice_count:
_logger.info(
"Sufficient slices active: %s >= %s. Active indices: %s",
Expand Down
Loading