@@ -249,7 +249,7 @@ def preprocess_train_dataset(
249249 max_target_length : int ,
250250 shuffle_buffer_size : int ,
251251 data_shuffle_seed : int ,
252- is_tokenized_dataset : bool = True ,
252+ is_tokenized_dataset : bool ,
253253) -> tf .data .Dataset :
254254 """Preprocess the training dataset."""
255255 if sp_tokenizer .pad_id is not None :
@@ -325,17 +325,19 @@ def make_c4_mlperf_train_iterator(
325325 process_indices ,
326326):
327327 """Make train iterator of customized C4 dataset for mlperf gpt3 training."""
328- train_split = "train2"
329- if config .dataset_name == "c4/en:3.0.5" :
330- is_tokenized_dataset = True
331- train_split = "train"
332- elif config .dataset_name == "c4/en:3.0.4" :
328+ train_split = "train"
329+ if config .dataset_name == "c4/en:3.0.1" :
330+ # gs://max-datasets-rogue/c4/en/3.0.1
333331 is_tokenized_dataset = False
334- elif config .dataset_name in ["c4/en:3.0.1" , "c4/en:3.0.8" , "c4/en:3.0.9" ]:
332+ elif config .dataset_name == "c4/en:3.0.5" :
333+ # gs://mlperf-6-submission/tfds/c4/en/3.0.5
334+ is_tokenized_dataset = True
335+ elif config .dataset_name == "c4/en:3.0.7" :
336+ # gs://max-datasets-rogue/c4/en/3.0.7
335337 is_tokenized_dataset = False
338+ train_split = "train2"
336339 else :
337- raise ValueError (f"{ config .dataset_name = } should be one of ('c4/en:3.0.1', 'c4/en:3.0.4', 'c4/en:3.0.5')" )
338-
340+ raise ValueError (f"{ config .dataset_name = } should be one of " "('c4/en:3.0.1', 'c4/en:3.0.5', 'c4/en:3.0.7')" )
339341 train_ds = get_dataset (
340342 dataset_name = config .dataset_name ,
341343 split = train_split ,
@@ -372,17 +374,21 @@ def make_c4_mlperf_eval_iterator(
372374 process_indices ,
373375):
374376 """Make eval iterator of customized C4 dataset for mlperf gpt3 training."""
375- eval_slit = "None"
377+ eval_split = "None"
376378 if config .eval_dataset_name == "c4/en:3.0.5" :
377379 is_tokenized_dataset = True
378380 elif config .eval_dataset_name == "c4/en:3.0.4" :
379381 is_tokenized_dataset = False
380- eval_slit = "validation_24567exp"
382+ eval_split = "validation_24567exp"
381383 elif config .eval_dataset_name in ["c4/en:3.0.1" , "c4/en:3.0.8" , "c4/en:3.0.9" ]:
382384 is_tokenized_dataset = False
383- eval_slit = "validation"
385+ eval_split = "validation"
384386 else :
385- raise ValueError (f"{ config .eval_dataset_name = } should be one of ('c4/en:3.0.1', 'c4/en:3.0.4', 'c4/en:3.0.5')" )
387+ raise ValueError (
388+ f"{ config .eval_dataset_name = } should be one of "
389+ "('c4/en:3.0.1', 'c4/en:3.0.4', 'c4/en:3.0.5', "
390+ "'c4/en:3.0.8', 'c4/en:3.0.9')"
391+ )
386392
387393 if is_tokenized_dataset :
388394 eval_ds = get_dataset (
@@ -398,7 +404,7 @@ def make_c4_mlperf_eval_iterator(
398404 else :
399405 eval_ds = get_dataset (
400406 dataset_name = config .eval_dataset_name ,
401- split = eval_slit ,
407+ split = eval_split ,
402408 dataloading_host_index = process_indices .index (jax .process_index ()),
403409 dataloading_host_count = len (process_indices ),
404410 enable_data_shuffling = False ,
0 commit comments