Skip to content

Commit a8ecbbf

Browse files
Merge pull request #3673 from AI-Hypercomputer:aireen/elastic_data
PiperOrigin-RevId: 907833470
2 parents 58ffd43 + e48f295 commit a8ecbbf

8 files changed

Lines changed: 404 additions & 306 deletions

File tree

src/maxtext/common/checkpointing.py

Lines changed: 63 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import jax
2525
from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
2626
from maxtext.input_pipeline.multihost_dataloading import MultiHostDataLoadIterator
27-
from maxtext.input_pipeline.multihost_dataloading import RemoteIterator
27+
from maxtext.input_pipeline.multihost_dataloading import RemoteIteratorWrapper
2828
from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator
2929
from maxtext.utils import exceptions
3030
from maxtext.utils import max_logging
@@ -44,6 +44,7 @@
4444

4545
import grain
4646
from grain.python import PyGrainCheckpointHandler
47+
from grain.experimental import ElasticIterator
4748

4849
CheckpointManager = ocp.CheckpointManager
4950
CheckpointManagerOptions = ocp.CheckpointManagerOptions
@@ -69,6 +70,22 @@ def save(
6970
"""Saves the given iterator to the checkpoint in `directory`."""
7071
item = item or args.item # pytype:disable=attribute-error
7172

73+
# RemoteIteratorWrapper handles checkpointing via colocated python
74+
if isinstance(item, RemoteIteratorWrapper):
75+
step = int(directory.parent.name)
76+
item.save_state(step)
77+
return
78+
79+
# ElasticIterator state is a single global scalar shared by all shards,
80+
# so we write one fixed `process_0.json` from process 0 only. This file
81+
# layout survives changes in `jax.process_count()`.
82+
if isinstance(item, ElasticIterator):
83+
if jax.process_index() == 0:
84+
directory.mkdir(parents=True, exist_ok=True)
85+
filename = directory / "process_0.json"
86+
filename.write_text(json.dumps(item.get_state(), indent=4))
87+
return
88+
7289
def save_single_process(item, process_index, process_count):
7390
filename = directory / f"process_{process_index}-of-{process_count}.json"
7491
if isinstance(item, grain.DatasetIterator):
@@ -95,6 +112,21 @@ def restore(
95112
process_index = getattr(args, "process_index", None)
96113
process_count = getattr(args, "process_count", None)
97114

115+
# In Pathways + colocated_python environment, RemoteIteratorWrapper handles checkpointing
116+
if isinstance(item, RemoteIteratorWrapper):
117+
step = int(directory.parent.name)
118+
item.restore_state(step)
119+
return item
120+
121+
# McJax and Pathways through controller cases
122+
# ElasticIterator: every process reads the same shared `process_0.json`.
123+
if isinstance(item, ElasticIterator):
124+
filename = directory / "process_0.json"
125+
if not filename.exists():
126+
raise ValueError(f"File {filename} does not exist.")
127+
item.set_state(json.loads(filename.read_text()))
128+
return item
129+
98130
def restore_single_process(item, process_index, process_count):
99131
filename = directory / f"process_{process_index}-of-{process_count}.json"
100132
if not filename.exists():
@@ -132,15 +164,6 @@ class GrainCheckpointRestore(ocp.args.CheckpointArgs):
132164
process_count: Optional[int] = None
133165

134166

135-
def _is_remote_iterator(data_iterator):
136-
"""Check if data_iterator is a RemoteIterator or contains RemoteIterator instances."""
137-
if isinstance(data_iterator, RemoteIterator):
138-
return True
139-
if isinstance(data_iterator, list):
140-
return any(isinstance(item, RemoteIterator) for item in data_iterator)
141-
return False
142-
143-
144167
def _load_full_state_from_path(
145168
path,
146169
abstract_unboxed_pre_state,
@@ -482,6 +505,17 @@ def _restore_grain_iterator(
482505
This function dispatches to the correct restore strategy based on
483506
the number of stored checkpoint files vs. current JAX processes.
484507
"""
508+
if isinstance(data_iterator, RemoteIteratorWrapper):
509+
grain_restore_args = GrainCheckpointRestore(item=data_iterator)
510+
restored_state = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_restore_args))
511+
return (restored_state, None)
512+
513+
# ElasticIterator: one shared `process_0.json` regardless of shard count.
514+
if not isinstance(data_iterator, list) and isinstance(data_iterator.local_iterator, ElasticIterator):
515+
grain_restore_args = GrainCheckpointRestore(item=data_iterator.local_iterator)
516+
restored_state = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_restore_args))
517+
return (restored_state, None)
518+
485519
directory = checkpoint_manager.directory / str(step) / "iter"
486520
process_count_jax = jax.process_count()
487521

@@ -625,7 +659,7 @@ def map_to_pspec(data):
625659
None,
626660
)
627661
# Case 2: Matches if dataset type is "grain" and the data iterator is not a
628-
# PlaceHolderDataIterator or RemoteIterator and a specific checkpoint file exists for the iterator
662+
# PlaceHolderDataIterator and a specific checkpoint file exists for the iterator
629663
case (
630664
checkpoint_manager,
631665
dataset_type,
@@ -634,7 +668,6 @@ def map_to_pspec(data):
634668
dataset_type == "grain"
635669
and data_iterator
636670
and not isinstance(data_iterator, PlaceHolderDataIterator)
637-
and not _is_remote_iterator(data_iterator)
638671
and (checkpoint_manager.directory / str(step) / "iter").exists()
639672
):
640673
return _restore_grain_iterator(
@@ -810,22 +843,24 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
810843
)
811844
save_args_composite = {"items": checkpoint_args}
812845

813-
if (
814-
config
815-
and config.dataset_type == "grain"
816-
and not isinstance(data_iterator, PlaceHolderDataIterator)
817-
and not _is_remote_iterator(data_iterator)
818-
):
819-
if not isinstance(data_iterator, list):
820-
data_iterator = [data_iterator]
821-
grain_iters_to_save = []
822-
process_count_total = jax.process_count() * len(data_iterator)
823-
if config.expansion_factor_real_data > 1:
824-
process_count_total = process_count_total // config.expansion_factor_real_data
825-
for i, data_iter in enumerate(data_iterator):
826-
process_index = jax.process_index() + i * jax.process_count()
827-
grain_iters_to_save.append((data_iter.local_iterator, process_index, process_count_total))
828-
save_args_composite["iter"] = GrainCheckpointSave(item=grain_iters_to_save)
846+
if config and config.dataset_type == "grain" and not isinstance(data_iterator, PlaceHolderDataIterator):
847+
if isinstance(data_iterator, RemoteIteratorWrapper):
848+
# Pass the wrapper directly; GrainCheckpointHandler will call save_state with the step
849+
save_args_composite["iter"] = GrainCheckpointSave(item=data_iterator)
850+
elif not isinstance(data_iterator, list) and isinstance(data_iterator.local_iterator, ElasticIterator):
851+
# ElasticIterator checkpoints a single global scalar shared by all shards.
852+
save_args_composite["iter"] = GrainCheckpointSave(item=data_iterator.local_iterator)
853+
else:
854+
if not isinstance(data_iterator, list):
855+
data_iterator = [data_iterator]
856+
grain_iters_to_save = []
857+
process_count_total = jax.process_count() * len(data_iterator)
858+
if config.expansion_factor_real_data > 1:
859+
process_count_total = process_count_total // config.expansion_factor_real_data
860+
for i, data_iter in enumerate(data_iterator):
861+
process_index = jax.process_index() + i * jax.process_count()
862+
grain_iters_to_save.append((data_iter.local_iterator, process_index, process_count_total))
863+
save_args_composite["iter"] = GrainCheckpointSave(item=grain_iters_to_save)
829864

830865
match (checkpoint_manager, config, data_iterator):
831866
case (checkpoint_manager, _, _) if isinstance(

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,7 @@ grain_num_threads_eval: 16
729729
grain_prefetch_buffer_size_eval: 500
730730
grain_data_source_max_workers: 16 # Max workers for ThreadPoolExecutor when mixing multiple Grain data sources.
731731
grain_shuffle_buffer_size: 100 # shuffle buffer when using sequential access formats such as Parquet, TFRecord.
732+
grain_use_elastic_iterator: False # For elastic training, set to this true and packing=False
732733
# for using pathways
733734
colocated_python_data_input: False # experimental feature, under testing
734735

src/maxtext/configs/post_train/dpo.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
base_config: "base.yml"
22

33
use_dpo: true
4+
packing: false
45
train_data_columns: ['chosen', 'rejected']
56
eval_data_columns: ['chosen', 'rejected']
67
base_output_directory: 'gs://maxtext-external/logs'

src/maxtext/configs/types.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,6 +1124,13 @@ class GrainDataset(BaseModel):
11241124
grain_file_type: str = Field(
11251125
"arrayrecord", description="File type for Grain data. Supported: arrayrecord, tfrecord, parquet."
11261126
)
1127+
grain_use_elastic_iterator: bool = Field(
1128+
False,
1129+
description=(
1130+
"Whether to use grain's `ElasticIterator` for data loading. When True, the iterator"
1131+
"checkpoint can be restored after a change in the number of data-loading shards."
1132+
),
1133+
)
11271134
grain_worker_count: int = Field(1, description="Number of workers for Grain data loading.")
11281135
grain_per_worker_buffer_size: int = Field(1, description="Per-worker buffer size for Grain train data loading.")
11291136
grain_worker_count_eval: int = Field(1, description="Number of workers for Grain eval data loading.")
@@ -2630,6 +2637,31 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
26302637
raise ValueError("At most one of `load_parameters_path` or `load_full_state_path` should be set.")
26312638
if self.elastic_enabled and not self.enable_single_controller:
26322639
raise ValueError("Elastic training is only supported with Pathways (`enable_single_controller=True`).")
2640+
if self.grain_use_elastic_iterator and self.grain_file_type != "arrayrecord":
2641+
raise ValueError(
2642+
"`grain_use_elastic_iterator=True` only supports `grain_file_type=arrayrecord`. "
2643+
"tfrecord and parquet pipelines use `InterleaveIterDataset` (a many-to-one "
2644+
"IterDataset transform), which `ElasticIterator` forbids. "
2645+
f"Got grain_file_type={self.grain_file_type}."
2646+
)
2647+
if self.grain_use_elastic_iterator and self.packing:
2648+
raise ValueError("`grain_use_elastic_iterator=True` requires `packing=False`.")
2649+
if self.use_dpo and self.packing:
2650+
raise ValueError("DPO does not support packing. Set `packing=False`.")
2651+
if self.grain_use_elastic_iterator and not self.use_truncation:
2652+
raise ValueError(
2653+
"`grain_use_elastic_iterator=True` requires `use_truncation=True`. "
2654+
"`TokenizeAndChunk` uses `apply`, which produces a many-to-one "
2655+
"IterDataset transform that `ElasticIterator` forbids."
2656+
)
2657+
if self.grain_use_elastic_iterator and (
2658+
self.grain_train_mixture_config_path or ";" in (self.grain_train_files or "")
2659+
):
2660+
raise ValueError(
2661+
"`grain_use_elastic_iterator=True` does not support dataset mixtures. "
2662+
"Set `grain_train_mixture_config_path` to empty and use a single "
2663+
"`grain_train_files` pattern (no ';' separator)."
2664+
)
26332665
if (self.load_parameters_path or self.load_full_state_path) and not self.enable_checkpointing:
26342666
raise ValueError("You must set enable_checkpointing=True to load a checkpoint.")
26352667
if self.enable_multi_tier_checkpointing:

src/maxtext/input_pipeline/data_processing_utils.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,16 @@ def get_local_batch_size(config):
7878
return batch_size
7979

8080

81-
def format_and_batch(dataset, config, batch_size, pad_id, data_columns, tokenizer_model):
82-
"""Packs or pads the dataset according to config and batches it."""
81+
def format_and_batch(dataset, config, batch_size, pad_id, data_columns, tokenizer_model, shift=True):
82+
"""Packs or pads the dataset, batches it, and optionally shifts tokens for next-token prediction.
83+
84+
When `config.grain_use_elastic_iterator` is True, batching is skipped
85+
(ElasticIterator performs it internally) and, if `shift=True`, the shift is
86+
applied pre-batch on axis 0, which is equivalent to a post-batch axis=1 shift.
87+
88+
`shift` should be False for pipelines that don't do next-token prediction
89+
(e.g. DPO, which scores full sequences).
90+
"""
8391
if config.packing:
8492
length_struct = {col: config.max_target_length for col in data_columns}
8593
max_segments = config.max_segments_per_seq
@@ -117,23 +125,24 @@ def format_and_batch(dataset, config, batch_size, pad_id, data_columns, tokenize
117125
else:
118126
dataset = dataset.map(input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id))
119127

128+
if config.grain_use_elastic_iterator:
129+
# ElasticIterator batches internally, so return the pre-batch dataset.
130+
if shift:
131+
dataset = dataset.map(input_pipeline_utils.ShiftData(ignored_ids=[pad_id], axis=0))
132+
return dataset
133+
120134
batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id)
121135
dataset = dataset.batch(batch_size, batch_fn=batch_fn)
136+
if shift:
137+
dataset = dataset.map(input_pipeline_utils.ShiftData(ignored_ids=[pad_id], axis=1))
122138
return dataset
123139

124140

125-
def shift_dataset(dataset, pad_id):
126-
"""Shift tokens to create inputs and targets for standard next-token prediction."""
127-
return dataset.map(
128-
input_pipeline_utils.ShiftData(
129-
ignored_ids=[pad_id],
130-
axis=1,
131-
)
132-
)
133-
134-
135141
def apply_multiprocessing_and_prefetch(dataset, config, grain_worker_count, grain_per_worker_buffer_size):
136142
"""Applies multiprocessing and prefetching configurations to the dataset."""
143+
if config.grain_use_elastic_iterator:
144+
# ElasticIterator applies multiprocessing itself.
145+
return dataset
137146
multiprocessing_options = (
138147
pick_performance_config(
139148
ds=dataset,

0 commit comments

Comments
 (0)