@@ -748,6 +748,11 @@ def test_cosine_schedule(self):
748748 # Warmup phase: 0 -> peak
749749 self .assertAlmostEqual (float (schedule_fn (0 )), 0.0 , places = 6 )
750750 self .assertAlmostEqual (float (schedule_fn (warmup_steps )), learning_rate , places = 6 )
751+ # Ensure delta is constant
752+ expected_slope = learning_rate / warmup_steps
753+ for i in range (1 , warmup_steps + 1 ):
754+ current_lr = float (schedule_fn (i ))
755+ self .assertAlmostEqual (current_lr - float (schedule_fn (i - 1 )), expected_slope , places = 6 )
751756
752757 # Cosine decay phase
753758 lr_end = schedule_fn (learning_rate_schedule_steps - 1 )
@@ -791,6 +796,11 @@ def test_wsd_schedule(self):
791796 # Warmup phase: 0 -> peak
792797 self .assertAlmostEqual (float (schedule_fn (0 )), 0.0 , places = 6 )
793798 self .assertAlmostEqual (float (schedule_fn (warmup_steps )), learning_rate , places = 6 )
799+ # Ensure delta is constant
800+ expected_slope = learning_rate / warmup_steps
801+ for i in range (1 , warmup_steps + 1 ):
802+ current_lr = float (schedule_fn (i ))
803+ self .assertAlmostEqual (current_lr - float (schedule_fn (i - 1 )), expected_slope , places = 6 )
794804
795805 # Stable phase: constant at peak
796806 self .assertAlmostEqual (float (schedule_fn (warmup_steps + 10 )), learning_rate , places = 6 )
0 commit comments