Skip to content

Commit aa3c81d

Browse files
committed
modify data input pipline
1 parent dff1d0c commit aa3c81d

2 files changed

Lines changed: 38 additions & 8 deletions

File tree

src/maxtext/input_pipeline/input_pipeline_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343

4444

4545
def normalize_features(x, column_name):
46+
if column_name not in x and "ids" in x:
47+
column_name = "ids"
4648
return {"inputs": x[column_name], "targets": x[column_name]}
4749

4850

@@ -85,7 +87,8 @@ def _process_string(string_tensor):
8587
return [modified_string]
8688

8789
for k in data_keys:
88-
features[k] = tf.py_function(_process_string, [features[k]], Tout=[tf.int32])[0]
90+
if features[k].dtype == tf.string:
91+
features[k] = tf.py_function(_process_string, [features[k]], Tout=[tf.int32])[0]
8992
return features
9093

9194

src/maxtext/input_pipeline/tfds_data_processing_c4_mlperf.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def preprocess_train_dataset(
249249
max_target_length: int,
250250
shuffle_buffer_size: int,
251251
data_shuffle_seed: int,
252+
is_tokenized_dataset: bool,
252253
) -> tf.data.Dataset:
253254
"""Preprocess the training dataset."""
254255
if sp_tokenizer.pad_id is not None:
@@ -257,10 +258,13 @@ def preprocess_train_dataset(
257258
pad_id = sp_tokenizer.unk_id
258259
else:
259260
pad_id = -1
260-
train_ds = train_ds.map(
261-
lambda x: TokenizeOp(tokenizer_model=sp_tokenizer, features=x, data_keys=("targets",)),
262-
num_parallel_calls=AUTOTUNE,
263-
)
261+
262+
if not is_tokenized_dataset:
263+
train_ds = train_ds.map(
264+
lambda x: TokenizeOp(tokenizer_model=sp_tokenizer, features=x, data_keys=("targets",)),
265+
num_parallel_calls=AUTOTUNE,
266+
)
267+
264268
train_ds = reduce_concat_tokens(train_ds, feature_key="targets", batch_size=4096)
265269
train_ds = split_tokens_to_targets_length(train_ds, max_target_length)
266270
train_ds = train_ds.shuffle(shuffle_buffer_size, seed=data_shuffle_seed)
@@ -321,15 +325,33 @@ def make_c4_mlperf_train_iterator(
321325
process_indices,
322326
):
323327
"""Make train iterator of customized C4 dataset for mlperf gpt3 training."""
328+
train_split = "train2"
329+
if config.dataset_name == "c4/en:3.0.5":
330+
is_tokenized_dataset = True
331+
train_split = "train"
332+
elif config.dataset_name == "c4/en:3.0.4":
333+
is_tokenized_dataset = False
334+
elif config.dataset_name in ["c4/en:3.0.1", "c4/en:3.0.8", "c4/en:3.0.9"]:
335+
is_tokenized_dataset = False
336+
else:
337+
raise ValueError(
338+
f"{config.dataset_name=} should be one of "
339+
"('c4/en:3.0.1', 'c4/en:3.0.4', 'c4/en:3.0.5', "
340+
"'c4/en:3.0.8', 'c4/en:3.0.9')"
341+
)
324342
train_ds = get_dataset(
325343
dataset_name=config.dataset_name,
326-
split="train2",
344+
split=train_split,
327345
dataloading_host_index=process_indices.index(jax.process_index()),
328346
dataloading_host_count=len(process_indices),
329347
enable_data_shuffling=config.enable_data_shuffling,
330348
data_shuffle_seed=config.data_shuffle_seed,
331349
)
332-
train_ds = rekey(train_ds, {"inputs": None, "targets": "text"})
350+
351+
if is_tokenized_dataset:
352+
train_ds = rekey(train_ds, {"inputs": None, "targets": "ids"})
353+
else:
354+
train_ds = rekey(train_ds, {"inputs": None, "targets": "text"})
333355

334356
sp_tokenizer = get_tokenizer(
335357
config.tokenizer_path, config.tokenizer_type, config.add_bos, config.add_eos, config.hf_access_token
@@ -341,6 +363,7 @@ def make_c4_mlperf_train_iterator(
341363
max_target_length=config.max_target_length,
342364
shuffle_buffer_size=128,
343365
data_shuffle_seed=config.data_shuffle_seed,
366+
is_tokenized_dataset=is_tokenized_dataset,
344367
)
345368
train_multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(train_ds, global_mesh)
346369
return train_multihost_gen
@@ -362,7 +385,11 @@ def make_c4_mlperf_eval_iterator(
362385
is_tokenized_dataset = False
363386
eval_slit = "validation"
364387
else:
365-
raise ValueError(f"{config.eval_dataset_name=} should be one of ('c4/en:3.0.1', 'c4/en:3.0.4', 'c4/en:3.0.5')")
388+
raise ValueError(
389+
f"{config.dataset_name=} should be one of "
390+
"('c4/en:3.0.1', 'c4/en:3.0.4', 'c4/en:3.0.5', "
391+
"'c4/en:3.0.8', 'c4/en:3.0.9')"
392+
)
366393

367394
if is_tokenized_dataset:
368395
eval_ds = get_dataset(

0 commit comments

Comments
 (0)