2525- Resharding the snapshot.
2626"""
2727
28+ import sys
2829import collections
2930from collections .abc import Callable , Mapping , Sequence
31+ import copy
3032import itertools
3133import logging
3234import traceback
@@ -357,7 +359,7 @@ def _slice_down(self, reshard_retry: bool = False) -> None:
357359 )
358360
359361 # TODO b/407772100 - Support multiple snapshots.
360- def pop_snapshot (self ) -> tuple [int , PyTree ]:
362+ def pop_snapshot (self ) -> tuple [int , PyTree , Any ]:
361363 """Pops next snapshot.
362364
363365 This function is used to get the next snapshot and remove it from
@@ -375,9 +377,13 @@ def pop_snapshot(self) -> tuple[int, PyTree]:
375377
376378 step = self ._snapshot .pop ("step" )
377379 snapshot = self ._snapshot .pop ("snapshot" )
380+ if "snapshot_controller" in self ._snapshot :
381+ snapshot_controller = self ._snapshot .pop ("snapshot_controller" )
382+ else :
383+ snapshot_controller = None
378384 self ._snapshot = None
379385
380- return step , snapshot
386+ return step , snapshot , snapshot_controller
381387
382388 @staticmethod
383389 def _get_snapshot_size (snapshot : PyTree ) -> int :
@@ -412,11 +418,18 @@ def _put_snapshot_on_host(
412418 may_alias = False ,
413419 )
414420
421+ @staticmethod
422+ def _put_snapshot_on_controller (
423+ snapshot : PyTree ,
424+ ) -> PyTree :
425+ return copy .deepcopy (snapshot )
426+
415427 @timing .timeit
416428 def maybe_snapshot (
417429 self ,
418430 step : int ,
419431 snapshot : PyTree ,
432+ snapshot_controller : Mapping [str , Any ] | None = None ,
420433 force : bool = False ,
421434 block : bool = False ,
422435 ) -> None :
@@ -438,7 +451,7 @@ def maybe_snapshot(
438451
439452 total_nbytes = self ._get_snapshot_size (snapshot )
440453
441- _logger .info ("Saving a snapshot of %s bytes" , total_nbytes )
454+ _logger .info ("Saving a snapshot of %s bytes on host " , total_nbytes )
442455
443456 snapshot_host = self ._put_snapshot_on_host (snapshot )
444457 _logger .info ("Snapshot dispatched" )
@@ -448,12 +461,22 @@ def maybe_snapshot(
448461 _logger .info ("Snapshot completed" )
449462
450463 # TODO b/407772100 - Support multiple snapshots.
451- self ._snapshot = {"step" : step , "snapshot" : snapshot_host }
464+ self ._snapshot = {
465+ "step" : step ,
466+ "snapshot" : snapshot_host ,
467+ }
468+ if snapshot_controller is not None :
469+ total_nbytes = sys .getsizeof (snapshot_controller )
470+ _logger .info ("Saving a snapshot of %s bytes on controller" , total_nbytes )
471+ snapshot_on_controller = self ._put_snapshot_on_controller (
472+ snapshot_controller
473+ )
474+ self ._snapshot ["snapshot_controller" ] = snapshot_on_controller
452475
453476 @timing .timeit
454477 def get_resharded_snapshot (
455478 self , mesh : jax .sharding .Mesh
456- ) -> tuple [int , Mapping [str , int | PyTree ]]:
479+ ) -> tuple [int , Mapping [str , int | PyTree ], Any ]:
457480 """Get the resharded snapshot.
458481
459482 The snapshot on pinned memory is resharded to the new mesh. This snapshot is
@@ -466,7 +489,7 @@ def get_resharded_snapshot(
466489 Returns:
467490 The next step and snapshot resharded to the new mesh.
468491 """
469- step , snapshot = self .pop_snapshot ()
492+ step , snapshot , snapshot_controller = self .pop_snapshot ()
470493
471494 sharding_pinned_host = jax .tree .map (
472495 lambda x : jax .sharding .NamedSharding (
@@ -480,7 +503,11 @@ def get_resharded_snapshot(
480503 donate = True ,
481504 may_alias = False ,
482505 )
483- self ._snapshot = {"step" : step , "snapshot" : resharded_snapshot_pinned_host }
506+ self ._snapshot = {
507+ "step" : step ,
508+ "snapshot" : resharded_snapshot_pinned_host ,
509+ "snapshot_controller" : snapshot_controller ,
510+ }
484511
485512 sharding_device = jax .tree .map (
486513 lambda x : x .sharding .with_memory_kind ("device" ),
@@ -493,7 +520,7 @@ def get_resharded_snapshot(
493520 may_alias = False ,
494521 )
495522
496- return step , resharded_snapshot_device
523+ return step , resharded_snapshot_device , snapshot_controller
497524
498525 @timing .timeit
499526 def maybe_reshard_down (
@@ -561,6 +588,7 @@ def maybe_reshard_up(
561588 step : int ,
562589 snapshot : Mapping [str , int | PyTree ],
563590 elastic_handler : Callable [..., Any ],
591+ snapshot_controller : Mapping [str , Any ] | None = None ,
564592 handler_args : tuple [Any , ...] | None = None ,
565593 handler_kwargs : Mapping [str , Any ] | None = None ,
566594 ) -> Any :
@@ -596,6 +624,7 @@ def maybe_reshard_up(
596624 self .maybe_snapshot (
597625 step = step ,
598626 snapshot = snapshot ,
627+ snapshot_controller = snapshot_controller ,
599628 force = True ,
600629 block = True ,
601630 )
0 commit comments