diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 36b16d471..cdc047fb0 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -142,7 +142,7 @@ def __init__( if self.rmsnorm: assert RMSNormGated is not None self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate, - group_size=self.d_ssm // ngroups, **factory_kwargs) + group_size=self.d_ssm // self.ngroups, **factory_kwargs) if self.process_group is None: self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)