@@ -29,18 +29,14 @@ def get_scheduler(
2929
3030 if name == SchedulerType .CONSTANT :
3131
32- def lr_lambda (
33- current_step : int , num_warmup_steps : Optional [int ] = None
34- ) -> float :
32+ def lr_lambda (current_step : int , num_warmup_steps : Optional [int ] = None ) -> float :
3533 current_step = current_step + last_global_step
3634 if exists (num_warmup_steps ) and current_step < num_warmup_steps :
3735 return current_step / max (1.0 , num_warmup_steps )
3836 return 1.0
3937 elif name == SchedulerType .COSINE :
4038
41- def lr_lambda (
42- current_step : int , num_warmup_steps : Optional [int ] = None
43- ) -> float :
39+ def lr_lambda (current_step : int , num_warmup_steps : Optional [int ] = None ) -> float :
4440 current_step = current_step + last_global_step
4541 if exists (num_warmup_steps ) and current_step < num_warmup_steps :
4642 return current_step / max (1.0 , num_warmup_steps )
@@ -50,12 +46,10 @@ def lr_lambda(
5046 progress = (
5147 (current_step - num_warmup_steps ) / (num_training_steps - num_warmup_steps )
5248 )
53- return 0.5 * (1.0 - eta_min ) * (1.0 + math .cos (math .pi * progress )) + eta_min
49+ return 0.5 * (1.0 - eta_min ) * (1.0 + math .cos (math .pi * progress )) + eta_min
5450 elif name == SchedulerType .LINEAR :
5551
56- def lr_lambda (
57- current_step : int , num_warmup_steps : Optional [int ] = None
58- ) -> float :
52+ def lr_lambda (current_step : int , num_warmup_steps : Optional [int ] = None ) -> float :
5953 current_step = current_step + last_global_step
6054 if exists (num_warmup_steps ) and current_step < num_warmup_steps :
6155 return current_step / max (1.0 , num_warmup_steps )
0 commit comments