@@ -794,6 +794,99 @@ def test_fp8_param_warning(self):
794794 with pytest .warns (UserWarning , match = "fp8_param=True sometimes causes NaN" ):
795795 _apply_performance_config (model_cfg , config )
796796
797+ def test_fine_grained_activation_offloading_enabled (self ):
798+ """Test happy path: enabled with non-empty offload_modules list."""
799+ from nemo_rl .models .megatron .setup import _apply_performance_config
800+
801+ model_cfg = MagicMock ()
802+ model_cfg .gated_linear_unit = True
803+ offload_modules = ["mlp" , "moe_act" ]
804+ config = {
805+ "megatron_cfg" : {
806+ "activation_checkpointing" : False ,
807+ "apply_rope_fusion" : False ,
808+ "bias_activation_fusion" : False ,
809+ "gradient_accumulation_fusion" : False ,
810+ "fine_grained_activation_offloading" : True ,
811+ "offload_modules" : offload_modules ,
812+ }
813+ }
814+
815+ _apply_performance_config (model_cfg , config )
816+
817+ assert model_cfg .fine_grained_activation_offloading is True
818+ assert model_cfg .offload_modules == offload_modules
819+
820+ def test_fine_grained_activation_offloading_disabled_skips (self ):
821+ """When flag is False (default), no offload attrs should be set."""
822+ from nemo_rl .models .megatron .setup import _apply_performance_config
823+
824+ model_cfg = MagicMock (spec = ["gated_linear_unit" ])
825+ model_cfg .gated_linear_unit = True
826+ config = {
827+ "megatron_cfg" : {
828+ "activation_checkpointing" : False ,
829+ "apply_rope_fusion" : False ,
830+ "bias_activation_fusion" : False ,
831+ "gradient_accumulation_fusion" : False ,
832+ }
833+ }
834+
835+ _apply_performance_config (model_cfg , config )
836+
837+ assert not hasattr (model_cfg , "fine_grained_activation_offloading" )
838+ assert not hasattr (model_cfg , "offload_modules" )
839+
840+ @pytest .mark .parametrize (
841+ "offload_modules" ,
842+ [[], None , "mlp" , 42 ],
843+ ids = ["empty_list" , "none" , "string" , "int" ],
844+ )
845+ def test_fine_grained_activation_offloading_invalid_modules_raises (
846+ self , offload_modules
847+ ):
848+ """offload_modules must be a non-empty list when feature is enabled."""
849+ from nemo_rl .models .megatron .setup import _apply_performance_config
850+
851+ model_cfg = MagicMock ()
852+ model_cfg .gated_linear_unit = True
853+ config = {
854+ "megatron_cfg" : {
855+ "activation_checkpointing" : False ,
856+ "apply_rope_fusion" : False ,
857+ "bias_activation_fusion" : False ,
858+ "gradient_accumulation_fusion" : False ,
859+ "fine_grained_activation_offloading" : True ,
860+ "offload_modules" : offload_modules ,
861+ }
862+ }
863+
864+ with pytest .raises (
865+ ValueError , match = "offload_modules must be a non-empty list"
866+ ):
867+ _apply_performance_config (model_cfg , config )
868+
869+ def test_fine_grained_activation_offloading_missing_modules_raises (self ):
870+ """When enabled but offload_modules key is absent, defaults to [] → raises."""
871+ from nemo_rl .models .megatron .setup import _apply_performance_config
872+
873+ model_cfg = MagicMock ()
874+ model_cfg .gated_linear_unit = True
875+ config = {
876+ "megatron_cfg" : {
877+ "activation_checkpointing" : False ,
878+ "apply_rope_fusion" : False ,
879+ "bias_activation_fusion" : False ,
880+ "gradient_accumulation_fusion" : False ,
881+ "fine_grained_activation_offloading" : True ,
882+ }
883+ }
884+
885+ with pytest .raises (
886+ ValueError , match = "offload_modules must be a non-empty list"
887+ ):
888+ _apply_performance_config (model_cfg , config )
889+
797890
798891@pytest .mark .mcore
799892class TestValidateOptimizerConfig :
0 commit comments