Skip to content

fix(mamba2): use self.ngroups for RMSNormGated group_size under tensor parallel#951

Open
Chessing234 wants to merge 1 commit into
state-spaces:mainfrom
Chessing234:fix/mamba2-norm-group-size-tensor-parallel
Open

fix(mamba2): use self.ngroups for RMSNormGated group_size under tensor parallel#951
Chessing234 wants to merge 1 commit into
state-spaces:mainfrom
Chessing234:fix/mamba2-norm-group-size-tensor-parallel

Conversation

@Chessing234
Copy link
Copy Markdown
Contributor

Bug

In Mamba2.__init__ (mamba_ssm/modules/mamba2.py line 145), the gated RMSNorm is constructed with group_size=self.d_ssm // ngroups:

if self.rmsnorm:
    assert RMSNormGated is not None
    self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate,
                             group_size=self.d_ssm // ngroups, **factory_kwargs)

Here ngroups is the constructor argument (the global number of groups), while self.d_ssm has already been sharded across ranks. With world_size > 1 the per-group size is too small by a factor of world_size, so the gated norm splits the local d_ssm into world_size× more groups than the chunk-scan kernel actually uses for B/C.

Root cause

Line 83 shards groups across ranks: self.ngroups = ngroups // self.world_size. Every other site in this module uses self.ngroups for the per-rank value — d_in_proj and conv_dim (lines 96, 104), the xBC split widths (lines 213, 243), the mamba_chunk_scan_combined / mamba_split_conv1d_scan_combined ngroups argument (lines 200, 248). The norm constructor on line 145 was the lone site still referring to the global ngroups.

RMSNormGated interprets group_size as the number of features per normalization group (see mamba_ssm/ops/triton/layernorm_gated.py LayerNorm/RMSNorm: ngroups_in_norm = hidden_size // group_size). With self.d_ssm (local) and ngroups (global), the resulting ngroups_in_norm = local_d_ssm / global_ngroups is world_size× smaller than self.ngroups, so the norm and the kernel disagree on grouping.

Fix

Use self.ngroups so the per-rank group size matches the per-rank ngroups the kernels consume:

group_size=self.d_ssm // self.ngroups,

For the non-distributed path (process_group is None, world_size == 1), self.ngroups == ngroups, so behavior is unchanged.

…r parallel

When process_group is set, the ngroups parameter is sharded across ranks
via 'self.ngroups = ngroups // self.world_size' (line 83), and every other
call site in this module uses self.ngroups for the per-rank value (the
in_proj/conv_dim widths on lines 96 and 104, the mamba_split_conv1d_scan_combined
ngroups arg on line 200, the xBC split widths, etc.). The RMSNormGated
constructor on line 145 was the lone holdout: passing the global ngroups
made the per-group size too small by a factor of world_size, so the gated
norm split self.d_ssm into world_size x more groups than the chunk-scan
kernel does. Behavior is unchanged for the non-distributed path
(world_size == 1, where self.ngroups == ngroups).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant