From 1713776b75bfe277a4cdb13222c7ad4d439399c9 Mon Sep 17 00:00:00 2001 From: sna Date: Thu, 16 Apr 2026 21:08:33 -0700 Subject: [PATCH 1/6] Add moe_grouped_gemm support and enable on MoE performance recipes Wires moe_grouped_gemm (CUTLASS grouped GEMM for MoE experts) through the MegatronConfig TypedDict and _apply_moe_config(). Enables it in every root MoE performance recipe (Qwen3-30B-A3B 4n4g/4n8g/4n8g-40K, Qwen3-235B 16n8g, DeepSeek-V3 32n8g, DAPO DeepSeek-V3 64n8g); child recipes inherit. Signed-off-by: sna --- .../recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml | 1 + .../recipes/llm/performance/grpo-deepseek-v3-32n8g.yaml | 1 + .../configs/recipes/llm/performance/grpo-qwen3-235b-16n8g.yaml | 1 + .../recipes/llm/performance/grpo-qwen3-30ba3b-4n4g.yaml | 1 + .../recipes/llm/performance/grpo-qwen3-30ba3b-4n8g-40K.yaml | 1 + .../recipes/llm/performance/grpo-qwen3-30ba3b-4n8g.yaml | 1 + nemo_rl/models/megatron/setup.py | 2 ++ nemo_rl/models/policy/__init__.py | 3 +++ 8 files changed, 11 insertions(+) diff --git a/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml b/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml index 056c15294a..5ef1591ecd 100644 --- a/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml +++ b/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml @@ -44,6 +44,7 @@ policy: empty_unused_memory_level: 2 enabled: true activation_checkpointing: true + moe_grouped_gemm: true tensor_model_parallel_size: 8 expert_model_parallel_size: 32 pipeline_model_parallel_size: 8 diff --git a/examples/configs/recipes/llm/performance/grpo-deepseek-v3-32n8g.yaml b/examples/configs/recipes/llm/performance/grpo-deepseek-v3-32n8g.yaml index c8a9487175..8405629c66 100644 --- a/examples/configs/recipes/llm/performance/grpo-deepseek-v3-32n8g.yaml +++ b/examples/configs/recipes/llm/performance/grpo-deepseek-v3-32n8g.yaml @@ -24,6 +24,7 @@ policy: pipeline_model_parallel_size: 16 expert_model_parallel_size: 16 activation_checkpointing: true + moe_grouped_gemm: true num_layers_in_first_pipeline_stage: 3 num_layers_in_last_pipeline_stage: 2 apply_rope_fusion: false diff --git a/examples/configs/recipes/llm/performance/grpo-qwen3-235b-16n8g.yaml b/examples/configs/recipes/llm/performance/grpo-qwen3-235b-16n8g.yaml index aecdabba73..1228db60bb 100644 --- a/examples/configs/recipes/llm/performance/grpo-qwen3-235b-16n8g.yaml +++ b/examples/configs/recipes/llm/performance/grpo-qwen3-235b-16n8g.yaml @@ -29,6 +29,7 @@ policy: context_parallel_size: 2 expert_model_parallel_size: 16 activation_checkpointing: true + moe_grouped_gemm: true num_layers_in_first_pipeline_stage: 11 num_layers_in_last_pipeline_stage: 11 defer_fp32_logits: true diff --git a/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n4g.yaml b/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n4g.yaml index 21b9746f4b..41fd83ec21 100644 --- a/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n4g.yaml +++ b/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n4g.yaml @@ -21,6 +21,7 @@ policy: pipeline_model_parallel_size: 1 expert_model_parallel_size: 16 sequence_parallel: false + moe_grouped_gemm: true optimizer: lr: 3.0e-07 min_lr: 3.0e-08 diff --git a/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g-40K.yaml b/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g-40K.yaml index 3b4f22ffbd..a8e130f853 100644 --- a/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g-40K.yaml +++ b/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g-40K.yaml @@ -24,6 +24,7 @@ policy: expert_model_parallel_size: 8 sequence_parallel: true context_parallel_size: 8 + moe_grouped_gemm: true optimizer: lr: 3.0e-07 min_lr: 3.0e-08 diff --git a/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g.yaml b/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g.yaml index 795764d3ee..7a03394402 100644 --- a/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g.yaml +++ b/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g.yaml @@ -21,6 +21,7 @@ policy: pipeline_model_parallel_size: 1 expert_model_parallel_size: 8 sequence_parallel: false + moe_grouped_gemm: true optimizer: lr: 3.0e-07 min_lr: 3.0e-08 diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index fc5c6c44fa..140378adde 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -432,6 +432,8 @@ def _apply_moe_config(model_cfg: Any, config: PolicyConfig) -> None: model_cfg.moe_permute_fusion = config["megatron_cfg"]["moe_permute_fusion"] + model_cfg.moe_grouped_gemm = config["megatron_cfg"].get("moe_grouped_gemm", False) + def _apply_mtp_config(model_cfg: Any, config: PolicyConfig) -> None: if "mtp_num_layers" in config["megatron_cfg"]: diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index ec4c9e66bb..b22f4015b2 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -236,6 +236,9 @@ class MegatronConfig(TypedDict): moe_token_dispatcher_type: str # Can be used only with 'alltoall' token dispatcher moe_shared_expert_overlap: bool + # Enable grouped GEMM for MoE experts via CUTLASS. Significant throughput + # gain for large EP configs. Requires TE >= 1.11.0 for FP8. + moe_grouped_gemm: NotRequired[bool] peft: NotRequired[MegatronPeftConfig | MegatronPeftConfigDisabled] optimizer: MegatronOptimizerConfig scheduler: MegatronSchedulerConfig From bd94eaff015f8800406bd1e3f8c7ecbcf1f68b80 Mon Sep 17 00:00:00 2001 From: Seonjin Date: Mon, 20 Apr 2026 00:15:06 -0700 Subject: [PATCH 2/6] Update nemo_rl/models/policy/__init__.py Co-authored-by: Terry Kong Signed-off-by: Seonjin --- nemo_rl/models/policy/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index b22f4015b2..63af536819 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -237,7 +237,8 @@ class MegatronConfig(TypedDict): # Can be used only with 'alltoall' token dispatcher moe_shared_expert_overlap: bool # Enable grouped GEMM for MoE experts via CUTLASS. Significant throughput - # gain for large EP configs. Requires TE >= 1.11.0 for FP8. + # gain when multiple experts are assigned per rank (num_local_experts > 1). + # Requires TE >= 1.11.0 for FP8 and Ampere (sm_80) or newer. moe_grouped_gemm: NotRequired[bool] peft: NotRequired[MegatronPeftConfig | MegatronPeftConfigDisabled] optimizer: MegatronOptimizerConfig From 6cd727863a6d15bf7741308dcf72e0850966ff0e Mon Sep 17 00:00:00 2001 From: Seonjin Date: Mon, 20 Apr 2026 00:15:18 -0700 Subject: [PATCH 3/6] Update nemo_rl/models/megatron/setup.py Co-authored-by: Terry Kong Signed-off-by: Seonjin --- nemo_rl/models/megatron/setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index 140378adde..4c99bacb0b 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -432,7 +432,8 @@ def _apply_moe_config(model_cfg: Any, config: PolicyConfig) -> None: model_cfg.moe_permute_fusion = config["megatron_cfg"]["moe_permute_fusion"] - model_cfg.moe_grouped_gemm = config["megatron_cfg"].get("moe_grouped_gemm", False) + if "moe_grouped_gemm" in config["megatron_cfg"]: + model_cfg.moe_grouped_gemm = config["megatron_cfg"]["moe_grouped_gemm"] def _apply_mtp_config(model_cfg: Any, config: PolicyConfig) -> None: From 65e5c918a598e05dd6b009767e391c02d13e0fde Mon Sep 17 00:00:00 2001 From: sna Date: Fri, 15 May 2026 11:52:42 -0700 Subject: [PATCH 4/6] fix: pin NeMo Gym docs URL to v0.2.1 (latest 404) Signed-off-by: sna --- docs/design-docs/nemo-gym-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/design-docs/nemo-gym-integration.md b/docs/design-docs/nemo-gym-integration.md index 33e324547b..0263d36fef 100644 --- a/docs/design-docs/nemo-gym-integration.md +++ b/docs/design-docs/nemo-gym-integration.md @@ -181,7 +181,7 @@ sequenceDiagram GRPO->>Policy: Compute loss and train ``` -> **NeMo Gym server types** (see [Core Components](https://docs.nvidia.com/nemo/gym/latest/about/concepts/core-components.html)): +> **NeMo Gym server types** (see [Core Components](https://docs.nvidia.com/nemo/gym/v0.2.1/about/concepts/core-components/)): > - **Agent Server**: Orchestrates the rollout loop > - **Model Server**: HTTP proxy to vLLM; translates Responses API ↔ Chat Completions > - **Resource Server**: Provides tools and rewards From 26639a25716f49eb57e32dde14a1b84473a1d594 Mon Sep 17 00:00:00 2001 From: sna Date: Sat, 16 May 2026 15:22:51 -0700 Subject: [PATCH 5/6] test: add unit tests for moe_grouped_gemm config branch Adds three cases covering _apply_moe_config moe_grouped_gemm: - True passes through - False passes through - Absent key leaves attr untouched Lifts patch coverage above codecov target. Signed-off-by: sna --- .../models/megatron/test_megatron_setup.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/unit/models/megatron/test_megatron_setup.py b/tests/unit/models/megatron/test_megatron_setup.py index 1eaa3a1247..6a9f8cdda0 100644 --- a/tests/unit/models/megatron/test_megatron_setup.py +++ b/tests/unit/models/megatron/test_megatron_setup.py @@ -619,6 +619,29 @@ def test_moe_configuration(self): assert model_cfg.moe_token_dispatcher_type == "alltoall" assert model_cfg.moe_shared_expert_overlap is True + @pytest.mark.parametrize("moe_grouped_gemm", [True, False]) + def test_moe_grouped_gemm_explicit(self, moe_grouped_gemm): + """Test that moe_grouped_gemm is applied when present in config.""" + from nemo_rl.models.megatron.setup import _apply_moe_config + + model_cfg = MagicMock() + config = {"megatron_cfg": {"moe_grouped_gemm": moe_grouped_gemm}} + + _apply_moe_config(model_cfg, config) + + assert model_cfg.moe_grouped_gemm is moe_grouped_gemm + + def test_moe_grouped_gemm_absent_keeps_default(self): + """Test that moe_grouped_gemm attribute is not touched when key absent.""" + from nemo_rl.models.megatron.setup import _apply_moe_config + + model_cfg = MagicMock(spec=[]) + config = {"megatron_cfg": {}} + + _apply_moe_config(model_cfg, config) + + assert not hasattr(model_cfg, "moe_grouped_gemm") + @pytest.mark.mcore class TestApplyPrecisionConfig: From e71839a5aa345ce2dea810c6ffed969bfb8657be Mon Sep 17 00:00:00 2001 From: sna Date: Sat, 16 May 2026 16:28:57 -0700 Subject: [PATCH 6/6] fix(test): supply base moe config keys before exercising moe_grouped_gemm _apply_moe_config requires expert_tensor_parallel_size and other keys via direct dict indexing. Minimal config triggered KeyError before reaching the moe_grouped_gemm branch. Base the new tests on the same megatron_cfg shape used by test_moe_configuration. Signed-off-by: sna --- .../models/megatron/test_megatron_setup.py | 40 ++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/tests/unit/models/megatron/test_megatron_setup.py b/tests/unit/models/megatron/test_megatron_setup.py index 6a9f8cdda0..f515e43c75 100644 --- a/tests/unit/models/megatron/test_megatron_setup.py +++ b/tests/unit/models/megatron/test_megatron_setup.py @@ -619,24 +619,54 @@ def test_moe_configuration(self): assert model_cfg.moe_token_dispatcher_type == "alltoall" assert model_cfg.moe_shared_expert_overlap is True + @staticmethod + def _base_moe_megatron_cfg() -> dict: + return { + "expert_tensor_parallel_size": 2, + "expert_model_parallel_size": 4, + "moe_router_dtype": "float32", + "moe_router_load_balancing_type": "none", + "moe_router_bias_update_rate": 0.0, + "moe_permute_fusion": True, + "moe_enable_deepep": False, + "moe_token_dispatcher_type": "alltoall", + "moe_shared_expert_overlap": True, + } + @pytest.mark.parametrize("moe_grouped_gemm", [True, False]) def test_moe_grouped_gemm_explicit(self, moe_grouped_gemm): - """Test that moe_grouped_gemm is applied when present in config.""" + """moe_grouped_gemm is applied when present in config.""" from nemo_rl.models.megatron.setup import _apply_moe_config model_cfg = MagicMock() - config = {"megatron_cfg": {"moe_grouped_gemm": moe_grouped_gemm}} + megatron_cfg = self._base_moe_megatron_cfg() + megatron_cfg["moe_grouped_gemm"] = moe_grouped_gemm + config = {"megatron_cfg": megatron_cfg} _apply_moe_config(model_cfg, config) assert model_cfg.moe_grouped_gemm is moe_grouped_gemm def test_moe_grouped_gemm_absent_keeps_default(self): - """Test that moe_grouped_gemm attribute is not touched when key absent.""" + """Absent key leaves the attr unset on the model cfg.""" from nemo_rl.models.megatron.setup import _apply_moe_config - model_cfg = MagicMock(spec=[]) - config = {"megatron_cfg": {}} + # spec lists everything _apply_moe_config writes so we can detect + # whether the moe_grouped_gemm branch fires. + model_cfg = MagicMock( + spec=[ + "expert_tensor_parallel_size", + "expert_model_parallel_size", + "moe_router_dtype", + "moe_router_load_balancing_type", + "moe_router_bias_update_rate", + "moe_permute_fusion", + "moe_enable_deepep", + "moe_token_dispatcher_type", + "moe_shared_expert_overlap", + ] + ) + config = {"megatron_cfg": self._base_moe_megatron_cfg()} _apply_moe_config(model_cfg, config)