Skip to content

Commit 5225217

Browse files
committed
test: add unit tests for fine_grained_activation_offloading branch
Covers _apply_performance_config offload-modules dispatch: - happy path: True + non-empty list sets both attrs - disabled: defaults skip the branch (no attrs touched) - invalid offload_modules ([], None, str, int) all raise ValueError - missing offload_modules key raises ValueError Lifts patch coverage above codecov target. Signed-off-by: sna <sna@nvidia.com>
1 parent 2b33ad3 commit 5225217

1 file changed

Lines changed: 93 additions & 0 deletions

File tree

tests/unit/models/megatron/test_megatron_setup.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,99 @@ def test_fp8_param_warning(self):
794794
with pytest.warns(UserWarning, match="fp8_param=True sometimes causes NaN"):
795795
_apply_performance_config(model_cfg, config)
796796

797+
def test_fine_grained_activation_offloading_enabled(self):
798+
"""Test happy path: enabled with non-empty offload_modules list."""
799+
from nemo_rl.models.megatron.setup import _apply_performance_config
800+
801+
model_cfg = MagicMock()
802+
model_cfg.gated_linear_unit = True
803+
offload_modules = ["mlp", "moe_act"]
804+
config = {
805+
"megatron_cfg": {
806+
"activation_checkpointing": False,
807+
"apply_rope_fusion": False,
808+
"bias_activation_fusion": False,
809+
"gradient_accumulation_fusion": False,
810+
"fine_grained_activation_offloading": True,
811+
"offload_modules": offload_modules,
812+
}
813+
}
814+
815+
_apply_performance_config(model_cfg, config)
816+
817+
assert model_cfg.fine_grained_activation_offloading is True
818+
assert model_cfg.offload_modules == offload_modules
819+
820+
def test_fine_grained_activation_offloading_disabled_skips(self):
821+
"""When flag is False (default), no offload attrs should be set."""
822+
from nemo_rl.models.megatron.setup import _apply_performance_config
823+
824+
model_cfg = MagicMock(spec=["gated_linear_unit"])
825+
model_cfg.gated_linear_unit = True
826+
config = {
827+
"megatron_cfg": {
828+
"activation_checkpointing": False,
829+
"apply_rope_fusion": False,
830+
"bias_activation_fusion": False,
831+
"gradient_accumulation_fusion": False,
832+
}
833+
}
834+
835+
_apply_performance_config(model_cfg, config)
836+
837+
assert not hasattr(model_cfg, "fine_grained_activation_offloading")
838+
assert not hasattr(model_cfg, "offload_modules")
839+
840+
@pytest.mark.parametrize(
841+
"offload_modules",
842+
[[], None, "mlp", 42],
843+
ids=["empty_list", "none", "string", "int"],
844+
)
845+
def test_fine_grained_activation_offloading_invalid_modules_raises(
846+
self, offload_modules
847+
):
848+
"""offload_modules must be a non-empty list when feature is enabled."""
849+
from nemo_rl.models.megatron.setup import _apply_performance_config
850+
851+
model_cfg = MagicMock()
852+
model_cfg.gated_linear_unit = True
853+
config = {
854+
"megatron_cfg": {
855+
"activation_checkpointing": False,
856+
"apply_rope_fusion": False,
857+
"bias_activation_fusion": False,
858+
"gradient_accumulation_fusion": False,
859+
"fine_grained_activation_offloading": True,
860+
"offload_modules": offload_modules,
861+
}
862+
}
863+
864+
with pytest.raises(
865+
ValueError, match="offload_modules must be a non-empty list"
866+
):
867+
_apply_performance_config(model_cfg, config)
868+
869+
def test_fine_grained_activation_offloading_missing_modules_raises(self):
870+
"""When enabled but offload_modules key is absent, defaults to [] → raises."""
871+
from nemo_rl.models.megatron.setup import _apply_performance_config
872+
873+
model_cfg = MagicMock()
874+
model_cfg.gated_linear_unit = True
875+
config = {
876+
"megatron_cfg": {
877+
"activation_checkpointing": False,
878+
"apply_rope_fusion": False,
879+
"bias_activation_fusion": False,
880+
"gradient_accumulation_fusion": False,
881+
"fine_grained_activation_offloading": True,
882+
}
883+
}
884+
885+
with pytest.raises(
886+
ValueError, match="offload_modules must be a non-empty list"
887+
):
888+
_apply_performance_config(model_cfg, config)
889+
797890

798891
@pytest.mark.mcore
799892
class TestValidateOptimizerConfig:

0 commit comments

Comments
 (0)