Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 40 additions & 5 deletions src/maxtext/input_pipeline/grain_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,18 +397,48 @@ def _get_pipeline_fn(config):
return pretrain_preprocessing_pipeline


def _make_elastic_iterator(dataset, config, preprocessing_fn, shard_index=None, shard_count=None, mp_opts=None):
def _make_elastic_iterator(
dataset,
config,
preprocessing_fn,
shard_index=None,
shard_count=None,
process_indices=None,
mp_opts=None,
):
"""Applies preprocessing_fn then wraps the result with ElasticIterator.

When shard_index/shard_count are None, defaults to jax.process_index()/jax.process_count().
When shard_index/shard_count are None, defaults to
jax.process_index()/jax.process_count() (or calculated from
process_indices if provided).

Args:
dataset: The input dataset.
config: The hyperparameter configuration.
preprocessing_fn: The function to apply before wrapping.
shard_index: The shard index. Defaults to None.
shard_count: The shard count. Defaults to None.
process_indices: The active process indices. Defaults to None.
mp_opts: Multiprocessing options. Defaults to None.
"""
ds = preprocessing_fn(dataset=dataset)
if shard_index is None:
if process_indices is not None:
shard_index = process_indices.index(jax.process_index())
else:
shard_index = jax.process_index()
if shard_count is None:
if process_indices is not None:
shard_count = len(process_indices)
else:
shard_count = jax.process_count()

return ElasticIterator(
ds,
global_batch_size=config.global_batch_size_to_load,
shard_options=grain.ShardOptions(
shard_index=shard_index if shard_index is not None else jax.process_index(),
shard_count=shard_count if shard_count is not None else jax.process_count(),
shard_index=shard_index,
shard_count=shard_count,
),
read_options=grain.ReadOptions(
num_threads=config.grain_num_threads,
Expand Down Expand Up @@ -463,7 +493,12 @@ def make_grain_train_iterator(
# pass to MultiHostDataLoadIterator
if config.colocated_python_data_input:
if config.grain_use_elastic_iterator:
preprocessing_fn = functools.partial(_make_elastic_iterator, config=config, preprocessing_fn=preprocessing_fn)
preprocessing_fn = functools.partial(
_make_elastic_iterator,
config=config,
preprocessing_fn=preprocessing_fn,
process_indices=process_indices,
)

global_shape = (config.global_batch_size_to_load, config.max_target_length)
return multihost_dataloading.RemoteIteratorWrapper(
Expand Down
12 changes: 10 additions & 2 deletions src/maxtext/input_pipeline/input_pipeline_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,12 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh):
mesh,
)
output_train_iterator = create_process_specific_iterator(config, mesh, process_indices_train, train_iterator)
active_process_count = len(set(d.process_index for d in mesh.devices.flat))
if config.expansion_factor_real_data > 1: # assert number of hosts loading real data
assert len(process_indices_train) == jax.process_count() // config.expansion_factor_real_data
assert (
len(process_indices_train)
== active_process_count // config.expansion_factor_real_data
)

# Generate output eval iterator
output_eval_iterator = None
Expand All @@ -115,6 +119,10 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh):
)

if config.expansion_factor_real_data > 1:
assert len(process_indices_eval) == jax.process_count() // config.expansion_factor_real_data
assert (
len(process_indices_eval)
== active_process_count // config.expansion_factor_real_data
)

output_eval_iterator = create_process_specific_iterator(config, mesh, process_indices_eval, eval_iterator)
return output_train_iterator, output_eval_iterator
Loading
Loading