@@ -249,7 +249,6 @@ def preprocess_train_dataset(
249249 max_target_length : int ,
250250 shuffle_buffer_size : int ,
251251 data_shuffle_seed : int ,
252- is_tokenized_dataset : bool ,
253252) -> tf .data .Dataset :
254253 """Preprocess the training dataset."""
255254 if sp_tokenizer .pad_id is not None :
@@ -259,11 +258,10 @@ def preprocess_train_dataset(
259258 else :
260259 pad_id = - 1
261260
262- if not is_tokenized_dataset :
263- train_ds = train_ds .map (
264- lambda x : TokenizeOp (tokenizer_model = sp_tokenizer , features = x , data_keys = ("targets" ,)),
265- num_parallel_calls = AUTOTUNE ,
266- )
261+ train_ds = train_ds .map (
262+ lambda x : TokenizeOp (tokenizer_model = sp_tokenizer , features = x , data_keys = ("targets" ,)),
263+ num_parallel_calls = AUTOTUNE ,
264+ )
267265
268266 train_ds = reduce_concat_tokens (train_ds , feature_key = "targets" , batch_size = 4096 )
269267 train_ds = split_tokens_to_targets_length (train_ds , max_target_length )
@@ -325,33 +323,16 @@ def make_c4_mlperf_train_iterator(
325323 process_indices ,
326324):
327325 """Make train iterator of customized C4 dataset for mlperf gpt3 training."""
328- train_split = "train"
329- if config .dataset_name == "c4/en:3.0.1" :
330- # gs://max-datasets-rogue/c4/en/3.0.1
331- is_tokenized_dataset = False
332- elif config .dataset_name == "c4/en:3.0.5" :
333- # gs://mlperf-6-submission/tfds/c4/en/3.0.5
334- is_tokenized_dataset = True
335- elif config .dataset_name == "c4/en:3.0.7" :
336- # gs://max-datasets-rogue/c4/en/3.0.7
337- is_tokenized_dataset = False
338- train_split = "train2"
339- else :
340- raise ValueError (f"{ config .dataset_name = } should be one of " "('c4/en:3.0.1', 'c4/en:3.0.5', 'c4/en:3.0.7')" )
341326 train_ds = get_dataset (
342327 dataset_name = config .dataset_name ,
343- split = train_split ,
328+ split = "train2" ,
344329 dataloading_host_index = process_indices .index (jax .process_index ()),
345330 dataloading_host_count = len (process_indices ),
346331 enable_data_shuffling = config .enable_data_shuffling ,
347332 data_shuffle_seed = config .data_shuffle_seed ,
348333 )
349334
350- if is_tokenized_dataset :
351- train_ds = rekey (train_ds , {"inputs" : None , "targets" : "ids" })
352- else :
353- train_ds = rekey (train_ds , {"inputs" : None , "targets" : "text" })
354-
335+ train_ds = rekey (train_ds , {"inputs" : None , "targets" : "text" })
355336 sp_tokenizer = get_tokenizer (
356337 config .tokenizer_path , config .tokenizer_type , config .add_bos , config .add_eos , config .hf_access_token
357338 )
@@ -362,7 +343,6 @@ def make_c4_mlperf_train_iterator(
362343 max_target_length = config .max_target_length ,
363344 shuffle_buffer_size = 128 ,
364345 data_shuffle_seed = config .data_shuffle_seed ,
365- is_tokenized_dataset = is_tokenized_dataset ,
366346 )
367347 train_multihost_gen = multihost_dataloading .MultiHostDataLoadIterator (train_ds , global_mesh )
368348 return train_multihost_gen
0 commit comments