@@ -37,6 +37,8 @@ def train(
3737 save_total_limit : int = 3 ,
3838 fp16 : bool = False ,
3939 bf16 : bool = False ,
40+ eval_batch_size : int | None = None ,
41+ eval_accumulation_steps : int = 1 ,
4042) -> None :
4143 """Train the encoder line classifier."""
4244 import torch
@@ -81,6 +83,7 @@ def train(
8183
8284 # Resize embeddings for the new [LINE_SEP] token
8385 model .encoder .resize_token_embeddings (len (tokenizer ))
86+ model .gradient_checkpointing_enable ()
8487
8588 # Load datasets
8689 logger .info (f"Loading train data from { train_file } " )
@@ -102,7 +105,7 @@ def train(
102105 output_dir = output_dir ,
103106 num_train_epochs = num_epochs ,
104107 per_device_train_batch_size = batch_size ,
105- per_device_eval_batch_size = max (1 , batch_size // 2 ),
108+ per_device_eval_batch_size = eval_batch_size or max (1 , batch_size // 2 ),
106109 gradient_accumulation_steps = gradient_accumulation_steps ,
107110 learning_rate = learning_rate ,
108111 weight_decay = weight_decay ,
@@ -120,6 +123,8 @@ def train(
120123 metric_for_best_model = "eval_loss" if eval_dataset else None ,
121124 report_to = "none" ,
122125 dataloader_num_workers = 0 ,
126+ eval_accumulation_steps = eval_accumulation_steps ,
127+ gradient_checkpointing = True ,
123128 remove_unused_columns = False ,
124129 )
125130
@@ -158,15 +163,17 @@ def build_parser(parser: argparse.ArgumentParser | None = None) -> argparse.Argu
158163
159164 parser .add_argument ("--train-file" , required = True , help = "Path to encoder_train.jsonl" )
160165 parser .add_argument ("--eval-file" , default = None , help = "Path to encoder_dev.jsonl" )
161- parser .add_argument ("--base-model" , default = "jhu-clsp/mmBERT-base" , help = "Base encoder model" )
162- parser .add_argument ("--output-dir" , default = "output/squeez_encoder" )
163- parser .add_argument ("--max-length" , type = int , default = 8192 )
164- parser .add_argument ("--batch-size" , type = int , default = 16 )
165- parser .add_argument ("--gradient-accumulation-steps" , type = int , default = 1 )
166- parser .add_argument ("--learning-rate" , type = float , default = 2e-5 )
167- parser .add_argument ("--num-epochs" , type = int , default = 5 )
168- parser .add_argument ("--warmup-ratio" , type = float , default = 0.1 )
169- parser .add_argument ("--weight-decay" , type = float , default = 0.01 )
166+ parser .add_argument ("--base-model" , default = None , help = "Base encoder model" )
167+ parser .add_argument ("--output-dir" , default = None )
168+ parser .add_argument ("--max-length" , type = int , default = None )
169+ parser .add_argument ("--batch-size" , type = int , default = None )
170+ parser .add_argument ("--eval-batch-size" , type = int , default = None )
171+ parser .add_argument ("--gradient-accumulation-steps" , type = int , default = None )
172+ parser .add_argument ("--eval-accumulation-steps" , type = int , default = None )
173+ parser .add_argument ("--learning-rate" , type = float , default = None )
174+ parser .add_argument ("--num-epochs" , type = int , default = None )
175+ parser .add_argument ("--warmup-ratio" , type = float , default = None )
176+ parser .add_argument ("--weight-decay" , type = float , default = None )
170177 parser .add_argument ("--eval-steps" , type = int , default = 200 )
171178 parser .add_argument ("--save-steps" , type = int , default = 200 )
172179 parser .add_argument ("--logging-steps" , type = int , default = 25 )
@@ -184,23 +191,37 @@ def main(argv: list[str] | None = None) -> int:
184191 format = "%(asctime)s [%(levelname)s] %(name)s: %(message)s" ,
185192 )
186193
194+ import yaml
195+
196+ default_config_path = Path (__file__ ).parent .parent .parent / "configs" / "default.yaml"
197+ config = {}
198+ if default_config_path .exists ():
199+ with open (default_config_path ) as f :
200+ config = yaml .safe_load (f ) or {}
201+
187202 train (
188203 train_file = args .train_file ,
189204 eval_file = args .eval_file ,
190- base_model = args .base_model ,
191- output_dir = args .output_dir ,
192- max_length = args .max_length ,
193- batch_size = args .batch_size ,
194- gradient_accumulation_steps = args .gradient_accumulation_steps ,
195- learning_rate = args .learning_rate ,
196- num_epochs = args .num_epochs ,
197- warmup_ratio = args .warmup_ratio ,
198- weight_decay = args .weight_decay ,
205+ base_model = args .base_model or config .get ("encoder_base_model" , "jhu-clsp/mmBERT-base" ),
206+ output_dir = args .output_dir or config .get ("encoder_output_dir" , "output/squeez_encoder" ),
207+ max_length = args .max_length or config .get ("encoder_max_length" , 8192 ),
208+ batch_size = args .batch_size or config .get ("encoder_batch_size" , 2 ),
209+ gradient_accumulation_steps = (
210+ args .gradient_accumulation_steps or config .get ("encoder_gradient_accumulation_steps" , 8 )
211+ ),
212+ learning_rate = args .learning_rate or config .get ("encoder_learning_rate" , 2e-5 ),
213+ num_epochs = args .num_epochs or config .get ("encoder_num_epochs" , 5 ),
214+ warmup_ratio = args .warmup_ratio or config .get ("encoder_warmup_ratio" , 0.1 ),
215+ weight_decay = args .weight_decay or config .get ("weight_decay" , 0.01 ),
199216 eval_steps = args .eval_steps ,
200217 save_steps = args .save_steps ,
201218 logging_steps = args .logging_steps ,
202219 fp16 = args .fp16 ,
203220 bf16 = args .bf16 ,
221+ eval_batch_size = args .eval_batch_size or config .get ("encoder_eval_batch_size" , 1 ),
222+ eval_accumulation_steps = (
223+ args .eval_accumulation_steps or config .get ("encoder_eval_accumulation_steps" , 1 )
224+ ),
204225 )
205226 return 0
206227
0 commit comments