Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/maxtext/input_pipeline/input_pipeline_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@
from maxtext.input_pipeline.synthetic_data_processing import SyntheticDataIterator
from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator
from maxtext.utils import max_logging
from maxtext.utils.sharding import remove_size_one_mesh_axis


def get_process_loading_real_data(
data_sharding, global_batch_size_to_load, global_batch_size_to_train_on, max_target_length, mesh
):
"""Get list of processes loading data from GCS when expansion_factor_real_data != -1"""
sharding = jax.sharding.NamedSharding(mesh, P(*data_sharding))
data_sharding_pspec = remove_size_one_mesh_axis(P(*data_sharding), mesh)
sharding = jax.sharding.NamedSharding(mesh, data_sharding_pspec)
devices_indices_map = sharding.devices_indices_map((global_batch_size_to_load, max_target_length))
batch_cutoff = global_batch_size_to_train_on
process_loading_real_data = set()
Expand Down
3 changes: 2 additions & 1 deletion src/maxtext/input_pipeline/synthetic_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from maxtext.input_pipeline import multihost_dataloading
from maxtext.configs import pyconfig
from maxtext.utils import sharding


class SyntheticDataIterator:
Expand All @@ -35,7 +36,7 @@ class SyntheticDataIterator:
def __init__(self, config, mesh):
self.mesh = mesh
self.config = config
data_pspec = P(*config.data_sharding)
data_pspec = sharding.remove_size_one_mesh_axis(P(*config.data_sharding), mesh)
data_pspec_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
self.data_generator = jax.jit(
SyntheticDataIterator.raw_generate_synthetic_data, out_shardings=data_pspec_shardings, static_argnums=0
Expand Down
Loading