Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 38 additions & 34 deletions pathwaysutils/elastic/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class Manager:
_snapshot: PyTree

_SIMPLE_EXECUTION_TEST_VALUE = 100
_ELASTIC_DOWN_ERROR_TYPES = [
"DATA_LOSS",
"NOT_FOUND",
"INTERNAL",
]

def __init__(
self,
Expand Down Expand Up @@ -138,22 +143,21 @@ def slice_device_count(self, slice_index: int) -> int:
f"Slice {slice_index=} not found in {self.slice_to_devices=}"
) from error

@staticmethod
def _is_error_due_to_slice_down(error: Exception) -> bool:
@classmethod
def _is_error_due_to_slice_down(cls, error: Exception) -> bool:
"""Check if the error is due to slice down."""
if "DATA_LOSS" in str(error):
_logger.debug("Caught JaxRuntimeError DATA_LOSS exception")
elif "NOT_FOUND" in str(error):
_logger.debug("Caught JaxRuntimeError NOT_FOUND exception")
elif "INTERNAL" in str(error):
_logger.debug("Caught JaxRuntimeError INTERNAL exception")

return_value = any(
error_type in str(error)
for error_type in cls._ELASTIC_DOWN_ERROR_TYPES
)
if return_value:
_logger.info("Caught an error due to slice down")
else:
_logger.debug("Unknown JaxRuntimeError")
return False
_logger.info("Caught an error not due to slice down")

_logger.debug("\n".join(traceback.format_exception(error)))
return True

return return_value

@classmethod
def _simple_execution(cls, devices: Sequence[jax.Device]) -> jax.Array:
Expand Down Expand Up @@ -192,7 +196,7 @@ def get_slice_availability(self) -> set[int]:
}

for slice_index, x in results.items():
_logger.debug("Checking slice_index=%s", slice_index)
_logger.info("Checking slice_index=%s", slice_index)
expected = (
np.zeros(self.slice_device_count(slice_index), dtype=float)
+ self._SIMPLE_EXECUTION_TEST_VALUE
Expand All @@ -202,7 +206,7 @@ def get_slice_availability(self) -> set[int]:
jax.block_until_ready(x)
if np.allclose(x, expected):
good_slice_indices.add(slice_index)
_logger.debug("slice_index=%s good", slice_index)
_logger.info("slice_index=%s good", slice_index)
else:
_logger.error(
"Error with _simple_execution for slice_index=%s. "
Expand All @@ -217,9 +221,9 @@ def get_slice_availability(self) -> set[int]:
except jax.errors.JaxRuntimeError as error:
if not self._is_error_due_to_slice_down(error):
raise
_logger.debug("slice_index=%s bad", slice_index)
_logger.info("slice_index=%s bad", slice_index)

_logger.debug("good_slice_indices=%s", good_slice_indices)
_logger.info("good_slice_indices=%s", good_slice_indices)

return good_slice_indices

Expand All @@ -245,13 +249,13 @@ def _is_ready_to_reshard(self, step: int) -> bool:
if len(good_slice_indices) == len(self.good_slice_indices):
return False

_logger.debug("New slice available.")
_logger.debug(
_logger.info("New slice available.")
_logger.info(
"Previous good slice indices: self.good_slice_indices=%s",
self.good_slice_indices,
)
_logger.debug(
"Current good slice indices: good_slice_indices=%s", good_slice_indices
_logger.info(
"Current good slice indices: %s", good_slice_indices
)

self.good_slice_indices = good_slice_indices
Expand Down Expand Up @@ -317,15 +321,15 @@ def _slice_down(self, reshard_retry: bool = False) -> None:
ElasticRuntimeError: If the maximum number of elastic down events or
reshard retries is reached.
"""
_logger.debug("Slice down")
_logger.info("Slice down")
self.good_slice_indices = self.get_slice_availability()
self.elastic_down_event_count += 1
if reshard_retry:
self.reshard_retry_count += 1
else:
self.reshard_retry_count = 0

_logger.debug(
_logger.info(
"elastic_down_event_count=%s max_elastic_down_event_count=%s",
self.elastic_down_event_count,
self.max_elastic_down_event_count,
Expand All @@ -339,7 +343,7 @@ def _slice_down(self, reshard_retry: bool = False) -> None:
f" {self.max_elastic_down_event_count}"
)

_logger.debug(
_logger.info(
"self.reshard_retry_count=%s self.max_reshard_retry_count=%s",
self.reshard_retry_count,
self.max_reshard_retry_count,
Expand Down Expand Up @@ -429,19 +433,19 @@ def maybe_snapshot(
block: If True, block until the snapshot is ready.
"""
if not force and step % self.snapshot_period:
_logger.debug("Not saving a snapshot")
_logger.info("Not saving a snapshot")
return

total_nbytes = self._get_snapshot_size(snapshot)

_logger.debug("Saving a snapshot of %s bytes", total_nbytes)
_logger.info("Saving a snapshot of %s bytes", total_nbytes)

snapshot_host = self._put_snapshot_on_host(snapshot)
_logger.debug("Snapshot dispatched")
_logger.info("Snapshot dispatched")

if block:
jax.block_until_ready(snapshot_host)
_logger.debug("Snapshot completed")
_logger.info("Snapshot completed")

# TODO b/407772100 - Support multiple snapshots.
self._snapshot = {"step": step, "snapshot": snapshot_host}
Expand Down Expand Up @@ -532,23 +536,23 @@ def maybe_reshard_down(

while True:
if not self._is_error_due_to_slice_down(error):
_logger.debug(
_logger.info(
"Not resharding down because the error is not due to a slice down."
)
raise error from error.__cause__

_logger.debug("Resharding down")
_logger.info("Resharding down")
self._slice_down(reshard_retry)

try:
handler_return_values = elastic_handler(*handler_args, **handler_kwargs)
break
except jax.errors.JaxRuntimeError as e:
_logger.debug("Elastic handler raised an error.")
_logger.info("Elastic handler raised an error.")
error = e
reshard_retry = True

_logger.debug("Successfully resharded down")
_logger.info("Successfully resharded down")
return handler_return_values

@timing.timeit
Expand Down Expand Up @@ -586,7 +590,7 @@ def maybe_reshard_up(
handler_kwargs = {}

if not self._is_ready_to_reshard(step):
_logger.debug("Not resharding up since it is not time to reshard.")
_logger.info("Not resharding up since it is not time to reshard.")
return

self.maybe_snapshot(
Expand All @@ -599,7 +603,7 @@ def maybe_reshard_up(
try:
handler_return_values = elastic_handler(*handler_args, **handler_kwargs)
except jax.errors.JaxRuntimeError as error:
_logger.debug("Elastic handler failed. Trying again")
_logger.info("Elastic handler failed. Trying again")
handler_return_values = self.maybe_reshard_down(
error=error,
elastic_handler=elastic_handler,
Expand All @@ -608,5 +612,5 @@ def maybe_reshard_up(
reshard_retry=True,
)

_logger.debug("Finished resharding up")
_logger.info("Finished resharding up")
return handler_return_values