Skip to content

Commit 28530ce

Browse files
aireenmeicopybara-github
authored andcommitted
add snapshot_controller in elastic manager for resume dataloading with grain, usage is here: AI-Hypercomputer/maxtext#1574
PiperOrigin-RevId: 748453130
1 parent 92cc542 commit 28530ce

1 file changed

Lines changed: 57 additions & 32 deletions

File tree

pathwaysutils/elastic/manager.py

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525
- Resharding the snapshot.
2626
"""
2727

28+
import sys
2829
import collections
2930
from collections.abc import Callable, Mapping, Sequence
31+
import copy
3032
import itertools
3133
import logging
3234
import traceback
@@ -356,8 +358,8 @@ def _slice_down(self, reshard_retry: bool = False) -> None:
356358
f"Max reshard retry count reached {self.max_reshard_retry_count=}"
357359
)
358360

359-
# TODO b/407772100 - Support multiple snapshots.
360-
def pop_snapshot(self) -> tuple[int, PyTree]:
361+
# TODO: b/407772100 - Support multiple snapshots.
362+
def pop_snapshot(self) -> tuple[int, PyTree | None, PyTree | None]:
361363
"""Pops next snapshot.
362364
363365
This function is used to get the next snapshot and remove it from
@@ -373,50 +375,60 @@ def pop_snapshot(self) -> tuple[int, PyTree]:
373375
if self._snapshot is None:
374376
raise ElasticRuntimeError("No snapshot to pop.")
375377

376-
step = self._snapshot.pop("step")
377-
snapshot = self._snapshot.pop("snapshot")
378+
step, snapshot_jax_arrays, snapshot_controller = (
379+
self._snapshot.pop(key)
380+
for key in ["step", "snapshot_jax_arrays", "snapshot_controller"]
381+
)
378382
self._snapshot = None
379383

380-
return step, snapshot
384+
return step, snapshot_jax_arrays, snapshot_controller
381385

382386
@staticmethod
383-
def _get_snapshot_size(snapshot: PyTree) -> int:
387+
def _get_snapshot_jax_arrays_size(snapshot_jax_arrays: PyTree | None) -> int:
384388
"""Returns the size of a snapshot.
385389
386390
Args:
387391
snapshot: The snapshot to get the size of.
388392
"""
389-
return sum(leaf.nbytes for leaf in jax.tree.leaves(snapshot))
393+
return sum(leaf.nbytes for leaf in jax.tree.leaves(snapshot_jax_arrays))
390394

391395
@staticmethod
392-
def _put_snapshot_on_host(
393-
snapshot: PyTree,
394-
) -> PyTree:
396+
def _put_snapshot_jax_arrays_on_host(
397+
snapshot_jax_arrays: PyTree | None,
398+
) -> PyTree | None:
395399
"""Puts a copy of the snapshot on the host.
396400
397401
Args:
398402
snapshot: The snapshot to move to the host. Must be a PyTree of JAX
399-
arrays.
403+
arrays or None.
400404
401405
Returns:
402406
A copy of the snapshot on the host.
403407
"""
404408

405409
sharding_pinned_host = jax.tree.map(
406-
lambda x: x.sharding.with_memory_kind("pinned_host"), snapshot
410+
lambda x: x.sharding.with_memory_kind("pinned_host"), snapshot_jax_arrays
407411
)
408412
return jax.device_put(
409-
snapshot,
413+
snapshot_jax_arrays,
410414
sharding_pinned_host,
411415
donate=False,
412416
may_alias=False,
413417
)
414418

