fix(mamba2): use self.ngroups for RMSNormGated group_size under tensor parallel#951
Open
Chessing234 wants to merge 1 commit into
Open
Conversation
…r parallel When process_group is set, the ngroups parameter is sharded across ranks via 'self.ngroups = ngroups // self.world_size' (line 83), and every other call site in this module uses self.ngroups for the per-rank value (the in_proj/conv_dim widths on lines 96 and 104, the mamba_split_conv1d_scan_combined ngroups arg on line 200, the xBC split widths, etc.). The RMSNormGated constructor on line 145 was the lone holdout: passing the global ngroups made the per-group size too small by a factor of world_size, so the gated norm split self.d_ssm into world_size x more groups than the chunk-scan kernel does. Behavior is unchanged for the non-distributed path (world_size == 1, where self.ngroups == ngroups).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Bug
In
Mamba2.__init__(mamba_ssm/modules/mamba2.pyline 145), the gated RMSNorm is constructed withgroup_size=self.d_ssm // ngroups:Here
ngroupsis the constructor argument (the global number of groups), whileself.d_ssmhas already been sharded across ranks. Withworld_size > 1the per-group size is too small by a factor ofworld_size, so the gated norm splits the locald_ssmintoworld_size×more groups than the chunk-scan kernel actually uses for B/C.Root cause
Line 83 shards groups across ranks:
self.ngroups = ngroups // self.world_size. Every other site in this module usesself.ngroupsfor the per-rank value —d_in_projandconv_dim(lines 96, 104), thexBCsplit widths (lines 213, 243), themamba_chunk_scan_combined/mamba_split_conv1d_scan_combinedngroupsargument (lines 200, 248). The norm constructor on line 145 was the lone site still referring to the globalngroups.RMSNormGatedinterpretsgroup_sizeas the number of features per normalization group (seemamba_ssm/ops/triton/layernorm_gated.pyLayerNorm/RMSNorm:ngroups_in_norm = hidden_size // group_size). Withself.d_ssm(local) andngroups(global), the resultingngroups_in_norm = local_d_ssm / global_ngroupsisworld_size×smaller thanself.ngroups, so the norm and the kernel disagree on grouping.Fix
Use
self.ngroupsso the per-rank group size matches the per-rank ngroups the kernels consume:For the non-distributed path (
process_group is None,world_size == 1),self.ngroups == ngroups, so behavior is unchanged.