2424import jax
2525from maxtext .utils .globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
2626from maxtext .input_pipeline .multihost_dataloading import MultiHostDataLoadIterator
27- from maxtext .input_pipeline .multihost_dataloading import RemoteIterator
27+ from maxtext .input_pipeline .multihost_dataloading import RemoteIteratorWrapper
2828from maxtext .input_pipeline .synthetic_data_processing import PlaceHolderDataIterator
2929from maxtext .utils import exceptions
3030from maxtext .utils import max_logging
4444
4545import grain
4646from grain .python import PyGrainCheckpointHandler
47+ from grain .experimental import ElasticIterator
4748
4849CheckpointManager = ocp .CheckpointManager
4950CheckpointManagerOptions = ocp .CheckpointManagerOptions
@@ -69,6 +70,22 @@ def save(
6970 """Saves the given iterator to the checkpoint in `directory`."""
7071 item = item or args .item # pytype:disable=attribute-error
7172
73+ # RemoteIteratorWrapper handles checkpointing via colocated python
74+ if isinstance (item , RemoteIteratorWrapper ):
75+ step = int (directory .parent .name )
76+ item .save_state (step )
77+ return
78+
79+ # ElasticIterator state is a single global scalar shared by all shards,
80+ # so we write one fixed `process_0.json` from process 0 only. This file
81+ # layout survives changes in `jax.process_count()`.
82+ if isinstance (item , ElasticIterator ):
83+ if jax .process_index () == 0 :
84+ directory .mkdir (parents = True , exist_ok = True )
85+ filename = directory / "process_0.json"
86+ filename .write_text (json .dumps (item .get_state (), indent = 4 ))
87+ return
88+
7289 def save_single_process (item , process_index , process_count ):
7390 filename = directory / f"process_{ process_index } -of-{ process_count } .json"
7491 if isinstance (item , grain .DatasetIterator ):
@@ -95,6 +112,21 @@ def restore(
95112 process_index = getattr (args , "process_index" , None )
96113 process_count = getattr (args , "process_count" , None )
97114
115+ # In Pathways + colocated_python environment, RemoteIteratorWrapper handles checkpointing
116+ if isinstance (item , RemoteIteratorWrapper ):
117+ step = int (directory .parent .name )
118+ item .restore_state (step )
119+ return item
120+
121+ # McJax and Pathways through controller cases
122+ # ElasticIterator: every process reads the same shared `process_0.json`.
123+ if isinstance (item , ElasticIterator ):
124+ filename = directory / "process_0.json"
125+ if not filename .exists ():
126+ raise ValueError (f"File { filename } does not exist." )
127+ item .set_state (json .loads (filename .read_text ()))
128+ return item
129+
98130 def restore_single_process (item , process_index , process_count ):
99131 filename = directory / f"process_{ process_index } -of-{ process_count } .json"
100132 if not filename .exists ():
@@ -132,15 +164,6 @@ class GrainCheckpointRestore(ocp.args.CheckpointArgs):
132164 process_count : Optional [int ] = None
133165
134166
135- def _is_remote_iterator (data_iterator ):
136- """Check if data_iterator is a RemoteIterator or contains RemoteIterator instances."""
137- if isinstance (data_iterator , RemoteIterator ):
138- return True
139- if isinstance (data_iterator , list ):
140- return any (isinstance (item , RemoteIterator ) for item in data_iterator )
141- return False
142-
143-
144167def _load_full_state_from_path (
145168 path ,
146169 abstract_unboxed_pre_state ,
@@ -482,6 +505,17 @@ def _restore_grain_iterator(
482505 This function dispatches to the correct restore strategy based on
483506 the number of stored checkpoint files vs. current JAX processes.
484507 """
508+ if isinstance (data_iterator , RemoteIteratorWrapper ):
509+ grain_restore_args = GrainCheckpointRestore (item = data_iterator )
510+ restored_state = checkpoint_manager .restore (step , args = Composite (items = checkpoint_args , iter = grain_restore_args ))
511+ return (restored_state , None )
512+
513+ # ElasticIterator: one shared `process_0.json` regardless of shard count.
514+ if not isinstance (data_iterator , list ) and isinstance (data_iterator .local_iterator , ElasticIterator ):
515+ grain_restore_args = GrainCheckpointRestore (item = data_iterator .local_iterator )
516+ restored_state = checkpoint_manager .restore (step , args = Composite (items = checkpoint_args , iter = grain_restore_args ))
517+ return (restored_state , None )
518+
485519 directory = checkpoint_manager .directory / str (step ) / "iter"
486520 process_count_jax = jax .process_count ()
487521
@@ -625,7 +659,7 @@ def map_to_pspec(data):
625659 None ,
626660 )
627661 # Case 2: Matches if dataset type is "grain" and the data iterator is not a
628- # PlaceHolderDataIterator or RemoteIterator and a specific checkpoint file exists for the iterator
662+ # PlaceHolderDataIterator and a specific checkpoint file exists for the iterator
629663 case (
630664 checkpoint_manager ,
631665 dataset_type ,
@@ -634,7 +668,6 @@ def map_to_pspec(data):
634668 dataset_type == "grain"
635669 and data_iterator
636670 and not isinstance (data_iterator , PlaceHolderDataIterator )
637- and not _is_remote_iterator (data_iterator )
638671 and (checkpoint_manager .directory / str (step ) / "iter" ).exists ()
639672 ):
640673 return _restore_grain_iterator (
@@ -810,22 +843,24 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
810843 )
811844 save_args_composite = {"items" : checkpoint_args }
812845
813- if (
814- config
815- and config .dataset_type == "grain"
816- and not isinstance (data_iterator , PlaceHolderDataIterator )
817- and not _is_remote_iterator (data_iterator )
818- ):
819- if not isinstance (data_iterator , list ):
820- data_iterator = [data_iterator ]
821- grain_iters_to_save = []
822- process_count_total = jax .process_count () * len (data_iterator )
823- if config .expansion_factor_real_data > 1 :
824- process_count_total = process_count_total // config .expansion_factor_real_data
825- for i , data_iter in enumerate (data_iterator ):
826- process_index = jax .process_index () + i * jax .process_count ()
827- grain_iters_to_save .append ((data_iter .local_iterator , process_index , process_count_total ))
828- save_args_composite ["iter" ] = GrainCheckpointSave (item = grain_iters_to_save )
846+ if config and config .dataset_type == "grain" and not isinstance (data_iterator , PlaceHolderDataIterator ):
847+ if isinstance (data_iterator , RemoteIteratorWrapper ):
848+ # Pass the wrapper directly; GrainCheckpointHandler will call save_state with the step
849+ save_args_composite ["iter" ] = GrainCheckpointSave (item = data_iterator )
850+ elif not isinstance (data_iterator , list ) and isinstance (data_iterator .local_iterator , ElasticIterator ):
851+ # ElasticIterator checkpoints a single global scalar shared by all shards.
852+ save_args_composite ["iter" ] = GrainCheckpointSave (item = data_iterator .local_iterator )
853+ else :
854+ if not isinstance (data_iterator , list ):
855+ data_iterator = [data_iterator ]
856+ grain_iters_to_save = []
857+ process_count_total = jax .process_count () * len (data_iterator )
858+ if config .expansion_factor_real_data > 1 :
859+ process_count_total = process_count_total // config .expansion_factor_real_data
860+ for i , data_iter in enumerate (data_iterator ):
861+ process_index = jax .process_index () + i * jax .process_count ()
862+ grain_iters_to_save .append ((data_iter .local_iterator , process_index , process_count_total ))
863+ save_args_composite ["iter" ] = GrainCheckpointSave (item = grain_iters_to_save )
829864
830865 match (checkpoint_manager , config , data_iterator ):
831866 case (checkpoint_manager , _, _) if isinstance (
0 commit comments