Skip to content

Commit 6c0351f

Browse files
committed
modify data input pipline
1 parent 00068fa commit 6c0351f

1 file changed

Lines changed: 20 additions & 14 deletions

File tree

src/maxtext/input_pipeline/tfds_data_processing_c4_mlperf.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def preprocess_train_dataset(
249249
max_target_length: int,
250250
shuffle_buffer_size: int,
251251
data_shuffle_seed: int,
252-
is_tokenized_dataset: bool = True,
252+
is_tokenized_dataset: bool,
253253
) -> tf.data.Dataset:
254254
"""Preprocess the training dataset."""
255255
if sp_tokenizer.pad_id is not None:
@@ -325,17 +325,19 @@ def make_c4_mlperf_train_iterator(
325325
process_indices,
326326
):
327327
"""Make train iterator of customized C4 dataset for mlperf gpt3 training."""
328-
train_split = "train2"
329-
if config.dataset_name == "c4/en:3.0.5":
330-
is_tokenized_dataset = True
331-
train_split = "train"
332-
elif config.dataset_name == "c4/en:3.0.4":
328+
train_split = "train"
329+
if config.dataset_name == "c4/en:3.0.1":
330+
# gs://max-datasets-rogue/c4/en/3.0.1
333331
is_tokenized_dataset = False
334-
elif config.dataset_name in ["c4/en:3.0.1", "c4/en:3.0.8", "c4/en:3.0.9"]:
332+
elif config.dataset_name == "c4/en:3.0.5":
333+
# gs://mlperf-6-submission/tfds/c4/en/3.0.5
334+
is_tokenized_dataset = True
335+
elif config.dataset_name == "c4/en:3.0.7":
336+
# gs://max-datasets-rogue/c4/en/3.0.7
335337
is_tokenized_dataset = False
338+
train_split = "train2"
336339
else:
337-
raise ValueError(f"{config.dataset_name=} should be one of ('c4/en:3.0.1', 'c4/en:3.0.4', 'c4/en:3.0.5')")
338-
340+
raise ValueError(f"{config.dataset_name=} should be one of " "('c4/en:3.0.1', 'c4/en:3.0.5', 'c4/en:3.0.7')")
339341
train_ds = get_dataset(
340342
dataset_name=config.dataset_name,
341343
split=train_split,
@@ -372,17 +374,21 @@ def make_c4_mlperf_eval_iterator(
372374
process_indices,
373375
):
374376
"""Make eval iterator of customized C4 dataset for mlperf gpt3 training."""
375-
eval_slit = "None"
377+
eval_split = "None"
376378
if config.eval_dataset_name == "c4/en:3.0.5":
377379
is_tokenized_dataset = True
378380
elif config.eval_dataset_name == "c4/en:3.0.4":
379381
is_tokenized_dataset = False
380-
eval_slit = "validation_24567exp"
382+
eval_split = "validation_24567exp"
381383
elif config.eval_dataset_name in ["c4/en:3.0.1", "c4/en:3.0.8", "c4/en:3.0.9"]:
382384
is_tokenized_dataset = False
383-
eval_slit = "validation"
385+
eval_split = "validation"
384386
else:
385-
raise ValueError(f"{config.eval_dataset_name=} should be one of ('c4/en:3.0.1', 'c4/en:3.0.4', 'c4/en:3.0.5')")
387+
raise ValueError(
388+
f"{config.eval_dataset_name=} should be one of "
389+
"('c4/en:3.0.1', 'c4/en:3.0.4', 'c4/en:3.0.5', "
390+
"'c4/en:3.0.8', 'c4/en:3.0.9')"
391+
)
386392

387393
if is_tokenized_dataset:
388394
eval_ds = get_dataset(
@@ -398,7 +404,7 @@ def make_c4_mlperf_eval_iterator(
398404
else:
399405
eval_ds = get_dataset(
400406
dataset_name=config.eval_dataset_name,
401-
split=eval_slit,
407+
split=eval_split,
402408
dataloading_host_index=process_indices.index(jax.process_index()),
403409
dataloading_host_count=len(process_indices),
404410
enable_data_shuffling=False,

0 commit comments

Comments
 (0)