@@ -61,6 +61,11 @@ class Manager:
6161 _snapshot : PyTree
6262
6363 _SIMPLE_EXECUTION_TEST_VALUE = 100
64+ _ELASTIC_DOWN_ERROR_TYPES = [
65+ "DATA_LOSS" ,
66+ "NOT_FOUND" ,
67+ "INTERNAL" ,
68+ ]
6469
6570 def __init__ (
6671 self ,
@@ -138,22 +143,21 @@ def slice_device_count(self, slice_index: int) -> int:
138143 f"Slice { slice_index = } not found in { self .slice_to_devices = } "
139144 ) from error
140145
141- @staticmethod
142- def _is_error_due_to_slice_down (error : Exception ) -> bool :
146+ @classmethod
147+ def _is_error_due_to_slice_down (cls , error : Exception ) -> bool :
143148 """Check if the error is due to slice down."""
144- if "DATA_LOSS" in str (error ):
145- _logger .debug ("Caught JaxRuntimeError DATA_LOSS exception" )
146- elif "NOT_FOUND" in str (error ):
147- _logger .debug ("Caught JaxRuntimeError NOT_FOUND exception" )
148- elif "INTERNAL" in str (error ):
149- _logger .debug ("Caught JaxRuntimeError INTERNAL exception" )
150-
149+ return_value = any (
150+ error_type in str (error )
151+ for error_type in cls ._ELASTIC_DOWN_ERROR_TYPES
152+ )
153+ if return_value :
154+ _logger .info ("Caught an error due to slice down" )
151155 else :
152- _logger .debug ("Unknown JaxRuntimeError" )
153- return False
156+ _logger .info ("Caught an error not due to slice down" )
154157
155158 _logger .debug ("\n " .join (traceback .format_exception (error )))
156- return True
159+
160+ return return_value
157161
158162 @classmethod
159163 def _simple_execution (cls , devices : Sequence [jax .Device ]) -> jax .Array :
@@ -192,7 +196,7 @@ def get_slice_availability(self) -> set[int]:
192196 }
193197
194198 for slice_index , x in results .items ():
195- _logger .debug ("Checking slice_index=%s" , slice_index )
199+ _logger .info ("Checking slice_index=%s" , slice_index )
196200 expected = (
197201 np .zeros (self .slice_device_count (slice_index ), dtype = float )
198202 + self ._SIMPLE_EXECUTION_TEST_VALUE
@@ -202,7 +206,7 @@ def get_slice_availability(self) -> set[int]:
202206 jax .block_until_ready (x )
203207 if np .allclose (x , expected ):
204208 good_slice_indices .add (slice_index )
205- _logger .debug ("slice_index=%s good" , slice_index )
209+ _logger .info ("slice_index=%s good" , slice_index )
206210 else :
207211 _logger .error (
208212 "Error with _simple_execution for slice_index=%s. "
@@ -217,9 +221,9 @@ def get_slice_availability(self) -> set[int]:
217221 except jax .errors .JaxRuntimeError as error :
218222 if not self ._is_error_due_to_slice_down (error ):
219223 raise
220- _logger .debug ("slice_index=%s bad" , slice_index )
224+ _logger .info ("slice_index=%s bad" , slice_index )
221225
222- _logger .debug ("good_slice_indices=%s" , good_slice_indices )
226+ _logger .info ("good_slice_indices=%s" , good_slice_indices )
223227
224228 return good_slice_indices
225229
@@ -245,13 +249,13 @@ def _is_ready_to_reshard(self, step: int) -> bool:
245249 if len (good_slice_indices ) == len (self .good_slice_indices ):
246250 return False
247251
248- _logger .debug ("New slice available." )
249- _logger .debug (
252+ _logger .info ("New slice available." )
253+ _logger .info (
250254 "Previous good slice indices: self.good_slice_indices=%s" ,
251255 self .good_slice_indices ,
252256 )
253- _logger .debug (
254- "Current good slice indices: good_slice_indices= %s" , good_slice_indices
257+ _logger .info (
258+ "Current good slice indices: %s" , good_slice_indices
255259 )
256260
257261 self .good_slice_indices = good_slice_indices
@@ -317,15 +321,15 @@ def _slice_down(self, reshard_retry: bool = False) -> None:
317321 ElasticRuntimeError: If the maximum number of elastic down events or
318322 reshard retries is reached.
319323 """
320- _logger .debug ("Slice down" )
324+ _logger .info ("Slice down" )
321325 self .good_slice_indices = self .get_slice_availability ()
322326 self .elastic_down_event_count += 1
323327 if reshard_retry :
324328 self .reshard_retry_count += 1
325329 else :
326330 self .reshard_retry_count = 0
327331
328- _logger .debug (
332+ _logger .info (
329333 "elastic_down_event_count=%s max_elastic_down_event_count=%s" ,
330334 self .elastic_down_event_count ,
331335 self .max_elastic_down_event_count ,
@@ -339,7 +343,7 @@ def _slice_down(self, reshard_retry: bool = False) -> None:
339343 f" { self .max_elastic_down_event_count } "
340344 )
341345
342- _logger .debug (
346+ _logger .info (
343347 "self.reshard_retry_count=%s self.max_reshard_retry_count=%s" ,
344348 self .reshard_retry_count ,
345349 self .max_reshard_retry_count ,
@@ -429,19 +433,19 @@ def maybe_snapshot(
429433 block: If True, block until the snapshot is ready.
430434 """
431435 if not force and step % self .snapshot_period :
432- _logger .debug ("Not saving a snapshot" )
436+ _logger .info ("Not saving a snapshot" )
433437 return
434438
435439 total_nbytes = self ._get_snapshot_size (snapshot )
436440
437- _logger .debug ("Saving a snapshot of %s bytes" , total_nbytes )
441+ _logger .info ("Saving a snapshot of %s bytes" , total_nbytes )
438442
439443 snapshot_host = self ._put_snapshot_on_host (snapshot )
440- _logger .debug ("Snapshot dispatched" )
444+ _logger .info ("Snapshot dispatched" )
441445
442446 if block :
443447 jax .block_until_ready (snapshot_host )
444- _logger .debug ("Snapshot completed" )
448+ _logger .info ("Snapshot completed" )
445449
446450 # TODO b/407772100 - Support multiple snapshots.
447451 self ._snapshot = {"step" : step , "snapshot" : snapshot_host }
@@ -532,23 +536,23 @@ def maybe_reshard_down(
532536
533537 while True :
534538 if not self ._is_error_due_to_slice_down (error ):
535- _logger .debug (
539+ _logger .info (
536540 "Not resharding down because the error is not due to a slice down."
537541 )
538542 raise error from error .__cause__
539543
540- _logger .debug ("Resharding down" )
544+ _logger .info ("Resharding down" )
541545 self ._slice_down (reshard_retry )
542546
543547 try :
544548 handler_return_values = elastic_handler (* handler_args , ** handler_kwargs )
545549 break
546550 except jax .errors .JaxRuntimeError as e :
547- _logger .debug ("Elastic handler raised an error." )
551+ _logger .info ("Elastic handler raised an error." )
548552 error = e
549553 reshard_retry = True
550554
551- _logger .debug ("Successfully resharded down" )
555+ _logger .info ("Successfully resharded down" )
552556 return handler_return_values
553557
554558 @timing .timeit
@@ -586,7 +590,7 @@ def maybe_reshard_up(
586590 handler_kwargs = {}
587591
588592 if not self ._is_ready_to_reshard (step ):
589- _logger .debug ("Not resharding up since it is not time to reshard." )
593+ _logger .info ("Not resharding up since it is not time to reshard." )
590594 return
591595
592596 self .maybe_snapshot (
@@ -599,7 +603,7 @@ def maybe_reshard_up(
599603 try :
600604 handler_return_values = elastic_handler (* handler_args , ** handler_kwargs )
601605 except jax .errors .JaxRuntimeError as error :
602- _logger .debug ("Elastic handler failed. Trying again" )
606+ _logger .info ("Elastic handler failed. Trying again" )
603607 handler_return_values = self .maybe_reshard_down (
604608 error = error ,
605609 elastic_handler = elastic_handler ,
@@ -608,5 +612,5 @@ def maybe_reshard_up(
608612 reshard_retry = True ,
609613 )
610614
611- _logger .debug ("Finished resharding up" )
615+ _logger .info ("Finished resharding up" )
612616 return handler_return_values
0 commit comments