diff --git a/pathwaysutils/elastic/manager.py b/pathwaysutils/elastic/manager.py index 84059e3..acb303f 100644 --- a/pathwaysutils/elastic/manager.py +++ b/pathwaysutils/elastic/manager.py @@ -61,6 +61,11 @@ class Manager: _snapshot: PyTree _SIMPLE_EXECUTION_TEST_VALUE = 100 + _ELASTIC_DOWN_ERROR_TYPES = [ + "DATA_LOSS", + "NOT_FOUND", + "INTERNAL", + ] def __init__( self, @@ -138,22 +143,21 @@ def slice_device_count(self, slice_index: int) -> int: f"Slice {slice_index=} not found in {self.slice_to_devices=}" ) from error - @staticmethod - def _is_error_due_to_slice_down(error: Exception) -> bool: + @classmethod + def _is_error_due_to_slice_down(cls, error: Exception) -> bool: """Check if the error is due to slice down.""" - if "DATA_LOSS" in str(error): - _logger.debug("Caught JaxRuntimeError DATA_LOSS exception") - elif "NOT_FOUND" in str(error): - _logger.debug("Caught JaxRuntimeError NOT_FOUND exception") - elif "INTERNAL" in str(error): - _logger.debug("Caught JaxRuntimeError INTERNAL exception") - + return_value = any( + error_type in str(error) + for error_type in cls._ELASTIC_DOWN_ERROR_TYPES + ) + if return_value: + _logger.info("Caught an error due to slice down") else: - _logger.debug("Unknown JaxRuntimeError") - return False + _logger.info("Caught an error not due to slice down") _logger.debug("\n".join(traceback.format_exception(error))) - return True + + return return_value @classmethod def _simple_execution(cls, devices: Sequence[jax.Device]) -> jax.Array: @@ -192,7 +196,7 @@ def get_slice_availability(self) -> set[int]: } for slice_index, x in results.items(): - _logger.debug("Checking slice_index=%s", slice_index) + _logger.info("Checking slice_index=%s", slice_index) expected = ( np.zeros(self.slice_device_count(slice_index), dtype=float) + self._SIMPLE_EXECUTION_TEST_VALUE @@ -202,7 +206,7 @@ def get_slice_availability(self) -> set[int]: jax.block_until_ready(x) if np.allclose(x, expected): good_slice_indices.add(slice_index) - _logger.debug("slice_index=%s good", slice_index) + _logger.info("slice_index=%s good", slice_index) else: _logger.error( "Error with _simple_execution for slice_index=%s. " @@ -217,9 +221,9 @@ def get_slice_availability(self) -> set[int]: except jax.errors.JaxRuntimeError as error: if not self._is_error_due_to_slice_down(error): raise - _logger.debug("slice_index=%s bad", slice_index) + _logger.info("slice_index=%s bad", slice_index) - _logger.debug("good_slice_indices=%s", good_slice_indices) + _logger.info("good_slice_indices=%s", good_slice_indices) return good_slice_indices @@ -245,13 +249,13 @@ def _is_ready_to_reshard(self, step: int) -> bool: if len(good_slice_indices) == len(self.good_slice_indices): return False - _logger.debug("New slice available.") - _logger.debug( + _logger.info("New slice available.") + _logger.info( "Previous good slice indices: self.good_slice_indices=%s", self.good_slice_indices, ) - _logger.debug( - "Current good slice indices: good_slice_indices=%s", good_slice_indices + _logger.info( + "Current good slice indices: %s", good_slice_indices ) self.good_slice_indices = good_slice_indices @@ -317,7 +321,7 @@ def _slice_down(self, reshard_retry: bool = False) -> None: ElasticRuntimeError: If the maximum number of elastic down events or reshard retries is reached. """ - _logger.debug("Slice down") + _logger.info("Slice down") self.good_slice_indices = self.get_slice_availability() self.elastic_down_event_count += 1 if reshard_retry: @@ -325,7 +329,7 @@ def _slice_down(self, reshard_retry: bool = False) -> None: else: self.reshard_retry_count = 0 - _logger.debug( + _logger.info( "elastic_down_event_count=%s max_elastic_down_event_count=%s", self.elastic_down_event_count, self.max_elastic_down_event_count, @@ -339,7 +343,7 @@ def _slice_down(self, reshard_retry: bool = False) -> None: f" {self.max_elastic_down_event_count}" ) - _logger.debug( + _logger.info( "self.reshard_retry_count=%s self.max_reshard_retry_count=%s", self.reshard_retry_count, self.max_reshard_retry_count, @@ -429,19 +433,19 @@ def maybe_snapshot( block: If True, block until the snapshot is ready. """ if not force and step % self.snapshot_period: - _logger.debug("Not saving a snapshot") + _logger.info("Not saving a snapshot") return total_nbytes = self._get_snapshot_size(snapshot) - _logger.debug("Saving a snapshot of %s bytes", total_nbytes) + _logger.info("Saving a snapshot of %s bytes", total_nbytes) snapshot_host = self._put_snapshot_on_host(snapshot) - _logger.debug("Snapshot dispatched") + _logger.info("Snapshot dispatched") if block: jax.block_until_ready(snapshot_host) - _logger.debug("Snapshot completed") + _logger.info("Snapshot completed") # TODO b/407772100 - Support multiple snapshots. self._snapshot = {"step": step, "snapshot": snapshot_host} @@ -532,23 +536,23 @@ def maybe_reshard_down( while True: if not self._is_error_due_to_slice_down(error): - _logger.debug( + _logger.info( "Not resharding down because the error is not due to a slice down." ) raise error from error.__cause__ - _logger.debug("Resharding down") + _logger.info("Resharding down") self._slice_down(reshard_retry) try: handler_return_values = elastic_handler(*handler_args, **handler_kwargs) break except jax.errors.JaxRuntimeError as e: - _logger.debug("Elastic handler raised an error.") + _logger.info("Elastic handler raised an error.") error = e reshard_retry = True - _logger.debug("Successfully resharded down") + _logger.info("Successfully resharded down") return handler_return_values @timing.timeit @@ -586,7 +590,7 @@ def maybe_reshard_up( handler_kwargs = {} if not self._is_ready_to_reshard(step): - _logger.debug("Not resharding up since it is not time to reshard.") + _logger.info("Not resharding up since it is not time to reshard.") return self.maybe_snapshot( @@ -599,7 +603,7 @@ def maybe_reshard_up( try: handler_return_values = elastic_handler(*handler_args, **handler_kwargs) except jax.errors.JaxRuntimeError as error: - _logger.debug("Elastic handler failed. Trying again") + _logger.info("Elastic handler failed. Trying again") handler_return_values = self.maybe_reshard_down( error=error, elastic_handler=elastic_handler, @@ -608,5 +612,5 @@ def maybe_reshard_up( reshard_retry=True, ) - _logger.debug("Finished resharding up") + _logger.info("Finished resharding up") return handler_return_values