Skip to content

Commit b81c07a

Browse files
lukebaumanncopybara-github
authored andcommitted
Use the info logging level within the elastic manager to reduce log spam.
PiperOrigin-RevId: 746061306
1 parent 1c1d553 commit b81c07a

1 file changed

Lines changed: 38 additions & 34 deletions

File tree

pathwaysutils/elastic/manager.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ class Manager:
6161
_snapshot: PyTree
6262

6363
_SIMPLE_EXECUTION_TEST_VALUE = 100
64+
_ELASTIC_DOWN_ERROR_TYPES = [
65+
"DATA_LOSS",
66+
"NOT_FOUND",
67+
"INTERNAL",
68+
]
6469

6570
def __init__(
6671
self,
@@ -138,22 +143,21 @@ def slice_device_count(self, slice_index: int) -> int:
138143
f"Slice {slice_index=} not found in {self.slice_to_devices=}"
139144
) from error
140145

141-
@staticmethod
142-
def _is_error_due_to_slice_down(error: Exception) -> bool:
146+
@classmethod
147+
def _is_error_due_to_slice_down(cls, error: Exception) -> bool:
143148
"""Check if the error is due to slice down."""
144-
if "DATA_LOSS" in str(error):
145-
_logger.debug("Caught JaxRuntimeError DATA_LOSS exception")
146-
elif "NOT_FOUND" in str(error):
147-
_logger.debug("Caught JaxRuntimeError NOT_FOUND exception")
148-
elif "INTERNAL" in str(error):
149-
_logger.debug("Caught JaxRuntimeError INTERNAL exception")
150-
149+
return_value = any(
150+
error_type in str(error)
151+
for error_type in cls._ELASTIC_DOWN_ERROR_TYPES
152+
)
153+
if return_value:
154+
_logger.info("Caught an error due to slice down")
151155
else:
152-
_logger.debug("Unknown JaxRuntimeError")
153-
return False
156+
_logger.info("Caught an error not due to slice down")
154157

155158
_logger.debug("\n".join(traceback.format_exception(error)))
156-
return True
159+
160+
return return_value
157161

158162
@classmethod
159163
def _simple_execution(cls, devices: Sequence[jax.Device]) -> jax.Array:
@@ -192,7 +196,7 @@ def get_slice_availability(self) -> set[int]:
192196
}
193197