419+
@staticmethod
420+
def _put_snapshot_on_controller(
421+
snapshot: PyTree | None,
422+
) -> PyTree | None:
423+
return copy.deepcopy(snapshot)
424+
425+
# TODO: b/407772100 - Support multiple snapshots.
415426
@timing.timeit
416427
def maybe_snapshot(
417428
self,
418429
step: int,
419-
snapshot: PyTree,
430+
snapshot_jax_arrays: PyTree | None = None,
431+
snapshot_controller: PyTree | None = None,
420432
force: bool = False,
421433
block: bool = False,
422434
) -> None:
@@ -436,24 +448,28 @@ def maybe_snapshot(
436448
_logger.info("Not saving a snapshot")
437449
return
438450

439-
total_nbytes = self._get_snapshot_size(snapshot)
451+
total_nbytes = self._get_snapshot_jax_arrays_size(snapshot_jax_arrays)
440452

441-
_logger.info("Saving a snapshot of %s bytes", total_nbytes)
453+
_logger.info("Saving a snapshot of %s bytes on host", total_nbytes)
442454

443-
snapshot_host = self._put_snapshot_on_host(snapshot)
455+
snapshot_jax_arrays_host = self._put_snapshot_jax_arrays_on_host(snapshot_jax_arrays)
444456
_logger.info("Snapshot dispatched")
445457

446458
if block:
447-
jax.block_until_ready(snapshot_host)
459+
jax.block_until_ready(snapshot_jax_arrays_host)
448460
_logger.info("Snapshot completed")
449461

450-
# TODO b/407772100 - Support multiple snapshots.
451-
self._snapshot = {"step": step, "snapshot": snapshot_host}
462+
snapshot_on_controller = self._put_snapshot_on_controller(snapshot_controller)
463+
self._snapshot = {
464+
"step": step,
465+
"snapshot_jax_arrays": snapshot_jax_arrays_host,
466+
"snapshot_controller": snapshot_on_controller,
467+
}
452468

453469
@timing.timeit
454470
def get_resharded_snapshot(
455471
self, mesh: jax.sharding.Mesh
456-
) -> tuple[int, Mapping[str, int | PyTree]]:
472+
) -> tuple[int, PyTree | None, PyTree | None]:
457473
"""Get the resharded snapshot.
458474
459475
The snapshot on pinned memory is resharded to the new mesh. This snapshot is
@@ -466,34 +482,41 @@ def get_resharded_snapshot(
466482
Returns:
467483
The next step and snapshot resharded to the new mesh.
468484
"""
469-
step, snapshot = self.pop_snapshot()
485+
step, snapshot_jax_arrays, snapshot_controller = self.pop_snapshot()
470486

471487
sharding_pinned_host = jax.tree.map(
472488
lambda x: jax.sharding.NamedSharding(
473489
mesh, x.sharding.spec, memory_kind="pinned_host"
474490
),
475-
snapshot,
491+
snapshot_jax_arrays,
476492
)
477-
resharded_snapshot_pinned_host = jax.device_put(
478-
snapshot,
493+
resharded_jax_arrays_pinned_host = jax.device_put(
494+
snapshot_jax_arrays,
479495
sharding_pinned_host,
480496
donate=True,
481497
may_alias=False,
482498
)
483-
self._snapshot = {"step": step, "snapshot": resharded_snapshot_pinned_host}
484499

485500
sharding_device = jax.tree.map(
486501
lambda x: x.sharding.with_memory_kind("device"),
487-
resharded_snapshot_pinned_host,
502+
resharded_jax_arrays_pinned_host,
488503
)
489-
resharded_snapshot_device = jax.device_put(
490-
resharded_snapshot_pinned_host,
504+
resharded_jax_arrays_device = jax.device_put(
505+
resharded_jax_arrays_pinned_host,
491506
sharding_device,
492507
donate=False,
493508
may_alias=False,
494509
)
495510

496-
return step, resharded_snapshot_device
511+
snapshot_on_controller = self._put_snapshot_on_controller(snapshot_controller)
512+
513+
self._snapshot = {
514+
"step": step,
515+
"snapshot_jax_arrays": resharded_jax_arrays_pinned_host,
516+
"snapshot_controller": snapshot_on_controller,
517+
}
518+
519+
return step, resharded_jax_arrays_device, snapshot_controller
497520

498521
@timing.timeit
499522
def maybe_reshard_down(
@@ -559,8 +582,9 @@ def maybe_reshard_down(
559582
def maybe_reshard_up(
560583
self,
561584
step: int,
562-
snapshot: Mapping[str, int | PyTree],
563585
elastic_handler: Callable[..., Any],
586+
snapshot_jax_arrays: PyTree | None = None,
587+
snapshot_controller: PyTree | None = None,
564588
handler_args: tuple[Any, ...] | None = None,
565589
handler_kwargs: Mapping[str, Any] | None = None,
566590
) -> Any:
@@ -595,7 +619,8 @@ def maybe_reshard_up(
595619

596620
self.maybe_snapshot(
597621
step=step,
598-
snapshot=snapshot,
622+
snapshot_jax_arrays=snapshot_jax_arrays,
623+
snapshot_controller=snapshot_controller,
599624
force=True,
600625
block=True,
601626
)

0 commit comments

Comments
 (0)