@@ -249,6 +249,7 @@ def preprocess_train_dataset(
249249 max_target_length : int ,
250250 shuffle_buffer_size : int ,
251251 data_shuffle_seed : int ,
252+ is_tokenized_dataset : bool ,
252253) -> tf .data .Dataset :
253254 """Preprocess the training dataset."""
254255 if sp_tokenizer .pad_id is not None :
@@ -257,10 +258,13 @@ def preprocess_train_dataset(
257258 pad_id = sp_tokenizer .unk_id
258259 else :
259260 pad_id = - 1
260- train_ds = train_ds .map (
261- lambda x : TokenizeOp (tokenizer_model = sp_tokenizer , features = x , data_keys = ("targets" ,)),
262- num_parallel_calls = AUTOTUNE ,
263- )
261+
262+ if not is_tokenized_dataset :
263+ train_ds = train_ds .map (
264+ lambda x : TokenizeOp (tokenizer_model = sp_tokenizer , features = x , data_keys = ("targets" ,)),
265+ num_parallel_calls = AUTOTUNE ,
266+ )
267+
264268 train_ds = reduce_concat_tokens (train_ds , feature_key = "targets" , batch_size = 4096 )
265269 train_ds = split_tokens_to_targets_length (train_ds , max_target_length )
266270 train_ds = train_ds .shuffle (shuffle_buffer_size , seed = data_shuffle_seed )
@@ -321,15 +325,33 @@ def make_c4_mlperf_train_iterator(
321325 process_indices ,
322326):
323327 """Make train iterator of customized C4 dataset for mlperf gpt3 training."""
328+ train_split = "train2"
329+ if config .dataset_name == "c4/en:3.0.5" :
330+ is_tokenized_dataset = True
331+ train_split = "train"
332+ elif config .dataset_name == "c4/en:3.0.4" :
333+ is_tokenized_dataset = False
334+ elif config .dataset_name in ["c4/en:3.0.1" , "c4/en:3.0.8" , "c4/en:3.0.9" ]:
335+ is_tokenized_dataset = False
336+ else :
337+ raise ValueError (
338+ f"{ config .dataset_name = } should be one of "
339+ "('c4/en:3.0.1', 'c4/en:3.0.4', 'c4/en:3.0.5', "
340+ "'c4/en:3.0.8', 'c4/en:3.0.9')"
341+ )
324342 train_ds = get_dataset (
325343 dataset_name = config .dataset_name ,
326- split = "train2" ,
344+ split = train_split ,
327345 dataloading_host_index = process_indices .index (jax .process_index ()),
328346 dataloading_host_count = len (process_indices ),
329347 enable_data_shuffling = config .enable_data_shuffling ,
330348 data_shuffle_seed = config .data_shuffle_seed ,
331349 )
332- train_ds = rekey (train_ds , {"inputs" : None , "targets" : "text" })
350+
351+ if is_tokenized_dataset :
352+ train_ds = rekey (train_ds , {"inputs" : None , "targets" : "ids" })
353+ else :
354+ train_ds = rekey (train_ds , {"inputs" : None , "targets" : "text" })
333355
334356 sp_tokenizer = get_tokenizer (
335357 config .tokenizer_path , config .tokenizer_type , config .add_bos , config .add_eos , config .hf_access_token
@@ -341,6 +363,7 @@ def make_c4_mlperf_train_iterator(
341363 max_target_length = config .max_target_length ,
342364 shuffle_buffer_size = 128 ,
343365 data_shuffle_seed = config .data_shuffle_seed ,
366+ is_tokenized_dataset = is_tokenized_dataset ,
344367 )
345368 train_multihost_gen = multihost_dataloading .MultiHostDataLoadIterator (train_ds , global_mesh )
346369 return train_multihost_gen
@@ -362,7 +385,11 @@ def make_c4_mlperf_eval_iterator(
362385 is_tokenized_dataset = False
363386 eval_slit = "validation"
364387 else :
365- raise ValueError (f"{ config .eval_dataset_name = } should be one of ('c4/en:3.0.1', 'c4/en:3.0.4', 'c4/en:3.0.5')" )
388+ raise ValueError (
389+ f"{ config .dataset_name = } should be one of "
390+ "('c4/en:3.0.1', 'c4/en:3.0.4', 'c4/en:3.0.5', "
391+ "'c4/en:3.0.8', 'c4/en:3.0.9')"
392+ )
366393
367394 if is_tokenized_dataset :
368395 eval_ds = get_dataset (
0 commit comments