Skip to content

Commit 0b1a14c

Browse files
lukebaumannGoogle-ML-Automation
authored andcommitted
Fix elastic data loading process count divisibility error in MaxText.
PiperOrigin-RevId: 927504355
1 parent b2153a3 commit 0b1a14c

3 files changed

Lines changed: 468 additions & 6 deletions

File tree

src/maxtext/input_pipeline/grain_data_processing.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -397,18 +397,37 @@ def _get_pipeline_fn(config):
397397
return pretrain_preprocessing_pipeline
398398

399399

400-
def _make_elastic_iterator(dataset, config, preprocessing_fn, shard_index=None, shard_count=None, mp_opts=None):
400+
def _make_elastic_iterator(
401+
dataset,
402+
config,
403+
preprocessing_fn,
404+
shard_index=None,
405+
shard_count=None,
406+
process_indices=None,
407+
mp_opts=None,
408+
):
401409
"""Applies preprocessing_fn then wraps the result with ElasticIterator.
402410
403411
When shard_index/shard_count are None, defaults to jax.process_index()/jax.process_count().
404412
"""
405413
ds = preprocessing_fn(dataset=dataset)
414+
if shard_index is None:
415+
if process_indices is not None:
416+
shard_index = process_indices.index(jax.process_index())
417+
else:
418+
shard_index = jax.process_index()
419+
if shard_count is None:
420+
if process_indices is not None:
421+
shard_count = len(process_indices)
422+
else:
423+
shard_count = jax.process_count()
424+
406425
return ElasticIterator(
407426
ds,
408427
global_batch_size=config.global_batch_size_to_load,
409428
shard_options=grain.ShardOptions(
410-
shard_index=shard_index if shard_index is not None else jax.process_index(),
411-
shard_count=shard_count if shard_count is not None else jax.process_count(),
429+
shard_index=shard_index,
430+
shard_count=shard_count,
412431
),
413432
read_options=grain.ReadOptions(
414433
num_threads=config.grain_num_threads,
@@ -463,7 +482,12 @@ def make_grain_train_iterator(
463482
# pass to MultiHostDataLoadIterator
464483
if config.colocated_python_data_input:
465484
if config.grain_use_elastic_iterator:
466-
preprocessing_fn = functools.partial(_make_elastic_iterator, config=config, preprocessing_fn=preprocessing_fn)
485+
preprocessing_fn = functools.partial(
486+
_make_elastic_iterator,
487+
config=config,
488+
preprocessing_fn=preprocessing_fn,
489+
process_indices=process_indices,
490+
)
467491

468492
global_shape = (config.global_batch_size_to_load, config.max_target_length)
469493
return multihost_dataloading.RemoteIteratorWrapper(

src/maxtext/input_pipeline/input_pipeline_interface.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,12 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh):
100100
mesh,
101101
)
102102
output_train_iterator = create_process_specific_iterator(config, mesh, process_indices_train, train_iterator)
103+
active_process_count = len(set(d.process_index for d in mesh.devices.flat))
103104
if config.expansion_factor_real_data > 1: # assert number of hosts loading real data
104-
assert len(process_indices_train) == jax.process_count() // config.expansion_factor_real_data
105+
assert (
106+
len(process_indices_train)
107+
== active_process_count // config.expansion_factor_real_data
108+
)
105109

106110
# Generate output eval iterator
107111
output_eval_iterator = None
@@ -115,6 +119,10 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh):
115119
)
116120

117121
if config.expansion_factor_real_data > 1:
118-
assert len(process_indices_eval) == jax.process_count() // config.expansion_factor_real_data
122+
assert (
123+
len(process_indices_eval)
124+
== active_process_count // config.expansion_factor_real_data
125+
)
126+
119127
output_eval_iterator = create_process_specific_iterator(config, mesh, process_indices_eval, eval_iterator)
120128
return output_train_iterator, output_eval_iterator

0 commit comments

Comments
 (0)