Skip to content

Commit 72e0297

Browse files
committed
New LR schedules
1 parent 7d02d2c commit 72e0297

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

main_training_mamba.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,15 @@ def main(**kwargs):
172172
elif cfg.training_stage == "constant":
173173
warmup_interval = 2000
174174
schedule = lambda x: (min(x, warmup_interval) / warmup_interval)
175+
elif cfg.training_stage == "linear_to_constant":
176+
linear_steps = 25000
177+
start_lr = 2e-4
178+
end_lr = 2e-4
179+
schedule = lambda x: (start_lr + (end_lr - start_lr) * min(x - start_step, linear_steps) / linear_steps) / cfg.learning_rate
180+
elif cfg.training_stage == "annealing_with_specified_decay_steps":
181+
warmup_interval = 2000
182+
total_decay_steps = 25000
183+
schedule = lambda x: (x - start_step) / warmup_interval if x - start_step < warmup_interval else max(0.0, 1 - (x - start_step - warmup_interval) / total_decay_steps)
175184
else:
176185
schedule = lambda x: 1.0 + (0.75 - 1.0) * (x / 32000) if x <= 32000 else 0.75
177186

0 commit comments

Comments
 (0)