@@ -52,10 +52,13 @@ class TrainConfig:
5252 weight_decay : float
5353 lr_scheduler : str
5454 lr_warmup_steps : int
55+ lr_warmup_ratio : float
56+ max_grad_norm : float
5557
5658 max_length : int
5759 block_size : int
5860 mask_token : str
61+ loss_decay_gamma : float
5962
6063
6164def parse_args () -> TrainConfig :
@@ -75,18 +78,32 @@ def parse_args() -> TrainConfig:
7578
7679 parser .add_argument ("--per_device_train_batch_size" , type = int , default = 2 )
7780 parser .add_argument ("--gradient_accumulation_steps" , type = int , default = 1 )
78- parser .add_argument ("--learning_rate" , type = float , default = 2e-5 )
81+ parser .add_argument ("--learning_rate" , type = float , default = 6e-4 )
7982 parser .add_argument ("--weight_decay" , type = float , default = 0.0 )
8083 parser .add_argument (
8184 "--lr_scheduler" , type = str , default = "cosine" , choices = ["linear" , "cosine" , "cosine_with_restarts" ]
8285 )
83- parser .add_argument ("--lr_warmup_steps" , type = int , default = 100 )
86+ parser .add_argument (
87+ "--lr_warmup_steps" ,
88+ type = int ,
89+ default = 0 ,
90+ help = "Absolute warmup steps. Ignored when --lr_warmup_ratio > 0 (default)." ,
91+ )
92+ parser .add_argument ("--lr_warmup_ratio" , type = float , default = 0.04 )
93+ parser .add_argument ("--max_grad_norm" , type = float , default = 1.0 )
8494
85- parser .add_argument ("--max_length" , type = int , default = 512 )
95+ parser .add_argument ("--max_length" , type = int , default = 3072 )
8696 parser .add_argument (
8797 "--block_size" , type = int , default = 0 , help = "Override draft block size (0 uses the model config)."
8898 )
8999 parser .add_argument ("--mask_token" , type = str , default = "<|MASK|>" )
100+ parser .add_argument (
101+ "--loss_decay_gamma" ,
102+ type = float ,
103+ default = 0.0 ,
104+ help = "Per-position loss decay γ for w_k = exp(-(k-1)/γ). 0 selects the paper default for the "
105+ "draft block size (γ=7 for block 16, γ=5 for block 10, γ=4 for block 8, else block_size/2)." ,
106+ )
90107
91108 args = parser .parse_args ()
92109 return TrainConfig (** vars (args ))
@@ -177,6 +194,14 @@ def main():
177194 if block_size < 2 :
178195 raise ValueError ("`block_size` must be at least 2 for DFlash training." )
179196
197+ # Eq. 4 in the DFlash paper: w_k = exp(-(k-1)/γ) over predicted positions k=1..block_size-1.
198+ # Defaults from Appendix A.1.
199+ if cfg .loss_decay_gamma > 0.0 :
200+ loss_gamma = float (cfg .loss_decay_gamma )
201+ else :
202+ loss_gamma = {16 : 7.0 , 10 : 5.0 , 8 : 4.0 }.get (block_size , max (2.0 , block_size / 2.0 ))
203+ pos_weights = torch .exp (- torch .arange (block_size - 1 , dtype = torch .float32 ) / loss_gamma )
204+
180205 layer_ids = getattr (draft_model , "target_layer_ids" , None )
181206 if layer_ids is None :
182207 cfg_draft = getattr (draft_model , "config" , None )
@@ -208,10 +233,14 @@ def main():
208233 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / cfg .gradient_accumulation_steps )
209234 num_train_epochs = math .ceil (cfg .max_train_steps / num_update_steps_per_epoch )
210235
236+ if cfg .lr_warmup_ratio > 0.0 :
237+ num_warmup_steps = int (cfg .lr_warmup_ratio * cfg .max_train_steps )
238+ else :
239+ num_warmup_steps = cfg .lr_warmup_steps
211240 lr_scheduler = get_scheduler (
212241 name = cfg .lr_scheduler ,
213242 optimizer = optimizer ,
214- num_warmup_steps = cfg . lr_warmup_steps ,
243+ num_warmup_steps = num_warmup_steps ,
215244 num_training_steps = cfg .max_train_steps ,
216245 )
217246
@@ -220,6 +249,7 @@ def main():
220249 )
221250 input_embeddings = get_target_input_embeddings (target_model )
222251 output_embeddings = get_target_output_embeddings (target_model )
252+ pos_weights = pos_weights .to (accelerator .device )
223253
224254 global_step = 0
225255 draft_model .train ()
@@ -279,9 +309,12 @@ def main():
279309 vocab_size = logits .shape [- 1 ]
280310 loss = F .cross_entropy (logits .view (- 1 , vocab_size ), block_targets .reshape (- 1 ), reduction = "none" )
281311 loss = loss .view (block_targets .shape [0 ], - 1 )
282- loss = (loss * block_mask .to (loss .dtype )).sum () / block_mask .sum ().clamp_min (1 )
312+ weights = pos_weights .to (loss .dtype )[None , :].expand_as (loss ) * block_mask .to (loss .dtype )
313+ loss = (loss * weights ).sum () / weights .sum ().clamp_min (1 )
283314
284315 accelerator .backward (loss )
316+ if accelerator .sync_gradients and cfg .max_grad_norm > 0 :
317+ accelerator .clip_grad_norm_ (draft_model .parameters (), cfg .max_grad_norm )
285318 optimizer .step ()
286319 lr_scheduler .step ()
287320 optimizer .zero_grad (set_to_none = True )
0 commit comments