@@ -91,7 +91,6 @@ def preprocessing_pipeline(
9191 shift : bool = True ,
9292 drop_remainder : bool = True ,
9393 prefetch_size = tf .data .experimental .AUTOTUNE ,
94- use_dpo : bool = False ,
9594 hf_access_token : str = "" ,
9695):
9796 """pipeline for preprocessing TFDS dataset."""
@@ -115,15 +114,11 @@ def preprocessing_pipeline(
115114 "Set tokenize_train_data or tokenize_eval_data to True if your dataset needs tokenization."
116115 )
117116
118- if not use_dpo :
119- assert len (data_column_names ) == 1
120- dataset = dataset .map (
121- lambda x : input_pipeline_utils .normalize_features (x , data_column_names [0 ]), num_parallel_calls = AUTOTUNE
122- )
123- else :
124- dataset = dataset .map (lambda x : {col : x [col ] for col in data_column_names }, num_parallel_calls = AUTOTUNE )
125-
126- data_column_names = data_column_names if use_dpo else ("inputs" , "targets" )
117+ assert len (data_column_names ) == 1
118+ dataset = dataset .map (
119+ lambda x : input_pipeline_utils .normalize_features (x , data_column_names [0 ]), num_parallel_calls = AUTOTUNE
120+ )
121+ data_column_names = ("inputs" , "targets" )
127122
128123 tokenizer_model = input_pipeline_utils .get_tokenizer (tokenizer_path , tokenizer_type , add_bos , add_eos , hf_access_token )
129124 if tokenizer_model .pad_id is not None :
@@ -144,7 +139,7 @@ def preprocessing_pipeline(
144139 if max_target_length > 0 :
145140 # in pre-training we can take upto max_length+1 because there would be truncation by
146141 # 1 token for both inputs and targets
147- extra_tokens = 1 if not use_dpo else 0
142+ extra_tokens = 1
148143 dataset = dataset .map (
149144 lambda x : input_pipeline_utils .truncate_to_max_allowable_length (x , max_target_length + extra_tokens ),
150145 num_parallel_calls = AUTOTUNE ,
@@ -157,13 +152,13 @@ def preprocessing_pipeline(
157152 dataset = dataset .repeat (num_epochs )
158153
159154 # Shift inputs for teacher-forced training
160- if shift and not use_dpo :
155+ if shift :
161156 dataset = dataset .map (
162157 input_pipeline_utils .shift_data_by_truncation , num_parallel_calls = tf .data .AUTOTUNE , deterministic = True
163158 )
164159
165160 # Perform greedy sequence packing and batching
166- if pack_examples and not use_dpo :
161+ if pack_examples :
167162 dataset = sequence_packing .pack_dataset (dataset , max_target_length , pad_id )
168163 dataset = dataset .batch (global_batch_size // jax .process_count (), drop_remainder = drop_remainder )
169164 else :
@@ -223,7 +218,6 @@ def make_tfds_train_iterator(
223218 add_eos = config .add_eos ,
224219 num_epochs = config .num_epoch ,
225220 pack_examples = config .packing ,
226- use_dpo = config .use_dpo ,
227221 hf_access_token = config .hf_access_token ,
228222 )
229223 return multihost_dataloading .MultiHostDataLoadIterator (
@@ -248,7 +242,6 @@ def make_tfds_train_iterator(
248242 add_eos = config .add_eos ,
249243 num_epochs = config .num_epoch ,
250244 pack_examples = config .packing ,
251- use_dpo = config .use_dpo ,
252245 hf_access_token = config .hf_access_token ,
253246 )
254247 global_shape = (config .global_batch_size_to_load , config .max_target_length )
@@ -289,7 +282,6 @@ def make_tfds_eval_iterator(
289282 add_bos = config .add_bos ,
290283 add_eos = config .add_eos ,
291284 pack_examples = config .packing ,
292- use_dpo = config .use_dpo ,
293285 hf_access_token = config .hf_access_token ,
294286 )
295287 return multihost_dataloading .MultiHostDataLoadIterator (
@@ -317,7 +309,6 @@ def make_tfds_eval_iterator(
317309 add_bos = config .add_bos ,
318310 add_eos = config .add_eos ,
319311 pack_examples = config .packing ,
320- use_dpo = config .use_dpo ,
321312 hf_access_token = config .hf_access_token ,
322313 )
323314 global_shape = (config .global_batch_size_to_load_eval , config .max_target_length )
0 commit comments