Skip to content

Commit 947c587

Browse files
committed
support elastic data checkpoint
1 parent 3616eb3 commit 947c587

8 files changed

Lines changed: 364 additions & 335 deletions

File tree

src/maxtext/common/checkpointing.py

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343

4444
import grain
4545
from grain.python import PyGrainCheckpointHandler
46+
from grain.experimental import ElasticIterator
4647

4748
CheckpointManager = ocp.CheckpointManager
4849
CheckpointManagerOptions = ocp.CheckpointManagerOptions
@@ -68,6 +69,22 @@ def save(
6869
"""Saves the given iterator to the checkpoint in `directory`."""
6970
item = item or args.item # pytype:disable=attribute-error
7071

72+
# RemoteIteratorWrapper handles checkpointing via colocated python
73+
if isinstance(item, RemoteIteratorWrapper):
74+
step = int(directory.parent.name)
75+
item.save_state(step)
76+
return
77+
78+
# ElasticIterator state is a single global scalar shared by all shards,
79+
# so we write one fixed `process_0.json` from process 0 only. This file
80+
# layout survives changes in `jax.process_count()`.
81+
if isinstance(item, ElasticIterator):
82+
if jax.process_index() == 0:
83+
directory.mkdir(parents=True, exist_ok=True)
84+
filename = directory / "process_0.json"
85+
filename.write_text(json.dumps(item.get_state(), indent=4))
86+
return
87+
7188
def save_single_process(item, process_index, process_count):
7289
filename = directory / f"process_{process_index}-of-{process_count}.json"
7390
if isinstance(item, grain.DatasetIterator):
@@ -94,6 +111,20 @@ def restore(
94111
process_index = getattr(args, "process_index", None)
95112
process_count = getattr(args, "process_count", None)
96113

114+
# RemoteIteratorWrapper handles checkpointing via colocated python
115+
if isinstance(item, RemoteIteratorWrapper):
116+
step = int(directory.parent.name)
117+
item.restore_state(step)
118+
return item
119+
120+
# ElasticIterator: every process reads the same shared `process_0.json`.
121+
if isinstance(item, ElasticIterator):
122+
filename = directory / "process_0.json"
123+
if not filename.exists():
124+
raise ValueError(f"File {filename} does not exist.")
125+
item.set_state(json.loads(filename.read_text()))
126+
return item
127+
97128
def restore_single_process(item, process_index, process_count):
98129
filename = directory / f"process_{process_index}-of-{process_count}.json"
99130
if not filename.exists():
@@ -131,15 +162,6 @@ class GrainCheckpointRestore(ocp.args.CheckpointArgs):
131162
process_count: Optional[int] = None
132163

133164

134-
def _is_remote_iterator(data_iterator):
135-
"""Check if data_iterator is a RemoteIteratorWrapper or contains RemoteIteratorWrapper instances."""
136-
if isinstance(data_iterator, RemoteIteratorWrapper):
137-
return True
138-
if isinstance(data_iterator, list):
139-
return any(isinstance(item, RemoteIteratorWrapper) for item in data_iterator)
140-
return False
141-
142-
143165
def _load_full_state_from_path(
144166
path,
145167
abstract_unboxed_pre_state,
@@ -481,6 +503,17 @@ def _restore_grain_iterator(
481503
This function dispatches to the correct restore strategy based on
482504
the number of stored checkpoint files vs. current JAX processes.
483505
"""
506+
if isinstance(data_iterator, RemoteIteratorWrapper):
507+
grain_restore_args = GrainCheckpointRestore(item=data_iterator)
508+
restored_state = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_restore_args))
509+
return (restored_state, None)
510+
511+
# ElasticIterator: one shared `process_0.json` regardless of shard count.
512+
if not isinstance(data_iterator, list) and isinstance(data_iterator.local_iterator, ElasticIterator):
513+
grain_restore_args = GrainCheckpointRestore(item=data_iterator.local_iterator)
514+
restored_state = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_restore_args))
515+
return (restored_state, None)
516+
484517
directory = checkpoint_manager.directory / str(step) / "iter"
485518
process_count_jax = jax.process_count()
486519

@@ -619,7 +652,7 @@ def map_to_pspec(data):
619652
None,
620653
)
621654
# Case 2: Matches if dataset type is "grain" and the data iterator is not a
622-
# PlaceHolderDataIterator or RemoteIteratorWrapper and a specific checkpoint file exists for the iterator
655+
# PlaceHolderDataIterator and a specific checkpoint file exists for the iterator
623656
case (
624657
checkpoint_manager,
625658
dataset_type,
@@ -628,7 +661,6 @@ def map_to_pspec(data):
628661
dataset_type == "grain"
629662
and data_iterator
630663
and not isinstance(data_iterator, PlaceHolderDataIterator)
631-
and not _is_remote_iterator(data_iterator)
632664
and (checkpoint_manager.directory / str(step) / "iter").exists()
633665
):
634666
return _restore_grain_iterator(
@@ -790,22 +822,24 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
790822
)
791823
save_args_composite = {"items": checkpoint_args}
792824

793-
if (
794-
config
795-
and config.dataset_type == "grain"
796-
and not isinstance(data_iterator, PlaceHolderDataIterator)
797-
and not _is_remote_iterator(data_iterator)
798-
):
825+
if config and config.dataset_type == "grain" and not isinstance(data_iterator, PlaceHolderDataIterator):
799826
if not isinstance(data_iterator, list):
800827
data_iterator = [data_iterator]
801-
grain_iters_to_save = []
802-
process_count_total = jax.process_count() * len(data_iterator)
803-
if config.expansion_factor_real_data > 1:
804-
process_count_total = process_count_total // config.expansion_factor_real_data
805-
for i, data_iter in enumerate(data_iterator):
806-
process_index = jax.process_index() + i * jax.process_count()
807-
grain_iters_to_save.append((data_iter.local_iterator, process_index, process_count_total))
808-
save_args_composite["iter"] = GrainCheckpointSave(item=grain_iters_to_save)
828+
if isinstance(data_iterator[0], RemoteIteratorWrapper):
829+
# Pass the wrapper directly; GrainCheckpointHandler will call save_state with the step
830+
save_args_composite["iter"] = GrainCheckpointSave(item=data_iterator[0])
831+
elif isinstance(data_iterator[0].local_iterator, ElasticIterator):
832+
# ElasticIterator checkpoints a single global scalar shared by all shards.
833+
save_args_composite["iter"] = GrainCheckpointSave(item=data_iterator[0].local_iterator)
834+
else:
835+
grain_iters_to_save = []
836+
process_count_total = jax.process_count() * len(data_iterator)
837+
if config.expansion_factor_real_data > 1:
838+
process_count_total = process_count_total // config.expansion_factor_real_data
839+
for i, data_iter in enumerate(data_iterator):
840+
process_index = jax.process_index() + i * jax.process_count()
841+
grain_iters_to_save.append((data_iter.local_iterator, process_index, process_count_total))
842+
save_args_composite["iter"] = GrainCheckpointSave(item=grain_iters_to_save)
809843

810844
match (checkpoint_manager, config, data_iterator):
811845
case (checkpoint_manager, _, _) if isinstance(

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,7 @@ grain_num_threads_eval: 16
713713
grain_prefetch_buffer_size_eval: 500
714714
grain_data_source_max_workers: 16 # Max workers for ThreadPoolExecutor when mixing multiple Grain data sources.
715715
grain_shuffle_buffer_size: 100 # shuffle buffer when using sequential access formats such as Parquet, TFRecord.
716+
grain_use_elastic_iterator: False # For elastic training, set to this true and packing=False
716717
# for using pathways
717718
colocated_python_data_input: False # experimental feature, under testing
718719

src/maxtext/configs/post_train/dpo.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
base_config: "base.yml"
22

33
use_dpo: true
4+
packing: false
45
train_data_columns: ['chosen', 'rejected']
56
eval_data_columns: ['chosen', 'rejected']
67
base_output_directory: 'gs://maxtext-external/logs'

src/maxtext/configs/types.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,13 @@ class GrainDataset(BaseModel):
10961096
grain_file_type: str = Field(
10971097
"arrayrecord", description="File type for Grain data. Supported: arrayrecord, tfrecord, parquet."
10981098
)
1099+
grain_use_elastic_iterator: bool = Field(
1100+
False,
1101+
description=(
1102+
"Whether to use grain's `ElasticIterator` for data loading. When True, the iterator"
1103+
"checkpoint can be restored after a change in the number of data-loading shards."
1104+
),
1105+
)
10991106
grain_worker_count: int = Field(1, description="Number of workers for Grain data loading.")
11001107
grain_per_worker_buffer_size: int = Field(1, description="Per-worker buffer size for Grain train data loading.")
11011108
grain_worker_count_eval: int = Field(1, description="Number of workers for Grain eval data loading.")
@@ -2493,6 +2500,31 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24932500
raise ValueError("At most one of `load_parameters_path` or `load_full_state_path` should be set.")
24942501
if self.elastic_enabled and not self.enable_single_controller:
24952502
raise ValueError("Elastic training is only supported with Pathways (`enable_single_controller=True`).")
2503+
if self.grain_use_elastic_iterator and self.grain_file_type != "arrayrecord":
2504+
raise ValueError(
2505+
"`grain_use_elastic_iterator=True` only supports `grain_file_type=arrayrecord`. "
2506+
"tfrecord and parquet pipelines use `InterleaveIterDataset` (a many-to-one "
2507+
"IterDataset transform), which `ElasticIterator` forbids. "
2508+
f"Got grain_file_type={self.grain_file_type}."
2509+
)
2510+
if self.grain_use_elastic_iterator and self.packing:
2511+
raise ValueError("`grain_use_elastic_iterator=True` requires `packing=False`.")
2512+
if self.use_dpo and self.packing:
2513+
raise ValueError("DPO does not support packing. Set `packing=False`.")
2514+
if self.grain_use_elastic_iterator and not self.use_truncation:
2515+
raise ValueError(
2516+
"`grain_use_elastic_iterator=True` requires `use_truncation=True`. "
2517+
"`TokenizeAndChunk` uses `apply`, which produces a many-to-one "
2518+
"IterDataset transform that `ElasticIterator` forbids."
2519+
)
2520+
if self.grain_use_elastic_iterator and (
2521+
self.grain_train_mixture_config_path or ";" in (self.grain_train_files or "")
2522+
):
2523+
raise ValueError(
2524+
"`grain_use_elastic_iterator=True` does not support dataset mixtures. "
2525+
"Set `grain_train_mixture_config_path` to empty and use a single "
2526+
"`grain_train_files` pattern (no ';' separator)."
2527+
)
24962528
if (self.load_parameters_path or self.load_full_state_path) and not self.enable_checkpointing:
24972529
raise ValueError("You must set enable_checkpointing=True to load a checkpoint.")
24982530
if self.enable_multi_tier_checkpointing:

src/maxtext/input_pipeline/data_processing_utils.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,16 @@ def get_local_batch_size(config):
7474
return batch_size
7575

7676

77-
def format_and_batch(dataset, config, batch_size, pad_id, data_columns, tokenizer_model):
78-
"""Packs or pads the dataset according to config and batches it."""
77+
def format_and_batch(dataset, config, batch_size, pad_id, data_columns, tokenizer_model, shift=True):
78+
"""Packs or pads the dataset, batches it, and optionally shifts tokens for next-token prediction.
79+
80+
When `config.grain_use_elastic_iterator` is True, batching is skipped
81+
(ElasticIterator performs it internally) and, if `shift=True`, the shift is
82+
applied pre-batch on axis 0, which is equivalent to a post-batch axis=1 shift.
83+
84+
`shift` should be False for pipelines that don't do next-token prediction
85+
(e.g. DPO, which scores full sequences).
86+
"""
7987
if config.packing:
8088
length_struct = {col: config.max_target_length for col in data_columns}
8189
max_segments = config.max_segments_per_seq
@@ -113,23 +121,24 @@ def format_and_batch(dataset, config, batch_size, pad_id, data_columns, tokenize
113121
else:
114122
dataset = dataset.map(input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id))
115123

124+
if config.grain_use_elastic_iterator:
125+
# ElasticIterator batches internally, so return the pre-batch dataset.
126+
if shift:
127+
dataset = dataset.map(input_pipeline_utils.ShiftData(ignored_ids=[pad_id], axis=0))
128+
return dataset
129+
116130
batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id)
117131
dataset = dataset.batch(batch_size, batch_fn=batch_fn)
132+
if shift:
133+
dataset = dataset.map(input_pipeline_utils.ShiftData(ignored_ids=[pad_id], axis=1))
118134
return dataset
119135

120136

121-
def shift_dataset(dataset, pad_id):
122-
"""Shift tokens to create inputs and targets for standard next-token prediction."""
123-
return dataset.map(
124-
input_pipeline_utils.ShiftData(
125-
ignored_ids=[pad_id],
126-
axis=1,
127-
)
128-
)
129-
130-
131137
def apply_multiprocessing_and_prefetch(dataset, config, grain_worker_count, grain_per_worker_buffer_size):
132138
"""Applies multiprocessing and prefetching configurations to the dataset."""
139+
if config.grain_use_elastic_iterator:
140+
# ElasticIterator applies multiprocessing itself.
141+
return dataset
133142
multiprocessing_options = (
134143
pick_performance_config(
135144
ds=dataset,

0 commit comments

Comments
 (0)