194198
for slice_index, x in results.items():
195-
_logger.debug("Checking slice_index=%s", slice_index)
199+
_logger.info("Checking slice_index=%s", slice_index)
196200
expected = (
197201
np.zeros(self.slice_device_count(slice_index), dtype=float)
198202
+ self._SIMPLE_EXECUTION_TEST_VALUE
@@ -202,7 +206,7 @@ def get_slice_availability(self) -> set[int]:
202206
jax.block_until_ready(x)
203207
if np.allclose(x, expected):
204208
good_slice_indices.add(slice_index)
205-
_logger.debug("slice_index=%s good", slice_index)
209+
_logger.info("slice_index=%s good", slice_index)
206210
else:
207211
_logger.error(
208212
"Error with _simple_execution for slice_index=%s. "
@@ -217,9 +221,9 @@ def get_slice_availability(self) -> set[int]:
217221
except jax.errors.JaxRuntimeError as error:
218222
if not self._is_error_due_to_slice_down(error):
219223
raise
220-
_logger.debug("slice_index=%s bad", slice_index)
224+
_logger.info("slice_index=%s bad", slice_index)
221225

222-
_logger.debug("good_slice_indices=%s", good_slice_indices)
226+
_logger.info("good_slice_indices=%s", good_slice_indices)
223227

224228
return good_slice_indices
225229

@@ -245,13 +249,13 @@ def _is_ready_to_reshard(self, step: int) -> bool:
245249
if len(good_slice_indices) == len(self.good_slice_indices):
246250
return False
247251

248-
_logger.debug("New slice available.")
249-
_logger.debug(
252+
_logger.info("New slice available.")
253+
_logger.info(
250254
"Previous good slice indices: self.good_slice_indices=%s",
251255
self.good_slice_indices,
252256
)
253-
_logger.debug(
254-
"Current good slice indices: good_slice_indices=%s", good_slice_indices
257+
_logger.info(
258+
"Current good slice indices: %s", good_slice_indices
255259
)
256260

257261
self.good_slice_indices = good_slice_indices
@@ -317,15 +321,15 @@ def _slice_down(self, reshard_retry: bool = False) -> None:
317321
ElasticRuntimeError: If the maximum number of elastic down events or
318322
reshard retries is reached.
319323
"""
320-
_logger.debug("Slice down")
324+
_logger.info("Slice down")
321325
self.good_slice_indices = self.get_slice_availability()
322326
self.elastic_down_event_count += 1
323327
if reshard_retry:
324328
self.reshard_retry_count += 1
325329
else:
326330
self.reshard_retry_count = 0
327331

328-
_logger.debug(
332+
_logger.info(
329333
"elastic_down_event_count=%s max_elastic_down_event_count=%s",
330334
self.elastic_down_event_count,
331335
self.max_elastic_down_event_count,
@@ -339,7 +343,7 @@ def _slice_down(self, reshard_retry: bool = False) -> None:
339343
f" {self.max_elastic_down_event_count}"
340344
)
341345

342-
_logger.debug(
346+
_logger.info(
343347
"self.reshard_retry_count=%s self.max_reshard_retry_count=%s",
344348
self.reshard_retry_count,
345349
self.max_reshard_retry_count,
@@ -429,19 +433,19 @@ def maybe_snapshot(
429433
block: If True, block until the snapshot is ready.
430434
"""
431435
if not force and step % self.snapshot_period:
432-
_logger.debug("Not saving a snapshot")
436+
_logger.info("Not saving a snapshot")
433437
return
434438

435439
total_nbytes = self._get_snapshot_size(snapshot)
436440

437-
_logger.debug("Saving a snapshot of %s bytes", total_nbytes)
441+
_logger.info("Saving a snapshot of %s bytes", total_nbytes)
438442

439443
snapshot_host = self._put_snapshot_on_host(snapshot)
440-
_logger.debug("Snapshot dispatched")
444+
_logger.info("Snapshot dispatched")
441445

442446
if block:
443447
jax.block_until_ready(snapshot_host)
444-
_logger.debug("Snapshot completed")
448+
_logger.info("Snapshot completed")
445449

446450
# TODO b/407772100 - Support multiple snapshots.
447451
self._snapshot = {"step": step, "snapshot": snapshot_host}
@@ -532,23 +536,23 @@ def maybe_reshard_down(
532536

533537
while True:
534538
if not self._is_error_due_to_slice_down(error):
535-
_logger.debug(
539+
_logger.info(
536540
"Not resharding down because the error is not due to a slice down."
537541
)
538542
raise error from error.__cause__
539543

540-
_logger.debug("Resharding down")
544+
_logger.info("Resharding down")
541545
self._slice_down(reshard_retry)
542546

543547
try:
544548
handler_return_values = elastic_handler(*handler_args, **handler_kwargs)
545549
break
546550
except jax.errors.JaxRuntimeError as e:
547-
_logger.debug("Elastic handler raised an error.")
551+
_logger.info("Elastic handler raised an error.")
548552
error = e
549553
reshard_retry = True
550554

551-
_logger.debug("Successfully resharded down")
555+
_logger.info("Successfully resharded down")
552556
return handler_return_values
553557

554558
@timing.timeit
@@ -586,7 +590,7 @@ def maybe_reshard_up(
586590
handler_kwargs = {}
587591

588592
if not self._is_ready_to_reshard(step):
589-
_logger.debug("Not resharding up since it is not time to reshard.")
593+
_logger.info("Not resharding up since it is not time to reshard.")
590594
return
591595

592596
self.maybe_snapshot(
@@ -599,7 +603,7 @@ def maybe_reshard_up(
599603
try:
600604
handler_return_values = elastic_handler(*handler_args, **handler_kwargs)
601605
except jax.errors.JaxRuntimeError as error:
602-
_logger.debug("Elastic handler failed. Trying again")
606+
_logger.info("Elastic handler failed. Trying again")
603607
handler_return_values = self.maybe_reshard_down(
604608
error=error,
605609
elastic_handler=elastic_handler,
@@ -608,5 +612,5 @@ def maybe_reshard_up(
608612
reshard_retry=True,
609613
)
610614

611-
_logger.debug("Finished resharding up")
615+
_logger.info("Finished resharding up")
612616
return handler_return_values

0 commit comments

Comments
 (0)