Skip to content

Commit 4989ff0

Browse files
aireenmeicopybara-github
authored andcommitted
add snapshot_controller in elastic manager for resume dataloading with grain
PiperOrigin-RevId: 747478080
1 parent 42640da commit 4989ff0

1 file changed

Lines changed: 37 additions & 8 deletions

File tree

pathwaysutils/elastic/manager.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525
- Resharding the snapshot.
2626
"""
2727

28+
import sys
2829
import collections
2930
from collections.abc import Callable, Mapping, Sequence
31+
import copy
3032
import itertools
3133
import logging
3234
import traceback
@@ -357,7 +359,7 @@ def _slice_down(self, reshard_retry: bool = False) -> None:
357359
)
358360

359361
# TODO b/407772100 - Support multiple snapshots.
360-
def pop_snapshot(self) -> tuple[int, PyTree]:
362+
def pop_snapshot(self) -> tuple[int, PyTree, Any]:
361363
"""Pops next snapshot.
362364
363365
This function is used to get the next snapshot and remove it from
@@ -375,9 +377,13 @@ def pop_snapshot(self) -> tuple[int, PyTree]:
375377

376378
step = self._snapshot.pop("step")
377379
snapshot = self._snapshot.pop("snapshot")
380+
if "snapshot_controller" in self._snapshot:
381+
snapshot_controller = self._snapshot.pop("snapshot_controller")
382+
else:
383+
snapshot_controller = None
378384
self._snapshot = None
379385

380-
return step, snapshot
386+
return step, snapshot, snapshot_controller
381387

382388
@staticmethod
383389
def _get_snapshot_size(snapshot: PyTree) -> int:
@@ -412,11 +418,18 @@ def _put_snapshot_on_host(
412418
may_alias=False,
413419
)
414420

421+
@staticmethod
422+
def _put_snapshot_on_controller(
423+
snapshot: PyTree,
424+
) -> PyTree:
425+
return copy.deepcopy(snapshot)
426+
415427
@timing.timeit
416428
def maybe_snapshot(
417429
self,
418430
step: int,
419431
snapshot: PyTree,
432+
snapshot_controller: Mapping[str, Any] | None = None,
420433
force: bool = False,
421434
block: bool = False,
422435
) -> None:
@@ -438,7 +451,7 @@ def maybe_snapshot(
438451

439452
total_nbytes = self._get_snapshot_size(snapshot)
440453

441-
_logger.info("Saving a snapshot of %s bytes", total_nbytes)
454+
_logger.info("Saving a snapshot of %s bytes on host", total_nbytes)
442455

443456
snapshot_host = self._put_snapshot_on_host(snapshot)
444457
_logger.info("Snapshot dispatched")
@@ -448,12 +461,22 @@ def maybe_snapshot(
448461
_logger.info("Snapshot completed")
449462

450463
# TODO b/407772100 - Support multiple snapshots.
451-
self._snapshot = {"step": step, "snapshot": snapshot_host}
464+
self._snapshot = {
465+
"step": step,
466+
"snapshot": snapshot_host,
467+
}
468+
if snapshot_controller is not None:
469+
total_nbytes = sys.getsizeof(snapshot_controller)
470+
_logger.info("Saving a snapshot of %s bytes on controller", total_nbytes)
471+
snapshot_on_controller = self._put_snapshot_on_controller(
472+
snapshot_controller
473+
)
474+
self._snapshot["snapshot_controller"] = snapshot_on_controller
452475

453476
@timing.timeit
454477
def get_resharded_snapshot(
455478
self, mesh: jax.sharding.Mesh
456-
) -> tuple[int, Mapping[str, int | PyTree]]:
479+
) -> tuple[int, Mapping[str, int | PyTree], Any]:
457480
"""Get the resharded snapshot.
458481
459482
The snapshot on pinned memory is resharded to the new mesh. This snapshot is
@@ -466,7 +489,7 @@ def get_resharded_snapshot(
466489
Returns:
467490
The next step and snapshot resharded to the new mesh.
468491
"""
469-
step, snapshot = self.pop_snapshot()
492+
step, snapshot, snapshot_controller = self.pop_snapshot()
470493

471494
sharding_pinned_host = jax.tree.map(
472495
lambda x: jax.sharding.NamedSharding(
@@ -480,7 +503,11 @@ def get_resharded_snapshot(
480503
donate=True,
481504
may_alias=False,
482505
)
483-
self._snapshot = {"step": step, "snapshot": resharded_snapshot_pinned_host}
506+
self._snapshot = {
507+
"step": step,
508+
"snapshot": resharded_snapshot_pinned_host,
509+
"snapshot_controller": snapshot_controller,
510+
}
484511

485512
sharding_device = jax.tree.map(
486513
lambda x: x.sharding.with_memory_kind("device"),
@@ -493,7 +520,7 @@ def get_resharded_snapshot(
493520
may_alias=False,
494521
)
495522

496-
return step, resharded_snapshot_device
523+
return step, resharded_snapshot_device, snapshot_controller
497524

498525
@timing.timeit
499526
def maybe_reshard_down(
@@ -561,6 +588,7 @@ def maybe_reshard_up(
561588
step: int,
562589
snapshot: Mapping[str, int | PyTree],
563590
elastic_handler: Callable[..., Any],
591+
snapshot_controller: Mapping[str, Any] | None = None,
564592
handler_args: tuple[Any, ...] | None = None,
565593
handler_kwargs: Mapping[str, Any] | None = None,
566594
) -> Any:
@@ -596,6 +624,7 @@ def maybe_reshard_up(
596624
self.maybe_snapshot(
597625
step=step,
598626
snapshot=snapshot,
627+
snapshot_controller=snapshot_controller,
599628
force=True,
600629
block=True,
601630
)

0 commit comments

Comments
 (0)