Skip to content

Commit abfdb42

Browse files
Merge pull request #3811 from AI-Hypercomputer:input_fix
PiperOrigin-RevId: 910356301
2 parents e1b0d6e + 7a488f9 commit abfdb42

2 files changed

Lines changed: 7 additions & 30 deletions

File tree

src/maxtext/input_pipeline/input_pipeline_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@
4343

4444

4545
def normalize_features(x, column_name):
46-
if column_name not in x and "ids" in x:
47-
column_name = "ids"
4846
return {"inputs": x[column_name], "targets": x[column_name]}
4947

5048

@@ -87,8 +85,7 @@ def _process_string(string_tensor):
8785
return [modified_string]
8886

8987
for k in data_keys:
90-
if features[k].dtype == tf.string:
91-
features[k] = tf.py_function(_process_string, [features[k]], Tout=[tf.int32])[0]
88+
features[k] = tf.py_function(_process_string, [features[k]], Tout=[tf.int32])[0]
9289
return features
9390

9491

src/maxtext/input_pipeline/tfds_data_processing_c4_mlperf.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,6 @@ def preprocess_train_dataset(
249249
max_target_length: int,
250250
shuffle_buffer_size: int,
251251
data_shuffle_seed: int,
252-
is_tokenized_dataset: bool,
253252
) -> tf.data.Dataset:
254253
"""Preprocess the training dataset."""
255254
if sp_tokenizer.pad_id is not None:
@@ -259,11 +258,10 @@ def preprocess_train_dataset(
259258
else:
260259
pad_id = -1
261260

262-
if not is_tokenized_dataset:
263-
train_ds = train_ds.map(
264-
lambda x: TokenizeOp(tokenizer_model=sp_tokenizer, features=x, data_keys=("targets",)),
265-
num_parallel_calls=AUTOTUNE,
266-
)
261+
train_ds = train_ds.map(
262+
lambda x: TokenizeOp(tokenizer_model=sp_tokenizer, features=x, data_keys=("targets",)),
263+
num_parallel_calls=AUTOTUNE,
264+
)
267265

268266
train_ds = reduce_concat_tokens(train_ds, feature_key="targets", batch_size=4096)
269267
train_ds = split_tokens_to_targets_length(train_ds, max_target_length)
@@ -325,33 +323,16 @@ def make_c4_mlperf_train_iterator(
325323
process_indices,
326324
):
327325
"""Make train iterator of customized C4 dataset for mlperf gpt3 training."""
328-
train_split = "train"
329-
if config.dataset_name == "c4/en:3.0.1":
330-
# gs://max-datasets-rogue/c4/en/3.0.1
331-
is_tokenized_dataset = False
332-
elif config.dataset_name == "c4/en:3.0.5":
333-
# gs://mlperf-6-submission/tfds/c4/en/3.0.5
334-
is_tokenized_dataset = True
335-
elif config.dataset_name == "c4/en:3.0.7":
336-
# gs://max-datasets-rogue/c4/en/3.0.7
337-
is_tokenized_dataset = False
338-
train_split = "train2"
339-
else:
340-
raise ValueError(f"{config.dataset_name=} should be one of " "('c4/en:3.0.1', 'c4/en:3.0.5', 'c4/en:3.0.7')")
341326
train_ds = get_dataset(
342327
dataset_name=config.dataset_name,
343-
split=train_split,
328+
split="train2",
344329
dataloading_host_index=process_indices.index(jax.process_index()),
345330
dataloading_host_count=len(process_indices),
346331
enable_data_shuffling=config.enable_data_shuffling,
347332
data_shuffle_seed=config.data_shuffle_seed,
348333
)
349334

350-
if is_tokenized_dataset:
351-
train_ds = rekey(train_ds, {"inputs": None, "targets": "ids"})
352-
else:
353-
train_ds = rekey(train_ds, {"inputs": None, "targets": "text"})
354-
335+
train_ds = rekey(train_ds, {"inputs": None, "targets": "text"})
355336
sp_tokenizer = get_tokenizer(
356337
config.tokenizer_path, config.tokenizer_type, config.add_bos, config.add_eos, config.hf_access_token
357338
)
@@ -362,7 +343,6 @@ def make_c4_mlperf_train_iterator(
362343
max_target_length=config.max_target_length,
363344
shuffle_buffer_size=128,
364345
data_shuffle_seed=config.data_shuffle_seed,
365-
is_tokenized_dataset=is_tokenized_dataset,
366346
)
367347
train_multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(train_ds, global_mesh)
368348
return train_multihost_gen

0 commit comments

Comments
 (0)