Skip to content

Commit 76a7ecb

Browse files
Merge pull request #3747 from AI-Hypercomputer:perf_input_pipeline
PiperOrigin-RevId: 906506791
2 parents cbb413f + 6c0351f commit 76a7ecb

2 files changed

Lines changed: 41 additions & 12 deletions

File tree

src/maxtext/input_pipeline/input_pipeline_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343

4444

4545
def normalize_features(x, column_name):
46+
if column_name not in x and "ids" in x:
47+
column_name = "ids"
4648
return {"inputs": x[column_name], "targets": x[column_name]}
4749

4850

@@ -85,7 +87,8 @@ def _process_string(string_tensor):
8587
return [modified_string]
8688

8789
for k in data_keys:
88-
features[k] = tf.py_function(_process_string, [features[k]], Tout=[tf.int32])[0]
90+
if features[k].dtype == tf.string:
91+
features[k] = tf.py_function(_process_string, [features[k]], Tout=[tf.int32])[0]
8992
return features
9093

9194

src/maxtext/input_pipeline/tfds_data_processing_c4_mlperf.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def preprocess_train_dataset(
249249
max_target_length: int,
250250
shuffle_buffer_size: int,
251251
data_shuffle_seed: int,
252+
is_tokenized_dataset: bool,
252253
) -> tf.data.Dataset:
253254
"""Preprocess the training dataset."""
254255
if sp_tokenizer.pad_id is not None:
@@ -257,10 +258,13 @@ def preprocess_train_dataset(
257258
pad_id = sp_tokenizer.unk_id
258259
else:
259260
pad_id = -1
260-
train_ds = train_ds.map(
261-
lambda x: TokenizeOp(tokenizer_model=sp_tokenizer, features=x, data_keys=("targets",)),
262-
num_parallel_calls=AUTOTUNE,
263-
)
261+
262+
if not is_tokenized_dataset:
263+
train_ds = train_ds.map(
264+
lambda x: TokenizeOp(tokenizer_model=sp_tokenizer, features=x, data_keys=("targets",)),
265+
num_parallel_calls=AUTOTUNE,
266+
)
267+
264268
train_ds = reduce_concat_tokens(train_ds, feature_key="targets", batch_size=4096)
265269
train_ds = split_tokens_to_targets_length(train_ds, max_target_length)
266270
train_ds = train_ds.shuffle(shuffle_buffer_size, seed=data_shuffle_seed)
@@ -321,15 +325,32 @@ def make_c4_mlperf_train_iterator(
321325
process_indices,
322326
):
323327
"""Make train iterator of customized C4 dataset for mlperf gpt3 training."""
328+
train_split = "train"
329+
if config.dataset_name == "c4/en:3.0.1":
330+
# gs://max-datasets-rogue/c4/en/3.0.1
331+
is_tokenized_dataset = False
332+
elif config.dataset_name == "c4/en:3.0.5":
333+
# gs://mlperf-6-submission/tfds/c4/en/3.0.5
334+
is_tokenized_dataset = True
335+
elif config.dataset_name == "c4/en:3.0.7":
336+
# gs://max-datasets-rogue/c4/en/3.0.7
337+
is_tokenized_dataset = False
338+
train_split = "train2"
339+
else:
340+
raise ValueError(f"{config.dataset_name=} should be one of " "('c4/en:3.0.1', 'c4/en:3.0.5', 'c4/en:3.0.7')")
324341
train_ds = get_dataset(
325342
dataset_name=config.dataset_name,
326-
split="train2",
343+
split=train_split,
327344
dataloading_host_index=process_indices.index(jax.process_index()),
328345
dataloading_host_count=len(process_indices),
329346
enable_data_shuffling=config.enable_data_shuffling,
330347
data_shuffle_seed=config.data_shuffle_seed,
331348
)
332-
train_ds = rekey(train_ds, {"inputs": None, "targets": "text"})
349+
350+
if is_tokenized_dataset:
351+
train_ds = rekey(train_ds, {"inputs": None, "targets": "ids"})
352+
else:
353+
train_ds = rekey(train_ds, {"inputs": None, "targets": "text"})
333354

334355
sp_tokenizer = get_tokenizer(
335356
config.tokenizer_path, config.tokenizer_type, config.add_bos, config.add_eos, config.hf_access_token
@@ -341,6 +362,7 @@ def make_c4_mlperf_train_iterator(
341362
max_target_length=config.max_target_length,
342363
shuffle_buffer_size=128,
343364
data_shuffle_seed=config.data_shuffle_seed,
365+
is_tokenized_dataset=is_tokenized_dataset,
344366
)
345367
train_multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(train_ds, global_mesh)
346368
return train_multihost_gen
@@ -352,17 +374,21 @@ def make_c4_mlperf_eval_iterator(
352374
process_indices,
353375
):
354376
"""Make eval iterator of customized C4 dataset for mlperf gpt3 training."""
355-
eval_slit = "None"
377+
eval_split = "None"
356378
if config.eval_dataset_name == "c4/en:3.0.5":
357379
is_tokenized_dataset = True
358380
elif config.eval_dataset_name == "c4/en:3.0.4":
359381
is_tokenized_dataset = False
360-
eval_slit = "validation_24567exp"
382+
eval_split = "validation_24567exp"
361383
elif config.eval_dataset_name in ["c4/en:3.0.1", "c4/en:3.0.8", "c4/en:3.0.9"]:
362384
is_tokenized_dataset = False
363-
eval_slit = "validation"
385+
eval_split = "validation"
364386
else:
365-
raise ValueError(f"{config.eval_dataset_name=} should be one of ('c4/en:3.0.1', 'c4/en:3.0.4', 'c4/en:3.0.5')")
387+
raise ValueError(
388+
f"{config.eval_dataset_name=} should be one of "
389+
"('c4/en:3.0.1', 'c4/en:3.0.4', 'c4/en:3.0.5', "
390+
"'c4/en:3.0.8', 'c4/en:3.0.9')"
391+
)
366392

367393
if is_tokenized_dataset:
368394
eval_ds = get_dataset(
@@ -378,7 +404,7 @@ def make_c4_mlperf_eval_iterator(
378404
else:
379405
eval_ds = get_dataset(
380406
dataset_name=config.eval_dataset_name,
381-
split=eval_slit,
407+
split=eval_split,
382408
dataloading_host_index=process_indices.index(jax.process_index()),
383409
dataloading_host_count=len(process_indices),
384410
enable_data_shuffling=False,

0 commit comments

Comments
 (0)