@@ -249,6 +249,7 @@ def preprocess_train_dataset(
249249 max_target_length : int ,
250250 shuffle_buffer_size : int ,
251251 data_shuffle_seed : int ,
252+ is_tokenized_dataset : bool ,
252253) -> tf .data .Dataset :
253254 """Preprocess the training dataset."""
254255 if sp_tokenizer .pad_id is not None :
@@ -257,10 +258,13 @@ def preprocess_train_dataset(
257258 pad_id = sp_tokenizer .unk_id
258259 else :
259260 pad_id = - 1
260- train_ds = train_ds .map (
261- lambda x : TokenizeOp (tokenizer_model = sp_tokenizer , features = x , data_keys = ("targets" ,)),
262- num_parallel_calls = AUTOTUNE ,
263- )
261+
262+ if not is_tokenized_dataset :
263+ train_ds = train_ds .map (
264+ lambda x : TokenizeOp (tokenizer_model = sp_tokenizer , features = x , data_keys = ("targets" ,)),
265+ num_parallel_calls = AUTOTUNE ,
266+ )
267+
264268 train_ds = reduce_concat_tokens (train_ds , feature_key = "targets" , batch_size = 4096 )
265269 train_ds = split_tokens_to_targets_length (train_ds , max_target_length )
266270 train_ds = train_ds .shuffle (shuffle_buffer_size , seed = data_shuffle_seed )
@@ -321,15 +325,32 @@ def make_c4_mlperf_train_iterator(
321325 process_indices ,
322326):
323327 """Make train iterator of customized C4 dataset for mlperf gpt3 training."""
328+ train_split = "train"
329+ if config .dataset_name == "c4/en:3.0.1" :
330+ # gs://max-datasets-rogue/c4/en/3.0.1
331+ is_tokenized_dataset = False
332+ elif config .dataset_name == "c4/en:3.0.5" :
333+ # gs://mlperf-6-submission/tfds/c4/en/3.0.5
334+ is_tokenized_dataset = True
335+ elif config .dataset_name == "c4/en:3.0.7" :
336+ # gs://max-datasets-rogue/c4/en/3.0.7
337+ is_tokenized_dataset = False
338+ train_split = "train2"
339+ else :
340+ raise ValueError (f"{ config .dataset_name = } should be one of " "('c4/en:3.0.1', 'c4/en:3.0.5', 'c4/en:3.0.7')" )
324341 train_ds = get_dataset (
325342 dataset_name = config .dataset_name ,
326- split = "train2" ,
343+ split = train_split ,
327344 dataloading_host_index = process_indices .index (jax .process_index ()),
328345 dataloading_host_count = len (process_indices ),
329346 enable_data_shuffling = config .enable_data_shuffling ,
330347 data_shuffle_seed = config .data_shuffle_seed ,
331348 )
332- train_ds = rekey (train_ds , {"inputs" : None , "targets" : "text" })
349+
350+ if is_tokenized_dataset :
351+ train_ds = rekey (train_ds , {"inputs" : None , "targets" : "ids" })
352+ else :
353+ train_ds = rekey (train_ds , {"inputs" : None , "targets" : "text" })
333354
334355 sp_tokenizer = get_tokenizer (
335356 config .tokenizer_path , config .tokenizer_type , config .add_bos , config .add_eos , config .hf_access_token
@@ -341,6 +362,7 @@ def make_c4_mlperf_train_iterator(
341362 max_target_length = config .max_target_length ,
342363 shuffle_buffer_size = 128 ,
343364 data_shuffle_seed = config .data_shuffle_seed ,
365+ is_tokenized_dataset = is_tokenized_dataset ,
344366 )
345367 train_multihost_gen = multihost_dataloading .MultiHostDataLoadIterator (train_ds , global_mesh )
346368 return train_multihost_gen
@@ -352,17 +374,21 @@ def make_c4_mlperf_eval_iterator(
352374 process_indices ,
353375):
354376 """Make eval iterator of customized C4 dataset for mlperf gpt3 training."""
355- eval_slit = "None"
377+ eval_split = "None"
356378 if config .eval_dataset_name == "c4/en:3.0.5" :
357379 is_tokenized_dataset = True
358380 elif config .eval_dataset_name == "c4/en:3.0.4" :
359381 is_tokenized_dataset = False
360- eval_slit = "validation_24567exp"
382+ eval_split = "validation_24567exp"
361383 elif config .eval_dataset_name in ["c4/en:3.0.1" , "c4/en:3.0.8" , "c4/en:3.0.9" ]:
362384 is_tokenized_dataset = False
363- eval_slit = "validation"
385+ eval_split = "validation"
364386 else :
365- raise ValueError (f"{ config .eval_dataset_name = } should be one of ('c4/en:3.0.1', 'c4/en:3.0.4', 'c4/en:3.0.5')" )
387+ raise ValueError (
388+ f"{ config .eval_dataset_name = } should be one of "
389+ "('c4/en:3.0.1', 'c4/en:3.0.4', 'c4/en:3.0.5', "
390+ "'c4/en:3.0.8', 'c4/en:3.0.9')"
391+ )
366392
367393 if is_tokenized_dataset :
368394 eval_ds = get_dataset (
@@ -378,7 +404,7 @@ def make_c4_mlperf_eval_iterator(
378404 else :
379405 eval_ds = get_dataset (
380406 dataset_name = config .eval_dataset_name ,
381- split = eval_slit ,
407+ split = eval_split ,
382408 dataloading_host_index = process_indices .index (jax .process_index ()),
383409 dataloading_host_count = len (process_indices ),
384410 enable_data_shuffling = False ,
0 commit comments