4343
4444import grain
4545from grain .python import PyGrainCheckpointHandler
46+ from grain .experimental import ElasticIterator
4647
4748CheckpointManager = ocp .CheckpointManager
4849CheckpointManagerOptions = ocp .CheckpointManagerOptions
@@ -68,6 +69,22 @@ def save(
6869 """Saves the given iterator to the checkpoint in `directory`."""
6970 item = item or args .item # pytype:disable=attribute-error
7071
72+ # RemoteIteratorWrapper handles checkpointing via colocated python
73+ if isinstance (item , RemoteIteratorWrapper ):
74+ step = int (directory .parent .name )
75+ item .save_state (step )
76+ return
77+
78+ # ElasticIterator state is a single global scalar shared by all shards,
79+ # so we write one fixed `process_0.json` from process 0 only. This file
80+ # layout survives changes in `jax.process_count()`.
81+ if isinstance (item , ElasticIterator ):
82+ if jax .process_index () == 0 :
83+ directory .mkdir (parents = True , exist_ok = True )
84+ filename = directory / "process_0.json"
85+ filename .write_text (json .dumps (item .get_state (), indent = 4 ))
86+ return
87+
7188 def save_single_process (item , process_index , process_count ):
7289 filename = directory / f"process_{ process_index } -of-{ process_count } .json"
7390 if isinstance (item , grain .DatasetIterator ):
@@ -94,6 +111,20 @@ def restore(
94111 process_index = getattr (args , "process_index" , None )
95112 process_count = getattr (args , "process_count" , None )
96113
114+ # RemoteIteratorWrapper handles checkpointing via colocated python
115+ if isinstance (item , RemoteIteratorWrapper ):
116+ step = int (directory .parent .name )
117+ item .restore_state (step )
118+ return item
119+
120+ # ElasticIterator: every process reads the same shared `process_0.json`.
121+ if isinstance (item , ElasticIterator ):
122+ filename = directory / "process_0.json"
123+ if not filename .exists ():
124+ raise ValueError (f"File { filename } does not exist." )
125+ item .set_state (json .loads (filename .read_text ()))
126+ return item
127+
97128 def restore_single_process (item , process_index , process_count ):
98129 filename = directory / f"process_{ process_index } -of-{ process_count } .json"
99130 if not filename .exists ():
@@ -131,15 +162,6 @@ class GrainCheckpointRestore(ocp.args.CheckpointArgs):
131162 process_count : Optional [int ] = None
132163
133164
134- def _is_remote_iterator (data_iterator ):
135- """Check if data_iterator is a RemoteIteratorWrapper or contains RemoteIteratorWrapper instances."""
136- if isinstance (data_iterator , RemoteIteratorWrapper ):
137- return True
138- if isinstance (data_iterator , list ):
139- return any (isinstance (item , RemoteIteratorWrapper ) for item in data_iterator )
140- return False
141-
142-
143165def _load_full_state_from_path (
144166 path ,
145167 abstract_unboxed_pre_state ,
@@ -481,6 +503,17 @@ def _restore_grain_iterator(
481503 This function dispatches to the correct restore strategy based on
482504 the number of stored checkpoint files vs. current JAX processes.
483505 """
506+ if isinstance (data_iterator , RemoteIteratorWrapper ):
507+ grain_restore_args = GrainCheckpointRestore (item = data_iterator )
508+ restored_state = checkpoint_manager .restore (step , args = Composite (items = checkpoint_args , iter = grain_restore_args ))
509+ return (restored_state , None )
510+
511+ # ElasticIterator: one shared `process_0.json` regardless of shard count.
512+ if not isinstance (data_iterator , list ) and isinstance (data_iterator .local_iterator , ElasticIterator ):
513+ grain_restore_args = GrainCheckpointRestore (item = data_iterator .local_iterator )
514+ restored_state = checkpoint_manager .restore (step , args = Composite (items = checkpoint_args , iter = grain_restore_args ))
515+ return (restored_state , None )
516+
484517 directory = checkpoint_manager .directory / str (step ) / "iter"
485518 process_count_jax = jax .process_count ()
486519
@@ -619,7 +652,7 @@ def map_to_pspec(data):
619652 None ,
620653 )
621654 # Case 2: Matches if dataset type is "grain" and the data iterator is not a
622- # PlaceHolderDataIterator or RemoteIteratorWrapper and a specific checkpoint file exists for the iterator
655+ # PlaceHolderDataIterator and a specific checkpoint file exists for the iterator
623656 case (
624657 checkpoint_manager ,
625658 dataset_type ,
@@ -628,7 +661,6 @@ def map_to_pspec(data):
628661 dataset_type == "grain"
629662 and data_iterator
630663 and not isinstance (data_iterator , PlaceHolderDataIterator )
631- and not _is_remote_iterator (data_iterator )
632664 and (checkpoint_manager .directory / str (step ) / "iter" ).exists ()
633665 ):
634666 return _restore_grain_iterator (
@@ -790,22 +822,24 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
790822 )
791823 save_args_composite = {"items" : checkpoint_args }
792824
793- if (
794- config
795- and config .dataset_type == "grain"
796- and not isinstance (data_iterator , PlaceHolderDataIterator )
797- and not _is_remote_iterator (data_iterator )
798- ):
825+ if config and config .dataset_type == "grain" and not isinstance (data_iterator , PlaceHolderDataIterator ):
799826 if not isinstance (data_iterator , list ):
800827 data_iterator = [data_iterator ]
801- grain_iters_to_save = []
802- process_count_total = jax .process_count () * len (data_iterator )
803- if config .expansion_factor_real_data > 1 :
804- process_count_total = process_count_total // config .expansion_factor_real_data
805- for i , data_iter in enumerate (data_iterator ):
806- process_index = jax .process_index () + i * jax .process_count ()
807- grain_iters_to_save .append ((data_iter .local_iterator , process_index , process_count_total ))
808- save_args_composite ["iter" ] = GrainCheckpointSave (item = grain_iters_to_save )
828+ if isinstance (data_iterator [0 ], RemoteIteratorWrapper ):
829+ # Pass the wrapper directly; GrainCheckpointHandler will call save_state with the step
830+ save_args_composite ["iter" ] = GrainCheckpointSave (item = data_iterator [0 ])
831+ elif isinstance (data_iterator [0 ].local_iterator , ElasticIterator ):
832+ # ElasticIterator checkpoints a single global scalar shared by all shards.
833+ save_args_composite ["iter" ] = GrainCheckpointSave (item = data_iterator [0 ].local_iterator )
834+ else :
835+ grain_iters_to_save = []
836+ process_count_total = jax .process_count () * len (data_iterator )
837+ if config .expansion_factor_real_data > 1 :
838+ process_count_total = process_count_total // config .expansion_factor_real_data
839+ for i , data_iter in enumerate (data_iterator ):
840+ process_index = jax .process_index () + i * jax .process_count ()
841+ grain_iters_to_save .append ((data_iter .local_iterator , process_index , process_count_total ))
842+ save_args_composite ["iter" ] = GrainCheckpointSave (item = grain_iters_to_save )
809843
810844 match (checkpoint_manager , config , data_iterator ):
811845 case (checkpoint_manager , _, _) if isinstance (
0 commit comments