@@ -721,5 +721,113 @@ def test_bytes_from_pytree_empty_dict(self):
721721 self .assertEqual (max_utils .calculate_bytes_from_pytree ({}), 0 )
722722
723723
724+ class TestLearningRateSchedules (unittest .TestCase ):
725+ """Test suite for learning rate schedule functions."""
726+
727+ def test_cosine_schedule (self ):
728+ """Tests cosine learning rate schedule."""
729+ learning_rate = 1e-3
730+ learning_rate_schedule_steps = 1000
731+ steps = 1200
732+ warmup_steps_fraction = 0.1
733+ learning_rate_final_fraction = 0.1
734+
735+ warmup_steps = int (learning_rate_schedule_steps * warmup_steps_fraction )
736+
737+ config = pyconfig .initialize (
738+ [None , os .path .join (MAXTEXT_PKG_DIR , "configs" , "base.yml" )],
739+ enable_checkpointing = False ,
740+ learning_rate = learning_rate ,
741+ learning_rate_schedule_steps = learning_rate_schedule_steps ,
742+ steps = steps ,
743+ warmup_steps_fraction = warmup_steps_fraction ,
744+ lr_schedule_type = "cosine" ,
745+ learning_rate_final_fraction = learning_rate_final_fraction ,
746+ )
747+
748+ schedule_fn = maxtext_utils .create_learning_rate_schedule (config )
749+
750+ # Warmup phase: 0 -> peak
751+ self .assertAlmostEqual (float (schedule_fn (0 )), 0.0 , places = 6 )
752+ self .assertAlmostEqual (float (schedule_fn (warmup_steps )), learning_rate , places = 6 )
753+
754+ # Cosine decay phase
755+ lr_end = schedule_fn (learning_rate_schedule_steps - 1 )
756+ expected_final = learning_rate * learning_rate_final_fraction
757+ self .assertLess (float (lr_end ), learning_rate )
758+ self .assertAlmostEqual (float (lr_end ), expected_final , places = 6 )
759+
760+ # Zero phase
761+ self .assertAlmostEqual (float (schedule_fn (steps - 1 )), 0.0 , places = 6 )
762+
763+ def test_wsd_schedule (self ):
764+ """Tests WSD learning rate schedule with both linear and cosine decay styles."""
765+ learning_rate = 1e-3
766+ learning_rate_schedule_steps = 1000
767+ steps = 1200
768+ warmup_steps_fraction = 0.1
769+ learning_rate_final_fraction = 0.1
770+ wsd_decay_steps_fraction = 0.1
771+
772+ warmup_steps = int (learning_rate_schedule_steps * warmup_steps_fraction )
773+ decay_steps = int (learning_rate_schedule_steps * wsd_decay_steps_fraction )
774+ stable_steps = learning_rate_schedule_steps - warmup_steps - decay_steps
775+ decay_start = warmup_steps + stable_steps
776+
777+ # Test both decay styles: linear and cosine
778+ for decay_style in ["linear" , "cosine" ]:
779+ config = pyconfig .initialize (
780+ [None , os .path .join (MAXTEXT_PKG_DIR , "configs" , "base.yml" )],
781+ enable_checkpointing = False ,
782+ learning_rate = learning_rate ,
783+ learning_rate_schedule_steps = learning_rate_schedule_steps ,
784+ steps = steps ,
785+ warmup_steps_fraction = warmup_steps_fraction ,
786+ lr_schedule_type = "wsd" ,
787+ learning_rate_final_fraction = learning_rate_final_fraction ,
788+ wsd_decay_steps_fraction = wsd_decay_steps_fraction ,
789+ wsd_decay_style = decay_style ,
790+ )
791+ schedule_fn = maxtext_utils .create_learning_rate_schedule (config )
792+
793+ # Warmup phase: 0 -> peak
794+ self .assertAlmostEqual (float (schedule_fn (0 )), 0.0 , places = 6 )
795+ self .assertAlmostEqual (float (schedule_fn (warmup_steps )), learning_rate , places = 6 )
796+
797+ # Stable phase: constant at peak
798+ self .assertAlmostEqual (float (schedule_fn (warmup_steps + 10 )), learning_rate , places = 6 )
799+ self .assertAlmostEqual (float (schedule_fn (warmup_steps + stable_steps // 2 )), learning_rate , places = 6 )
800+ self .assertAlmostEqual (float (schedule_fn (decay_start - 1 )), learning_rate , places = 6 )
801+
802+ # Decay phase: peak -> final
803+ lr_mid_decay = schedule_fn (decay_start + decay_steps // 2 )
804+ expected_final = learning_rate * learning_rate_final_fraction
805+ self .assertLess (float (lr_mid_decay ), learning_rate )
806+ self .assertGreater (float (lr_mid_decay ), expected_final )
807+
808+ # End of decay phase: should reach expected_final
809+ lr_end = schedule_fn (learning_rate_schedule_steps - 1 )
810+ self .assertAlmostEqual (float (lr_end ), expected_final , places = 6 )
811+
812+ # Zero phase
813+ self .assertAlmostEqual (float (schedule_fn (steps - 1 )), 0.0 , places = 6 )
814+
815+ # Test invalid fractions - should raise during config initialization
816+ with self .assertRaises (ValueError ) as cm :
817+ pyconfig .initialize (
818+ [None , os .path .join (MAXTEXT_PKG_DIR , "configs" , "base.yml" )],
819+ enable_checkpointing = False ,
820+ learning_rate = learning_rate ,
821+ learning_rate_schedule_steps = learning_rate_schedule_steps ,
822+ steps = steps ,
823+ warmup_steps_fraction = 0.6 ,
824+ lr_schedule_type = "wsd" ,
825+ learning_rate_final_fraction = learning_rate_final_fraction ,
826+ wsd_decay_steps_fraction = 0.5 , # Sum > 1.0
827+ )
828+ self .assertIn ("warmup_steps_fraction" , str (cm .exception ))
829+ self .assertIn ("wsd_decay_steps_fraction" , str (cm .exception ))
830+
831+
724832if __name__ == "__main__" :
725833 unittest .main ()
0 commit comments