diff --git a/src/maxtext/input_pipeline/input_pipeline_interface.py b/src/maxtext/input_pipeline/input_pipeline_interface.py index ac37a7bdda..a3d2a781af 100644 --- a/src/maxtext/input_pipeline/input_pipeline_interface.py +++ b/src/maxtext/input_pipeline/input_pipeline_interface.py @@ -30,13 +30,15 @@ from maxtext.input_pipeline.synthetic_data_processing import SyntheticDataIterator from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator from maxtext.utils import max_logging +from maxtext.utils.sharding import remove_size_one_mesh_axis def get_process_loading_real_data( data_sharding, global_batch_size_to_load, global_batch_size_to_train_on, max_target_length, mesh ): """Get list of processes loading data from GCS when expansion_factor_real_data != -1""" - sharding = jax.sharding.NamedSharding(mesh, P(*data_sharding)) + data_sharding_pspec = remove_size_one_mesh_axis(P(*data_sharding), mesh) + sharding = jax.sharding.NamedSharding(mesh, data_sharding_pspec) devices_indices_map = sharding.devices_indices_map((global_batch_size_to_load, max_target_length)) batch_cutoff = global_batch_size_to_train_on process_loading_real_data = set() diff --git a/src/maxtext/input_pipeline/synthetic_data_processing.py b/src/maxtext/input_pipeline/synthetic_data_processing.py index bc739ded69..31d3cfc814 100644 --- a/src/maxtext/input_pipeline/synthetic_data_processing.py +++ b/src/maxtext/input_pipeline/synthetic_data_processing.py @@ -25,6 +25,7 @@ from maxtext.input_pipeline import multihost_dataloading from maxtext.configs import pyconfig +from maxtext.utils import sharding class SyntheticDataIterator: @@ -35,7 +36,7 @@ class SyntheticDataIterator: def __init__(self, config, mesh): self.mesh = mesh self.config = config - data_pspec = P(*config.data_sharding) + data_pspec = sharding.remove_size_one_mesh_axis(P(*config.data_sharding), mesh) data_pspec_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) self.data_generator = jax.jit( SyntheticDataIterator.raw_generate_synthetic_data, out_shardings=data_pspec_shardings, static_argnums=0