Skip to content

Commit 6739e05

Browse files
lukebaumannGoogle-ML-Automation
authored andcommitted
Fix elastic data loading process count divisibility error in MaxText.
PiperOrigin-RevId: 927504355
1 parent b2153a3 commit 6739e05

3 files changed

Lines changed: 473 additions & 7 deletions

File tree

src/maxtext/input_pipeline/grain_data_processing.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -397,18 +397,48 @@ def _get_pipeline_fn(config):
397397
return pretrain_preprocessing_pipeline
398398

399399

400-
def _make_elastic_iterator(dataset, config, preprocessing_fn, shard_index=None, shard_count=None, mp_opts=None):
400+
def _make_elastic_iterator(
401+
dataset,
402+
config,
403+
preprocessing_fn,
404+
shard_index=None,
405+
shard_count=None,
406+
process_indices=None,
407+
mp_opts=None,
408+
):
401409
"""Applies preprocessing_fn then wraps the result with ElasticIterator.
402410
403-
When shard_index/shard_count are None, defaults to jax.process_index()/jax.process_count().
411+
When shard_index/shard_count are None, defaults to
412+
jax.process_index()/jax.process_count() (or calculated from
413+
process_indices if provided).
414+
415+
Args:
416+
dataset: The input dataset.
417+
config: The hyperparameter configuration.
418+
preprocessing_fn: The function to apply before wrapping.
419+
shard_index: The shard index. Defaults to None.
420+
shard_count: The shard count. Defaults to None.
421+
process_indices: The active process indices. Defaults to None.
422+
mp_opts: Multiprocessing options. Defaults to None.
404423
"""
405424
ds = preprocessing_fn(dataset=dataset)
425+
if shard_index is None:
426+
if process_indices is not None:
427+
shard_index = process_indices.index(jax.process_index())
428+
else:
429+
shard_index = jax.process_index()
430+
if shard_count is None:
431+
if process_indices is not None:
432+
shard_count = len(process_indices)
433+
else:
434+
shard_count = jax.process_count()
435+
406436
return ElasticIterator(
407437
ds,
408438
global_batch_size=config.global_batch_size_to_load,
409439
shard_options=grain.ShardOptions(
410-
shard_index=shard_index if shard_index is not None else jax.process_index(),
411-
shard_count=shard_count if shard_count is not None else jax.process_count(),
440+
shard_index=shard_index,
441+
shard_count=shard_count,
412442
),
413443
read_options=grain.ReadOptions(
414444
num_threads=config.grain_num_threads,
@@ -463,7 +493,12 @@ def make_grain_train_iterator(
463493
# pass to MultiHostDataLoadIterator
464494
if config.colocated_python_data_input:
465495
if config.grain_use_elastic_iterator:
466-
preprocessing_fn = functools.partial(_make_elastic_iterator, config=config, preprocessing_fn=preprocessing_fn)
496+
preprocessing_fn = functools.partial(
497+
_make_elastic_iterator,
498+
config=config,
499+
preprocessing_fn=preprocessing_fn,
500+
process_indices=process_indices,
501+
)
467502

468503
global_shape = (config.global_batch_size_to_load, config.max_target_length)
469504
return multihost_dataloading.RemoteIteratorWrapper(

src/maxtext/input_pipeline/input_pipeline_interface.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,12 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh):
100100
mesh,
101101
)
102102
output_train_iterator = create_process_specific_iterator(config, mesh, process_indices_train, train_iterator)
103+
active_process_count = len(set(d.process_index for d in mesh.devices.flat))
103104
if config.expansion_factor_real_data > 1: # assert number of hosts loading real data
104-
assert len(process_indices_train) == jax.process_count() // config.expansion_factor_real_data
105+
assert (
106+
len(process_indices_train)
107+
== active_process_count // config.expansion_factor_real_data
108+
)
105109

106110
# Generate output eval iterator
107111
output_eval_iterator = None
@@ -115,6 +119,10 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh):
115119
)
116120

117121
if config.expansion_factor_real_data > 1:
118-
assert len(process_indices_eval) == jax.process_count() // config.expansion_factor_real_data
122+
assert (
123+
len(process_indices_eval)
124+
== active_process_count // config.expansion_factor_real_data
125+
)
126+
119127
output_eval_iterator = create_process_specific_iterator(config, mesh, process_indices_eval, eval_iterator)
120128
return output_train_iterator, output_eval_iterator

0 commit comments

Comments
 (0)