diff --git a/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml b/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml index 056c15294a..5ef1591ecd 100644 --- a/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml +++ b/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml @@ -44,6 +44,7 @@ policy: empty_unused_memory_level: 2 enabled: true activation_checkpointing: true + moe_grouped_gemm: true tensor_model_parallel_size: 8 expert_model_parallel_size: 32 pipeline_model_parallel_size: 8 diff --git a/examples/configs/recipes/llm/performance/grpo-deepseek-v3-32n8g.yaml b/examples/configs/recipes/llm/performance/grpo-deepseek-v3-32n8g.yaml index 35903949ad..a3911db41a 100644 --- a/examples/configs/recipes/llm/performance/grpo-deepseek-v3-32n8g.yaml +++ b/examples/configs/recipes/llm/performance/grpo-deepseek-v3-32n8g.yaml @@ -26,6 +26,7 @@ policy: pipeline_model_parallel_size: 16 expert_model_parallel_size: 16 activation_checkpointing: true + moe_grouped_gemm: true num_layers_in_first_pipeline_stage: 3 num_layers_in_last_pipeline_stage: 2 apply_rope_fusion: false diff --git a/examples/configs/recipes/llm/performance/grpo-qwen3-235b-16n8g.yaml b/examples/configs/recipes/llm/performance/grpo-qwen3-235b-16n8g.yaml index aecdabba73..1228db60bb 100644 --- a/examples/configs/recipes/llm/performance/grpo-qwen3-235b-16n8g.yaml +++ b/examples/configs/recipes/llm/performance/grpo-qwen3-235b-16n8g.yaml @@ -29,6 +29,7 @@ policy: context_parallel_size: 2 expert_model_parallel_size: 16 activation_checkpointing: true + moe_grouped_gemm: true num_layers_in_first_pipeline_stage: 11 num_layers_in_last_pipeline_stage: 11 defer_fp32_logits: true diff --git a/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n4g.yaml b/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n4g.yaml index 9292a72439..947b2d1b1c 100644 --- a/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n4g.yaml +++ b/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n4g.yaml @@ -22,6 +22,7 @@ policy: pipeline_model_parallel_size: 1 expert_model_parallel_size: 16 sequence_parallel: false + moe_grouped_gemm: true optimizer: lr: 3.0e-07 min_lr: 3.0e-08 diff --git a/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g-40K.yaml b/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g-40K.yaml index 3b4f22ffbd..a8e130f853 100644 --- a/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g-40K.yaml +++ b/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g-40K.yaml @@ -24,6 +24,7 @@ policy: expert_model_parallel_size: 8 sequence_parallel: true context_parallel_size: 8 + moe_grouped_gemm: true optimizer: lr: 3.0e-07 min_lr: 3.0e-08 diff --git a/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g.yaml b/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g.yaml index 21ddcc6bd3..6eda477ba1 100644 --- a/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g.yaml +++ b/examples/configs/recipes/llm/performance/grpo-qwen3-30ba3b-4n8g.yaml @@ -22,6 +22,7 @@ policy: pipeline_model_parallel_size: 1 expert_model_parallel_size: 8 sequence_parallel: false + moe_grouped_gemm: true optimizer: lr: 3.0e-07 min_lr: 3.0e-08 diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index e8bed08fa9..540cfbc58a 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -587,6 +587,9 @@ def _apply_moe_config(model_cfg: Any, config: PolicyConfig) -> None: model_cfg.moe_permute_fusion = config["megatron_cfg"]["moe_permute_fusion"] + if "moe_grouped_gemm" in config["megatron_cfg"]: + model_cfg.moe_grouped_gemm = config["megatron_cfg"]["moe_grouped_gemm"] + def _apply_mtp_config(model_cfg: Any, config: PolicyConfig) -> None: if "mtp_num_layers" in config["megatron_cfg"]: diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index 04e1c8cbe9..a865132758 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -237,6 +237,10 @@ class MegatronConfig(TypedDict): moe_token_dispatcher_type: str # Can be used only with 'alltoall' token dispatcher moe_shared_expert_overlap: bool + # Enable grouped GEMM for MoE experts via CUTLASS. Significant throughput + # gain when multiple experts are assigned per rank (num_local_experts > 1). + # Requires TE >= 1.11.0 for FP8 and Ampere (sm_80) or newer. + moe_grouped_gemm: NotRequired[bool] peft: NotRequired[MegatronPeftConfig | MegatronPeftConfigDisabled] optimizer: MegatronOptimizerConfig scheduler: MegatronSchedulerConfig