File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -172,6 +172,15 @@ def main(**kwargs):
172172 elif cfg .training_stage == "constant" :
173173 warmup_interval = 2000
174174 schedule = lambda x : (min (x , warmup_interval ) / warmup_interval )
175+ elif cfg .training_stage == "linear_to_constant" :
176+ linear_steps = 25000
177+ start_lr = 2e-4
178+ end_lr = 2e-4
179+ schedule = lambda x : (start_lr + (end_lr - start_lr ) * min (x - start_step , linear_steps ) / linear_steps ) / cfg .learning_rate
180+ elif cfg .training_stage == "annealing_with_specified_decay_steps" :
181+ warmup_interval = 2000
182+ total_decay_steps = 25000
183+ schedule = lambda x : (x - start_step ) / warmup_interval if x - start_step < warmup_interval else max (0.0 , 1 - (x - start_step - warmup_interval ) / total_decay_steps )
175184 else :
176185 schedule = lambda x : 1.0 + (0.75 - 1.0 ) * (x / 32000 ) if x <= 32000 else 0.75
177186
You can’t perform that action at this time.
0 commit comments