From db7b93fcf7faabb698951854360bfdeca4fdbaef Mon Sep 17 00:00:00 2001 From: Taksh Date: Sun, 17 May 2026 16:38:18 +0530 Subject: [PATCH] fix(mamba2): use self.ngroups for RMSNormGated group_size under tensor parallel When process_group is set, the ngroups parameter is sharded across ranks via 'self.ngroups = ngroups // self.world_size' (line 83), and every other call site in this module uses self.ngroups for the per-rank value (the in_proj/conv_dim widths on lines 96 and 104, the mamba_split_conv1d_scan_combined ngroups arg on line 200, the xBC split widths, etc.). The RMSNormGated constructor on line 145 was the lone holdout: passing the global ngroups made the per-group size too small by a factor of world_size, so the gated norm split self.d_ssm into world_size x more groups than the chunk-scan kernel does. Behavior is unchanged for the non-distributed path (world_size == 1, where self.ngroups == ngroups). --- mamba_ssm/modules/mamba2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 36b16d471..cdc047fb0 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -142,7 +142,7 @@ def __init__( if self.rmsnorm: assert RMSNormGated is not None self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate, - group_size=self.d_ssm // ngroups, **factory_kwargs) + group_size=self.d_ssm // self.ngroups, **factory_kwargs) if self.process_group is None: self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)