2525- Resharding the snapshot.
2626"""
2727
28+ import sys
2829import collections
2930from collections .abc import Callable , Mapping , Sequence
31+ import copy
3032import itertools
3133import logging
3234import traceback
@@ -356,8 +358,8 @@ def _slice_down(self, reshard_retry: bool = False) -> None:
356358 f"Max reshard retry count reached { self .max_reshard_retry_count = } "
357359 )
358360
359- # TODO b/407772100 - Support multiple snapshots.
360- def pop_snapshot (self ) -> tuple [int , PyTree ]:
361+ # TODO: b/407772100 - Support multiple snapshots.
362+ def pop_snapshot (self ) -> tuple [int , PyTree | None , PyTree | None ]:
361363 """Pops next snapshot.
362364
363365 This function is used to get the next snapshot and remove it from
@@ -373,50 +375,60 @@ def pop_snapshot(self) -> tuple[int, PyTree]:
373375 if self ._snapshot is None :
374376 raise ElasticRuntimeError ("No snapshot to pop." )
375377
376- step = self ._snapshot .pop ("step" )
377- snapshot = self ._snapshot .pop ("snapshot" )
378+ step , snapshot_jax_arrays , snapshot_controller = (
379+ self ._snapshot .pop (key )
380+ for key in ["step" , "snapshot_jax_arrays" , "snapshot_controller" ]
381+ )
378382 self ._snapshot = None
379383
380- return step , snapshot
384+ return step , snapshot_jax_arrays , snapshot_controller
381385
382386 @staticmethod
383- def _get_snapshot_size ( snapshot : PyTree ) -> int :
387+ def _get_snapshot_jax_arrays_size ( snapshot_jax_arrays : PyTree | None ) -> int :
384388 """Returns the size of a snapshot.
385389
386390 Args:
387391 snapshot: The snapshot to get the size of.
388392 """
389- return sum (leaf .nbytes for leaf in jax .tree .leaves (snapshot ))
393+ return sum (leaf .nbytes for leaf in jax .tree .leaves (snapshot_jax_arrays ))
390394
391395 @staticmethod
392- def _put_snapshot_on_host (
393- snapshot : PyTree ,
394- ) -> PyTree :
396+ def _put_snapshot_jax_arrays_on_host (
397+ snapshot_jax_arrays : PyTree | None ,
398+ ) -> PyTree | None :
395399 """Puts a copy of the snapshot on the host.
396400
397401 Args:
398402 snapshot: The snapshot to move to the host. Must be a PyTree of JAX
399- arrays.
403+ arrays or None .
400404
401405 Returns:
402406 A copy of the snapshot on the host.
403407 """
404408
405409 sharding_pinned_host = jax .tree .map (
406- lambda x : x .sharding .with_memory_kind ("pinned_host" ), snapshot
410+ lambda x : x .sharding .with_memory_kind ("pinned_host" ), snapshot_jax_arrays
407411 )
408412 return jax .device_put (
409- snapshot ,
413+ snapshot_jax_arrays ,
410414 sharding_pinned_host ,
411415 donate = False ,
412416 may_alias = False ,
413417 )
414418
419+ @staticmethod
420+ def _put_snapshot_on_controller (
421+ snapshot : PyTree | None ,
422+ ) -> PyTree | None :
423+ return copy .deepcopy (snapshot )
424+
425+ # TODO: b/407772100 - Support multiple snapshots.
415426 @timing .timeit
416427 def maybe_snapshot (
417428 self ,
418429 step : int ,
419- snapshot : PyTree ,
430+ snapshot_jax_arrays : PyTree | None = None ,
431+ snapshot_controller : PyTree | None = None ,
420432 force : bool = False ,
421433 block : bool = False ,
422434 ) -> None :
@@ -436,24 +448,28 @@ def maybe_snapshot(
436448 _logger .info ("Not saving a snapshot" )
437449 return
438450
439- total_nbytes = self ._get_snapshot_size ( snapshot )
451+ total_nbytes = self ._get_snapshot_jax_arrays_size ( snapshot_jax_arrays )
440452
441- _logger .info ("Saving a snapshot of %s bytes" , total_nbytes )
453+ _logger .info ("Saving a snapshot of %s bytes on host " , total_nbytes )
442454
443- snapshot_host = self ._put_snapshot_on_host ( snapshot )
455+ snapshot_jax_arrays_host = self ._put_snapshot_jax_arrays_on_host ( snapshot_jax_arrays )
444456 _logger .info ("Snapshot dispatched" )
445457
446458 if block :
447- jax .block_until_ready (snapshot_host )
459+ jax .block_until_ready (snapshot_jax_arrays_host )
448460 _logger .info ("Snapshot completed" )
449461
450- # TODO b/407772100 - Support multiple snapshots.
451- self ._snapshot = {"step" : step , "snapshot" : snapshot_host }
462+ snapshot_on_controller = self ._put_snapshot_on_controller (snapshot_controller )
463+ self ._snapshot = {
464+ "step" : step ,
465+ "snapshot_jax_arrays" : snapshot_jax_arrays_host ,
466+ "snapshot_controller" : snapshot_on_controller ,
467+ }
452468
453469 @timing .timeit
454470 def get_resharded_snapshot (
455471 self , mesh : jax .sharding .Mesh
456- ) -> tuple [int , Mapping [ str , int | PyTree ] ]:
472+ ) -> tuple [int , PyTree | None , PyTree | None ]:
457473 """Get the resharded snapshot.
458474
459475 The snapshot on pinned memory is resharded to the new mesh. This snapshot is
@@ -466,34 +482,41 @@ def get_resharded_snapshot(
466482 Returns:
467483 The next step and snapshot resharded to the new mesh.
468484 """
469- step , snapshot = self .pop_snapshot ()
485+ step , snapshot_jax_arrays , snapshot_controller = self .pop_snapshot ()
470486
471487 sharding_pinned_host = jax .tree .map (
472488 lambda x : jax .sharding .NamedSharding (
473489 mesh , x .sharding .spec , memory_kind = "pinned_host"
474490 ),
475- snapshot ,
491+ snapshot_jax_arrays ,
476492 )
477- resharded_snapshot_pinned_host = jax .device_put (
478- snapshot ,
493+ resharded_jax_arrays_pinned_host = jax .device_put (
494+ snapshot_jax_arrays ,
479495 sharding_pinned_host ,
480496 donate = True ,
481497 may_alias = False ,
482498 )
483- self ._snapshot = {"step" : step , "snapshot" : resharded_snapshot_pinned_host }
484499
485500 sharding_device = jax .tree .map (
486501 lambda x : x .sharding .with_memory_kind ("device" ),
487- resharded_snapshot_pinned_host ,
502+ resharded_jax_arrays_pinned_host ,
488503 )
489- resharded_snapshot_device = jax .device_put (
490- resharded_snapshot_pinned_host ,
504+ resharded_jax_arrays_device = jax .device_put (
505+ resharded_jax_arrays_pinned_host ,
491506 sharding_device ,
492507 donate = False ,
493508 may_alias = False ,
494509 )
495510
496- return step , resharded_snapshot_device
511+ snapshot_on_controller = self ._put_snapshot_on_controller (snapshot_controller )
512+
513+ self ._snapshot = {
514+ "step" : step ,
515+ "snapshot_jax_arrays" : resharded_jax_arrays_pinned_host ,
516+ "snapshot_controller" : snapshot_on_controller ,
517+ }
518+
519+ return step , resharded_jax_arrays_device , snapshot_controller
497520
498521 @timing .timeit
499522 def maybe_reshard_down (
@@ -559,8 +582,9 @@ def maybe_reshard_down(
559582 def maybe_reshard_up (
560583 self ,
561584 step : int ,
562- snapshot : Mapping [str , int | PyTree ],
563585 elastic_handler : Callable [..., Any ],
586+ snapshot_jax_arrays : PyTree | None = None ,
587+ snapshot_controller : PyTree | None = None ,
564588 handler_args : tuple [Any , ...] | None = None ,
565589 handler_kwargs : Mapping [str , Any ] | None = None ,
566590 ) -> Any :
@@ -595,7 +619,8 @@ def maybe_reshard_up(
595619
596620 self .maybe_snapshot (
597621 step = step ,
598- snapshot = snapshot ,
622+ snapshot_jax_arrays = snapshot_jax_arrays ,
623+ snapshot_controller = snapshot_controller ,
599624 force = True ,
600625 block = True ,
601626 )
0 commit comments