Skip to content

Commit 183ea74

Browse files
committed
add remove size one func to eval dataloader, otherwise eval gets error in explicit shard mode
1 parent c9df820 commit 183ea74

2 files changed

Lines changed: 5 additions & 2 deletions

File tree

src/maxtext/input_pipeline/input_pipeline_interface.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@
3030
from maxtext.input_pipeline.synthetic_data_processing import SyntheticDataIterator
3131
from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator
3232
from maxtext.utils import max_logging
33+
from maxtext.utils.sharding import remove_size_one_mesh_axis
3334

3435

3536
def get_process_loading_real_data(
3637
data_sharding, global_batch_size_to_load, global_batch_size_to_train_on, max_target_length, mesh
3738
):
3839
"""Get list of processes loading data from GCS when expansion_factor_real_data != -1"""
39-
sharding = jax.sharding.NamedSharding(mesh, P(*data_sharding))
40+
data_sharding_pspec = remove_size_one_mesh_axis(P(*data_sharding), mesh)
41+
sharding = jax.sharding.NamedSharding(mesh, data_sharding_pspec)
4042
devices_indices_map = sharding.devices_indices_map((global_batch_size_to_load, max_target_length))
4143
batch_cutoff = global_batch_size_to_train_on
4244
process_loading_real_data = set()

src/maxtext/input_pipeline/synthetic_data_processing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from maxtext.input_pipeline import multihost_dataloading
2727
from maxtext.configs import pyconfig
28+
from maxtext.utils import sharding
2829

2930

3031
class SyntheticDataIterator:
@@ -35,7 +36,7 @@ class SyntheticDataIterator:
3536
def __init__(self, config, mesh):
3637
self.mesh = mesh
3738
self.config = config
38-
data_pspec = P(*config.data_sharding)
39+
data_pspec = sharding.remove_size_one_mesh_axis(P(*config.data_sharding), mesh)
3940
data_pspec_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
4041
self.data_generator = jax.jit(
4142
SyntheticDataIterator.raw_generate_synthetic_data, out_shardings=data_pspec_shardings, static_argnums=0

0 commit comments

Comments
 (0)