1818events. It also provides a utility for waiting for slices to become active.
1919"""
2020
21+ from abc import ABC , abstractmethod
2122import collections
22- from collections .abc import Mapping , Sequence , Set
23+ from collections .abc import Callable , Mapping , Sequence , Set
24+ import itertools
2325import logging
2426import time
27+ from typing import Any
2528
2629import jax
2730import numpy as np
4144])
4245
4346
44- def _plus_one ( x : jax . Array ) -> jax . Array :
45- """Adds one to each element in the array .
47+ class SliceHealthChecker ( ABC ) :
48+ """Base class for slice health checkers .
4649
47- Used to test if a slice is active.
50+ Implementations must provide `dispatch` and `validate` methods.
51+ `dispatch` starts the health check operations, and `validate` waits for them
52+ and returns the active slice indices.
53+ """
4854
49- Args :
50- x: The array to add one to.
55+ def __init__ ( self , slice_to_devices : Mapping [ int , Sequence [ jax . Device ]]) :
56+ self . slice_to_devices = slice_to_devices
5157
52- Returns:
53- The array with one added to each element.
54- """
55- return x + 1
58+ @abstractmethod
59+ def dispatch (self ) -> None :
60+ """Dispatches the JAX operations in the background."""
5661
62+ @abstractmethod
63+ def validate (self ) -> Set [int ]:
64+ """Blocks on results and returns the set of active slice indices.
5765
58- def _simple_execution (devices : Sequence [jax .Device ]) -> jax .Array :
59- """Simple execution to test if a slice is active.
66+ Returns:
67+ Set of active slice indices.
68+ """
6069
61- This function is used to test if a slice is active. It executes a simple
62- computation on the devices and returns the result. If any of the devices are
63- not active, the returned array will fail with a JaxRuntimeError used.
6470
65- Simply executing this function is not enough to determine if the slice is
66- active. We also need to check the value of the returned array .
71+ class DefaultSliceHealthChecker ( SliceHealthChecker ):
72+ """Default implementation that checks each slice individually .
6773
68- Args:
69- devices: The devices to execute on .
74+ It executes a simple computation on each slice independently to verify
75+ that the devices are active and responsive .
7076
71- Returns:
72- The result of the execution.
77+ The computation involves creating a zero array of the same size as the
78+ number of devices in the slice, adding (_SIMPLE_EXECUTION_TEST_VALUE - 1)
79+ to it, and then running a pmap to add 1 to each element. This verifies
80+ that the devices are active and can perform computations.
7381 """
74- if not devices :
75- raise ValueError ("No devices" )
7682
77- test_input = np .zeros (len (devices ), dtype = float ) + (
78- _SIMPLE_EXECUTION_TEST_VALUE - 1
79- )
83+ def _plus_one (self , x : jax .Array ) -> jax .Array :
84+ return x + 1
85+
86+ def _simple_execution (self , devices : Sequence [jax .Device ]) -> jax .Array :
87+ if not devices :
88+ raise ValueError ("No devices" )
89+
90+ test_input = np .zeros (len (devices ), dtype = float ) + (
91+ _SIMPLE_EXECUTION_TEST_VALUE - 1
92+ )
93+
94+ return jax .pmap (self ._plus_one , devices = devices )(test_input )
95+
96+ def dispatch (self ) -> None :
97+ self .results = {
98+ slice_index : self ._simple_execution (devices )
99+ for slice_index , devices in self .slice_to_devices .items ()
100+ }
101+
102+ def validate (self ) -> Set [int ]:
103+ active_slice_indices = set ()
104+ for slice_index , x in self .results .items ():
105+ expected = (
106+ np .zeros (len (self .slice_to_devices [slice_index ]), dtype = float )
107+ + _SIMPLE_EXECUTION_TEST_VALUE
108+ )
109+ try :
110+ jax .block_until_ready (x )
111+ if np .allclose (x , expected ):
112+ active_slice_indices .add (slice_index )
113+ else :
114+ msg = (
115+ f"Error with _simple_execution for slice_index={ slice_index } . "
116+ f"Expected: { expected !r} , Actual: { x !r} "
117+ )
118+ _logger .error (msg )
119+ raise ValueError (msg )
120+ except jax .errors .JaxRuntimeError as error :
121+ _logger .debug (
122+ "Caught JaxRuntimeError for slice_index=%s: %s" , slice_index , error
123+ )
124+ if not is_error_due_to_slice_down (error ):
125+ raise
126+ return active_slice_indices
80127
81- return jax .pmap (_plus_one , devices = devices )(test_input )
82128
83129
84130def get_slice_to_devices (
@@ -94,13 +140,16 @@ def get_slice_to_devices(
94140@timing .timeit
95141def get_active_slice_indices (
96142 slice_to_devices : Mapping [int , Sequence [jax .Device ]] | None = None ,
143+ checker : SliceHealthChecker | None = None ,
97144) -> Set [int ]:
98145 """Returns the set of active slices indices.
99146
100147 Args:
101148 slice_to_devices: A mapping from slice index to devices. If None,
102149 `get_slice_to_devices(jax.devices())` is used to gather all available
103150 devices and group them by slice.
151+ checker: A SliceHealthChecker instance. If None, DefaultSliceHealthChecker
152+ is used.
104153
105154 Returns:
106155 A set of integers representing the indices of the active slices.
@@ -109,62 +158,24 @@ def get_active_slice_indices(
109158 _logger .debug ("slice_to_devices is None. Getting from jax.devices()." )
110159 slice_to_devices = get_slice_to_devices (tuple (jax .devices ()))
111160
161+ if checker is None :
162+ checker = DefaultSliceHealthChecker (slice_to_devices )
163+
112164 _logger .debug (
113165 "Getting active slice indices for slices: %s" ,
114166 sorted (list (slice_to_devices .keys ())),
115167 )
116168
117- active_slice_indices = set ()
118-
119- results = {
120- slice_index : _simple_execution (devices )
121- for slice_index , devices in slice_to_devices .items ()
122- }
123-
124- for slice_index , x in results .items ():
125- _logger .debug ("Checking slice_index=%s" , slice_index )
126- expected = (
127- np .zeros (len (slice_to_devices [slice_index ]), dtype = float )
128- + _SIMPLE_EXECUTION_TEST_VALUE
129- )
130- try :
131- with timing .Timer (f"Checking { slice_index = } " ):
132- _logger .debug ("Blocking until ready for slice_index=%s" , slice_index )
133- jax .block_until_ready (x )
134- _logger .debug ("Execution finished for slice_index=%s" , slice_index )
135- if np .allclose (x , expected ):
136- active_slice_indices .add (slice_index )
137- _logger .debug ("slice_index=%s active" , slice_index )
138- else :
139- _logger .error (
140- "Error with _simple_execution for slice_index=%s. "
141- "This should never happen. Expected: %r, Actual: %r" ,
142- slice_index ,
143- expected ,
144- x ,
145- )
146- raise ValueError (
147- f"Error with _simple_execution for slice_index={ slice_index } ."
148- )
149- except jax .errors .JaxRuntimeError as error :
150- _logger .debug (
151- "Caught JaxRuntimeError for slice_index=%s: %s" , slice_index , error
152- )
153- if not is_error_due_to_slice_down (error ):
154- _logger .info ("Re-raising error for slice_index=%s" , slice_index )
155- raise
156- _logger .debug ("slice_index=%s bad" , slice_index )
157-
158- _logger .debug ("active_slice_indices=%s" , active_slice_indices )
159-
160- return active_slice_indices
169+ checker .dispatch ()
170+ return checker .validate ()
161171
162172
163173def wait_for_slices (
164174 slice_count : int ,
165175 poll_interval : float | int = 10 ,
166176 timeout : float | int | None = None ,
167177 slice_to_devices : Mapping [int , Sequence [jax .Device ]] | None = None ,
178+ checker : SliceHealthChecker | None = None ,
168179) -> Set [int ]:
169180 """Waits until after at least `slice_count` slices become active.
170181
@@ -177,6 +188,7 @@ def wait_for_slices(
177188 timeout.
178189 slice_to_devices: A mapping from slice index to devices. If None,
179190 `get_slice_to_devices(jax.devices())` is used.
191+ checker: A SliceHealthChecker instance to use for checking health.
180192
181193 Returns:
182194 The active slice indices
@@ -201,7 +213,7 @@ def wait_for_slices(
201213 check_start_time = time .time ()
202214
203215 _logger .debug ("Checking active slices..." )
204- active_slice_indices = get_active_slice_indices (slice_to_devices )
216+ active_slice_indices = get_active_slice_indices (slice_to_devices , checker )
205217 if len (active_slice_indices ) >= slice_count :
206218 _logger .info (
207219 "Sufficient slices active: %s >= %s. Active indices: %s" ,
0 commit comments