diff --git a/README.md b/README.md index bf3b76a93..5a6805693 100755 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ > Tri Dao*, Albert Gu*\ > Paper: https://arxiv.org/abs/2405.21060 +> **Mamba-3: Improved Sequence Modeling with Structured State Spaces**\ +> Paper: https://openreview.net/pdf?id=HwCvaJOiCj + ## About Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers. @@ -95,6 +98,28 @@ assert y.shape == x.shape A minimal version of the inner SSD module (Listing 1 from the Mamba-2 paper) with conversion between "discrete" and "continuous" SSM versions is at [modules/ssd_minimal.py](mamba_ssm/modules/ssd_minimal.py). +### Mamba-3 + +The Mamba-3 block is implemented at [modules/mamba3.py](mamba_ssm/modules/mamba3.py). + +A simpler version is at [modules/mamba3_simple.py](mamba_ssm/modules/mamba3_simple.py) + +Usage: +``` python +from mamba_ssm import Mamba3 +model = Mamba3( + d_model=dim, # Model dimension d_model + d_state=64, # SSM state expansion factor + d_conv=4, # Local convolution width + expand=2, # Block expansion factor +).to("cuda") +y = model(x) +assert y.shape == x.shape +``` + +Mamba-3 adds RoPE, BCNorm, and MIMO (multi-input multi-output) support on top of the SSD framework. +To use Mamba-3 in a full language model, set `"layer": "Mamba3"` in `ssm_cfg`. + ### Mamba Language Model Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head. @@ -240,4 +265,10 @@ If you use this codebase, or otherwise find our work valuable, please cite Mamba year={2024} } +@inproceedings{mamba3, + title={Mamba-3: Improved Sequence Modeling with Structured State Spaces}, + booktitle={International Conference on Learning Representations (ICLR)}, + year={2026} +} + ``` diff --git a/mamba_ssm/__init__.py b/mamba_ssm/__init__.py index d9f24a80f..67099cec9 100644 --- a/mamba_ssm/__init__.py +++ b/mamba_ssm/__init__.py @@ -3,4 +3,5 @@ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn from mamba_ssm.modules.mamba_simple import Mamba from mamba_ssm.modules.mamba2 import Mamba2 +from mamba_ssm.modules.mamba3 import Mamba3 from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index fae2257a9..b2b582d16 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -14,6 +14,7 @@ from mamba_ssm.models.config_mamba import MambaConfig from mamba_ssm.modules.mamba_simple import Mamba from mamba_ssm.modules.mamba2 import Mamba2 +from mamba_ssm.modules.mamba3 import Mamba3 from mamba_ssm.modules.mha import MHA from mamba_ssm.modules.mlp import GatedMLP from mamba_ssm.modules.block import Block @@ -51,10 +52,11 @@ def create_block( # Create a copy of the config to modify ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {} ssm_layer = ssm_cfg.pop("layer", "Mamba1") - if ssm_layer not in ["Mamba1", "Mamba2"]: - raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2") + if ssm_layer not in ["Mamba1", "Mamba2", "Mamba3"]: + raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1, Mamba2, and Mamba3") + layer_cls = {"Mamba1": Mamba, "Mamba2": Mamba2, "Mamba3": Mamba3}[ssm_layer] mixer_cls = partial( - Mamba2 if ssm_layer == "Mamba2" else Mamba, + layer_cls, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs diff --git a/mamba_ssm/modules/mamba3.py b/mamba_ssm/modules/mamba3.py new file mode 100644 index 000000000..97f6b21f9 --- /dev/null +++ b/mamba_ssm/modules/mamba3.py @@ -0,0 +1,916 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Mamba-3 implementation based on "Mamba-3: Improved Sequence Modeling Using State Space Principles" +# Key changes from Mamba-2: +# 1. Exponential-trapezoidal discretization (lookback recurrence) +# 2. Complex-valued SSM via data-dependent RoPE on B, C +# 3. MIMO (multi-input, multi-output) SSM option +# 4. BCNorm (RMSNorm on B, C projections) +# 5. Learnable B, C biases (head-specific, channel-wise) +# 6. No short causal convolution + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat + +try: + from mamba_ssm.ops.triton.selective_state_update import selective_state_update +except ImportError: + selective_state_update = None + +try: + from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated +except ImportError: + RMSNormGated = None + +try: + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +except ImportError: + mamba_chunk_scan_combined = None + +try: + from mamba_ssm.ops.triton.mamba3_ssd import ( + mamba3_chunk_scan_combined, + mamba3_state_update, + ) +except ImportError: + mamba3_chunk_scan_combined = None + mamba3_state_update = None + +try: + from mamba_ssm.ops.triton.mamba3_combined import mamba3_chunk_scan_combined_triton + _has_triton_combined = True +except ImportError: + mamba3_chunk_scan_combined_triton = None + _has_triton_combined = False + +from torch.utils.checkpoint import checkpoint as gradient_checkpoint + +from mamba_ssm.distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear +from mamba_ssm.distributed.distributed_utils import all_reduce, reduce_scatter + +from huggingface_hub import PyTorchModelHubMixin + + +def apply_rotary_emb(x, cos, sin): + """Apply rotary embedding to x. x: (..., N), cos/sin: (..., N/2).""" + d = x.shape[-1] + x1, x2 = x[..., :d // 2], x[..., d // 2:] + out1 = x1 * cos - x2 * sin + out2 = x1 * sin + x2 * cos + return torch.cat([out1, out2], dim=-1) + + +def compute_cumulative_rotary(theta_cumsum, dstate): + """Compute cumulative rotation angles for data-dependent RoPE. + + Args: + theta_cumsum: (batch, seqlen, nheads, dstate//2) cumulative sum of theta angles + Returns: + cos, sin: (batch, seqlen, nheads, dstate//2) cumulative cos/sin + """ + return torch.cos(theta_cumsum), torch.sin(theta_cumsum) + + +class Mamba3(nn.Module, PyTorchModelHubMixin): + def __init__( + self, + d_model, + d_state=64, + expand=2, + headdim=64, + d_ssm=None, + ngroups=1, + A_init_range=(1, 16), + D_has_hdim=False, + rmsnorm=True, + norm_before_gate=False, + dt_min=0.001, + dt_max=0.1, + dt_init_floor=1e-4, + dt_limit=(0.0, float("inf")), + bias=False, + # Mamba-3 specific + use_rope=True, + use_trapezoidal=True, + use_bc_norm=True, + use_bc_bias=True, + mimo_rank=0, # 0 = SISO, >0 = MIMO with this rank + # Fused kernel and sharding options + chunk_size=256, + use_mem_eff_path=True, # gradient checkpointing for memory-efficient training + use_triton_fwd=True, # use Triton-accelerated forward when available + layer_idx=None, + process_group=None, + sequence_parallel=True, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.expand = expand + self.process_group = process_group + self.sequence_parallel = sequence_parallel + self.world_size = 1 if process_group is None else process_group.size() + self.local_rank = 0 if process_group is None else process_group.rank() + self.d_inner = (self.expand * self.d_model) // self.world_size + assert self.d_inner * self.world_size == self.expand * self.d_model + self.headdim = headdim + self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size + assert self.d_ssm == self.d_inner, ( + f"Mamba3 requires d_ssm == d_inner (got d_ssm={self.d_ssm}, d_inner={self.d_inner}). " + f"Unlike Mamba-2, Mamba-3 does not support partial SSM dimension." + ) + assert ngroups % self.world_size == 0 + self.ngroups = ngroups // self.world_size + assert self.d_ssm % self.headdim == 0 + self.nheads = self.d_ssm // self.headdim + self.D_has_hdim = D_has_hdim + self.rmsnorm = rmsnorm + self.norm_before_gate = norm_before_gate + self.dt_limit = dt_limit + self.chunk_size = chunk_size + self.use_mem_eff_path = use_mem_eff_path + self.use_triton_fwd = use_triton_fwd + self.layer_idx = layer_idx + + # Mamba-3 specific + self.use_rope = use_rope + self.use_trapezoidal = use_trapezoidal + self.use_bc_norm = use_bc_norm + self.use_bc_bias = use_bc_bias + self.mimo_rank = mimo_rank + self.is_mimo = mimo_rank > 0 + + # Projection sizes + # For MIMO: B and C project to (ngroups * d_state * mimo_rank) instead of (ngroups * d_state) + bc_dim = self.ngroups * self.d_state + if self.is_mimo: + bc_proj_dim = bc_dim * self.mimo_rank + else: + bc_proj_dim = bc_dim + + # dt: nheads + # theta (for RoPE): nheads * (d_state // 2) if use_rope else 0 + theta_dim = self.nheads * (self.d_state // 2) if self.use_rope else 0 + # lambda (for trapezoidal): nheads if use_trapezoidal else 0 + lambda_dim = self.nheads if self.use_trapezoidal else 0 + + # Order: [z, x, B, C, dt, theta, lambda] + d_in_proj = ( + self.d_inner # z (gate) + + self.d_ssm # x (SSM input) + + bc_proj_dim # B + + bc_proj_dim # C + + self.nheads # dt + + theta_dim # theta for RoPE + + lambda_dim # lambda for trapezoidal + ) + + if self.process_group is None: + self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs) + else: + self.in_proj = ColumnParallelLinear( + self.d_model, d_in_proj * self.world_size, bias=bias, + process_group=self.process_group, + sequence_parallel=self.sequence_parallel, + **factory_kwargs, + ) + + # For MIMO, x also needs rank R projection + if self.is_mimo: + # x goes from (batch, seqlen, d_ssm) to (batch, seqlen, nheads, headdim, mimo_rank) + # We project headdim -> headdim * mimo_rank per head + self.x_mimo_proj = nn.Linear(self.headdim, self.headdim * self.mimo_rank, bias=False, **factory_kwargs) + # Learned MIMO output projection: PR -> P per head (paper Section D, W_{O'}) + self.mimo_out_proj = nn.Linear(self.headdim * self.mimo_rank, self.headdim, bias=False, **factory_kwargs) + + # dt bias + dt = torch.exp( + torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ) + dt = torch.clamp(dt, min=dt_init_floor) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + self.dt_bias._no_weight_decay = True + + # A parameter (data-dependent in Mamba-3, but we keep log-space init) + assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] + A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range) + A_log = torch.log(A).to(dtype=dtype) + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)) + self.D._no_weight_decay = True + + # BC Norm (Mamba-3: RMSNorm on B and C after projection) + if self.use_bc_norm: + self.B_norm = nn.RMSNorm(self.d_state, eps=1e-5, **factory_kwargs) + self.C_norm = nn.RMSNorm(self.d_state, eps=1e-5, **factory_kwargs) + + # BC Bias (Mamba-3: learnable head-specific channel-wise biases, init=1.0 per paper Table 9a) + if self.use_bc_bias: + self.B_bias = nn.Parameter(torch.ones(self.nheads, self.d_state, **factory_kwargs)) + self.C_bias = nn.Parameter(torch.ones(self.nheads, self.d_state, **factory_kwargs)) + + # Output norm (gated RMSNorm) + if self.rmsnorm: + assert RMSNormGated is not None + self.norm = RMSNormGated( + self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate, + group_size=self.d_ssm // self.ngroups, **factory_kwargs, + ) + + # Output projection + if self.process_group is None: + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + else: + self.out_proj = RowParallelLinear( + self.d_inner * self.world_size, self.d_model, bias=bias, + process_group=self.process_group, + sequence_parallel=self.sequence_parallel, + **factory_kwargs, + ) + + # Store split sizes for forward + self._split_sizes = [ + self.d_inner, # z + self.d_ssm, # x + bc_proj_dim, # B + bc_proj_dim, # C + self.nheads, # dt + ] + if self.use_rope: + self._split_sizes.append(theta_dim) + if self.use_trapezoidal: + self._split_sizes.append(lambda_dim) + + def _process_bc(self, B_raw, C_raw): + """Apply BCNorm to B and C projections (at group level, before expansion). + + BC bias is applied separately after group→head expansion to preserve + head-specificity (paper Section 3.4: "head-specific, channel-wise biases"). + + Args: + B_raw: (batch, seqlen, ngroups, d_state [, mimo_rank]) + C_raw: same shape + Returns: + B, C with norm applied (bias applied later after group→head expansion) + """ + if self.use_bc_norm: + orig_shape = B_raw.shape + if B_raw.dim() == 5: + # MIMO: (b, l, g, d_state, mimo_rank) — move rank before d_state to normalize correctly + B_raw = self.B_norm(B_raw.movedim(-1, -2).reshape(-1, self.d_state)).reshape( + *orig_shape[:-2], orig_shape[-1], orig_shape[-2] + ).movedim(-1, -2) + C_raw = self.C_norm(C_raw.movedim(-1, -2).reshape(-1, self.d_state)).reshape( + *orig_shape[:-2], orig_shape[-1], orig_shape[-2] + ).movedim(-1, -2) + else: + B_raw = self.B_norm(B_raw.reshape(-1, self.d_state)).reshape(orig_shape) + C_raw = self.C_norm(C_raw.reshape(-1, self.d_state)).reshape(orig_shape) + + return B_raw, C_raw + + def _apply_bc_bias(self, B, C): + """Apply head-specific BC bias after group→head expansion. + + Args: + B: (batch, seqlen, nheads, d_state [, mimo_rank]) + C: same shape + Returns: + B, C with per-head bias applied + """ + if not self.use_bc_bias: + return B, C + if self.is_mimo: + # B_bias: (nheads, d_state) -> broadcast over (batch, seqlen, nheads, d_state, mimo_rank) + B = B + self.B_bias.view(1, 1, self.nheads, self.d_state, 1) + C = C + self.C_bias.view(1, 1, self.nheads, self.d_state, 1) + else: + # B_bias: (nheads, d_state) -> broadcast over (batch, seqlen, nheads, d_state) + B = B + self.B_bias + C = C + self.C_bias + return B, C + + def _ssd_trapezoidal(self, x, dt, A, B, C, theta=None, lam=None, + initial_states=None, return_final_states=False, + initial_prev_Bx=None, seq_idx=None): + """Reference implementation of Mamba-3 SSD with trapezoidal discretization. + + This is a step-by-step recurrence (not chunked). For production, this should + be replaced with optimized Triton kernels. + + Args: + x: (batch, seqlen, nheads, headdim) or (batch, seqlen, nheads, headdim, mimo_rank) for MIMO + dt: (batch, seqlen, nheads) - already processed (softplus applied) + A: (nheads,) - negative values + B: (batch, seqlen, nheads, d_state) or (..., d_state, mimo_rank) — already expanded + C: same as B + theta: (batch, seqlen, nheads, d_state//2) or None + lam: (batch, seqlen, nheads) or None - trapezoidal lambda in [0, 1] + initial_states: (batch, nheads, headdim, d_state) or None + return_final_states: bool + Returns: + y: (batch, seqlen, nheads, headdim) + final_states: (batch, nheads, headdim, d_state) if return_final_states + """ + batch, seqlen, nheads, headdim = x.shape[:4] + dstate = B.shape[-2] if self.is_mimo else B.shape[-1] + # B, C are already at head level (expanded + biased in forward()) + + # Apply RoPE if enabled + if self.use_rope and theta is not None: + # theta: (batch, seqlen, nheads, d_state//2) + # Cumulative sum of theta for data-dependent RoPE + theta_cumsum = torch.cumsum(theta, dim=1) # (batch, seqlen, nheads, d_state//2) + cos_t, sin_t = compute_cumulative_rotary(theta_cumsum, dstate) + + if self.is_mimo: + # Apply RoPE to each rank slice (out-of-place for autograd safety) + B_parts = [apply_rotary_emb(B[:, :, :, :, r], cos_t, sin_t) for r in range(self.mimo_rank)] + C_parts = [apply_rotary_emb(C[:, :, :, :, r], cos_t, sin_t) for r in range(self.mimo_rank)] + B = torch.stack(B_parts, dim=-1) + C = torch.stack(C_parts, dim=-1) + else: + B = apply_rotary_emb(B, cos_t, sin_t) + C = apply_rotary_emb(C, cos_t, sin_t) + + # Discretize + alpha = torch.exp(dt.unsqueeze(-1) * A.view(1, 1, nheads, 1)) # (batch, seqlen, nheads, 1) + + if self.use_trapezoidal and lam is not None: + gamma = lam * dt # (batch, seqlen, nheads) — λ * Δt + beta = (1 - lam) * dt * torch.exp(dt * A.view(1, 1, nheads)) # (1-λ) * Δt * α + else: + gamma = dt # Euler fallback: γ = Δt + beta = None + + # Initialize state + if self.is_mimo: + h = torch.zeros(batch, nheads, headdim, dstate, device=x.device, dtype=torch.float32) + else: + h = torch.zeros(batch, nheads, headdim, dstate, device=x.device, dtype=torch.float32) + + if initial_states is not None: + h = initial_states.float() + + ys = [] + prev_Bx = initial_prev_Bx # For trapezoidal: B_{t-1} * x_{t-1} + + for t in range(seqlen): + # Reset state at document boundaries + if seq_idx is not None and t > 0: + boundary = (seq_idx[:, t] != seq_idx[:, t - 1]) # (batch,) + if boundary.any(): + mask = boundary.view(-1, 1, 1, 1).float() + h = h * (1 - mask) # zero state at boundary + if initial_states is not None: + h = h + mask * initial_states.float() + prev_Bx = None if boundary.all() else ( + prev_Bx * (1 - mask) if prev_Bx is not None else None + ) + + # State transition: h_t = alpha_t * h_{t-1} + beta_t * B_{t-1} * x_{t-1} + gamma_t * B_t * x_t + alpha_t = alpha[:, t] # (batch, nheads, 1) + + if self.is_mimo: + x_t = x[:, t] + B_t = B[:, t] + Bx_t = torch.einsum("bhpr,bhnr->bhpn", x_t.float(), B_t.float()) + else: + x_t = x[:, t] + B_t = B[:, t] + Bx_t = torch.einsum("bhp,bhn->bhpn", x_t.float(), B_t.float()) + + gamma_t = gamma[:, t].unsqueeze(-1).unsqueeze(-1) # (batch, nheads, 1, 1) + + h = alpha_t.unsqueeze(-1) * h + gamma_t * Bx_t + + if beta is not None and prev_Bx is not None: + beta_t = beta[:, t].unsqueeze(-1).unsqueeze(-1) # (batch, nheads, 1, 1) + h = h + beta_t * prev_Bx + + prev_Bx = Bx_t + + # Output: y_t = C_t^T @ h_t + if self.is_mimo: + C_t = C[:, t] # (batch, nheads, d_state, mimo_rank) + y_t = torch.einsum("bhpn,bhnr->bhpr", h.to(C_t.dtype), C_t) + # Per-rank output: (batch, nheads, headdim, mimo_rank) + else: + C_t = C[:, t] # (batch, nheads, d_state) + y_t = torch.einsum("bhpn,bhn->bhp", h.to(C_t.dtype), C_t) + + ys.append(y_t) + + y = torch.stack(ys, dim=1) # (batch, seqlen, nheads, headdim) + + if return_final_states: + return y, h + return y + + def _ssd_chunked(self, x, dt, A, B, C, theta=None, lam=None, + initial_states=None, seq_idx=None, return_final_states=False, + initial_prev_Bx=None): + """Chunked parallel implementation. + + Uses mamba3_chunk_scan_combined (matmul-based parallel within chunks) when available. + Falls back to step-by-step recurrence otherwise. + """ + nheads = self.nheads + + # Compute trapezoidal weights + gamma = None + beta = None + if self.use_trapezoidal and lam is not None: + gamma = lam * dt # (batch, seqlen, nheads) + beta = (1 - lam) * dt * torch.exp(dt * A.view(1, 1, nheads)) + else: + gamma = dt # Euler fallback + + # Choose between Triton-accelerated and PyTorch-only chunked forward + _scan_fn = None + if _has_triton_combined and self.use_triton_fwd and mamba3_chunk_scan_combined_triton is not None: + _scan_fn = mamba3_chunk_scan_combined_triton + elif mamba3_chunk_scan_combined is not None: + _scan_fn = mamba3_chunk_scan_combined + + if _scan_fn is not None: + # B, C are already expanded to head level, so ngroups=nheads + return _scan_fn( + x, dt, A, B, C, + chunk_size=self.chunk_size, + gamma=gamma, + beta=beta if self.use_trapezoidal else None, + theta=theta, + D=None, # D is applied outside + initial_states=initial_states, + initial_prev_Bx=initial_prev_Bx, + return_final_states=return_final_states, + ngroups=nheads, + seq_idx=seq_idx, + ) + else: + # Fallback to step-by-step (if chunked kernels unavailable) + return self._ssd_trapezoidal( + x, dt, A, B, C, theta=theta, lam=lam, + initial_states=initial_states, + return_final_states=return_final_states, + initial_prev_Bx=initial_prev_Bx, + seq_idx=seq_idx, + ) + + def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None): + """ + u: (batch, seqlen, hidden_dim) if seqlen=None. + Returns: same shape as u + """ + if seqlen is None: + batch = u.shape[0] + else: + batch = u.shape[0] // seqlen + + conv_state, ssm_state = None, None + if inference_params is not None: + inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch + conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch) + if inference_params.seqlen_offset > 0: + out, _, _ = self.step(u, conv_state, ssm_state) + return out + + # Memory-efficient path: gradient checkpointing to avoid storing intermediates + if self.use_mem_eff_path and inference_params is None and torch.is_grad_enabled(): + return gradient_checkpoint( + self._forward_inner, u, seqlen, seq_idx, conv_state, ssm_state, + use_reentrant=False, + ) + return self._forward_inner(u, seqlen, seq_idx, conv_state, ssm_state) + + def _forward_inner(self, u, seqlen, seq_idx, conv_state, ssm_state): + """Core forward computation, factored out for gradient checkpointing.""" + seqlen_og = seqlen + if seqlen is None: + batch, seqlen, dim = u.shape + else: + batch_seqlen, dim = u.shape + batch = batch_seqlen // seqlen + + proj = self.in_proj(u) # (B, L, d_in_proj) + if seqlen_og is not None: + proj = rearrange(proj, "(b l) d -> b l d", l=seqlen) + + # Split projection + splits = torch.split(proj, self._split_sizes, dim=-1) + idx = 0 + z = splits[idx]; idx += 1 + x = splits[idx]; idx += 1 + B_raw = splits[idx]; idx += 1 + C_raw = splits[idx]; idx += 1 + dt_raw = splits[idx]; idx += 1 + theta_raw = splits[idx] if self.use_rope else None; idx += (1 if self.use_rope else 0) + lam_raw = splits[idx] if self.use_trapezoidal else None + + A = -torch.exp(self.A_log.float()) + + # Process dt + dt = F.softplus(dt_raw + self.dt_bias) # (batch, seqlen, nheads) + if self.dt_limit != (0.0, float("inf")): + dt = dt.clamp(min=self.dt_limit[0], max=self.dt_limit[1]) + + # Process B, C + if self.is_mimo: + B = rearrange(B_raw, "b l (g n r) -> b l g n r", g=self.ngroups, r=self.mimo_rank) + C = rearrange(C_raw, "b l (g n r) -> b l g n r", g=self.ngroups, r=self.mimo_rank) + else: + B = rearrange(B_raw, "b l (g n) -> b l g n", g=self.ngroups) + C = rearrange(C_raw, "b l (g n) -> b l g n", g=self.ngroups) + + B, C = self._process_bc(B, C) + + # Expand B, C from groups to heads and apply per-head bias + nheads_per_group = self.nheads // self.ngroups + if self.is_mimo: + B = repeat(B, "b l g n r -> b l (g h) n r", h=nheads_per_group) + C = repeat(C, "b l g n r -> b l (g h) n r", h=nheads_per_group) + else: + B = repeat(B, "b l g n -> b l (g h) n", h=nheads_per_group) + C = repeat(C, "b l g n -> b l (g h) n", h=nheads_per_group) + B, C = self._apply_bc_bias(B, C) + + # Process theta (RoPE angles) + theta = None + if self.use_rope and theta_raw is not None: + theta = rearrange(theta_raw, "b l (h d) -> b l h d", h=self.nheads) + + # Process lambda (trapezoidal parameter) + lam = None + if self.use_trapezoidal and lam_raw is not None: + lam = torch.sigmoid(lam_raw) # (batch, seqlen, nheads) in [0, 1] + + # Process x + x = rearrange(x, "b l (h p) -> b l h p", p=self.headdim) + if self.is_mimo: + # Project x to (batch, seqlen, nheads, headdim, mimo_rank) + x = self.x_mimo_proj(x) # (batch, seqlen, nheads, headdim * mimo_rank) + x = rearrange(x, "b l h (p r) -> b l h p r", r=self.mimo_rank) + + # Extract initial_prev_Bx from conv_state for segmented prefill continuity + initial_prev_Bx = None + if conv_state is not None and self.use_trapezoidal: + prev_Bx_flat_size = self.nheads * self.headdim * self.d_state + prev_Bx_data = conv_state[:, :prev_Bx_flat_size] + if prev_Bx_data.abs().sum() > 0: # non-zero means state was populated by a previous segment + initial_prev_Bx = prev_Bx_data.view(-1, self.nheads, self.headdim, self.d_state) + + # Run SSM (B, C already at head level with bias applied) + result = self._ssd_chunked( + x, dt, A, B, C, + theta=theta, lam=lam, + seq_idx=seq_idx, + return_final_states=ssm_state is not None, + initial_prev_Bx=initial_prev_Bx, + ) + + if ssm_state is not None: + y, last_state = result + ssm_state.copy_(last_state) + else: + y = result if not isinstance(result, tuple) else result[0] + + # Store inference state for decode continuity after prefill + if conv_state is not None: + prev_Bx_flat_size = self.nheads * self.headdim * self.d_state + half_d = self.d_state // 2 + + # Compute cumulative theta first (needed for both RoPE on B_last and storage) + theta_cumsum = None + if self.use_rope and theta is not None: + theta_cumsum = torch.cumsum(theta, dim=1) # (batch, seqlen, nheads, d_state//2) + + # Store last step's B*x for trapezoidal lookback (with RoPE applied) + # B is already at head level with bias applied; just need RoPE + B_last = B[:, -1].clone() # (batch, nheads, d_state[, R]) + if self.is_mimo: + # Apply RoPE to B_last so prev_Bx matches decode step's convention + if theta_cumsum is not None: + theta_last = theta_cumsum[:, -1] # (batch, nheads, d_state//2) + cos_last = torch.cos(theta_last) + sin_last = torch.sin(theta_last) + for r in range(self.mimo_rank): + B_last[:, :, :, r] = apply_rotary_emb(B_last[:, :, :, r], cos_last, sin_last) + x_last = x[:, -1] # (batch, nheads, headdim, mimo_rank) + Bx_last = torch.einsum("bhpr,bhnr->bhpn", x_last.float(), B_last.float()) + else: + # Apply RoPE to B_last so prev_Bx matches decode step's convention + if theta_cumsum is not None: + theta_last = theta_cumsum[:, -1] # (batch, nheads, d_state//2) + cos_last = torch.cos(theta_last) + sin_last = torch.sin(theta_last) + B_last = apply_rotary_emb(B_last, cos_last, sin_last) + x_last = x[:, -1] # (batch, nheads, headdim) + Bx_last = torch.einsum("bhp,bhn->bhpn", x_last.float(), B_last.float()) + conv_state[:, :prev_Bx_flat_size] = Bx_last.reshape(batch, -1) + + # Store cumulative theta for RoPE continuity + if theta_cumsum is not None: + theta_total = theta_cumsum[:, -1] # (batch, nheads, d_state//2) + cum_theta_offset = prev_Bx_flat_size + conv_state[:, cum_theta_offset:cum_theta_offset + self.nheads * half_d] = \ + theta_total.reshape(batch, -1).float() + + # D skip connection + flatten (cast D to input dtype to avoid float32 promotion) + D = self.D.to(dtype=y.dtype) + if self.is_mimo: + # y: (B, L, H, P, R), x: (B, L, H, P, R) + if self.D_has_hdim: + y = y + x * rearrange(D, "(h p) -> 1 1 h p 1", p=self.headdim) + else: + y = y + x * repeat(D, "h -> 1 1 h 1 1") + # Learned MIMO output projection: (P*R) -> P per head + y = self.mimo_out_proj(rearrange(y, "b l h p r -> b l h (p r)")) + y = rearrange(y, "b l h p -> b l (h p)") + else: + y = rearrange(y, "b l h p -> b l (h p)") + x_flat = rearrange(x, "b l h p -> b l (h p)") + if self.D_has_hdim: + y = y + x_flat * rearrange(D, "(h p) -> h p", p=self.headdim).reshape(1, 1, -1) + else: + y = y + x_flat * repeat(D, "h -> (h p)", p=self.headdim).reshape(1, 1, -1) + + # Gated output norm + if self.rmsnorm: + y = self.norm(y, z) + else: + y = y * F.silu(z) + + if seqlen_og is not None: + y = rearrange(y, "b l d -> (b l) d") + + out = self.out_proj(y) + + if self.process_group is not None: + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + out = reduce_fn(out, self.process_group) + + return out + + def step(self, hidden_states, conv_state, ssm_state): + """Single-token decoding step. + + conv_state is a dict-like object (or tuple) with: + - prev_Bx: (batch, nheads, headdim, d_state) for trapezoidal lookback + - cum_theta: (batch, nheads, d_state//2) for cumulative RoPE angles + For simplicity we pack them into a single tensor: + conv_state: (batch, nheads, headdim * d_state + d_state // 2) + """ + dtype = hidden_states.dtype + assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time" + proj = self.in_proj(hidden_states.squeeze(1)) # (B, d_in_proj) + + # Split + splits = torch.split(proj, self._split_sizes, dim=-1) + idx = 0 + z = splits[idx]; idx += 1 + x = splits[idx]; idx += 1 + B_raw = splits[idx]; idx += 1 + C_raw = splits[idx]; idx += 1 + dt_raw = splits[idx]; idx += 1 + theta_raw = splits[idx] if self.use_rope else None; idx += (1 if self.use_rope else 0) + lam_raw = splits[idx] if self.use_trapezoidal else None + + A = -torch.exp(self.A_log.float()) + + # Process dt + dt = F.softplus(dt_raw + self.dt_bias.to(dtype=dt_raw.dtype)) # (batch, nheads) + if self.dt_limit != (0.0, float("inf")): + dt = dt.clamp(min=self.dt_limit[0], max=self.dt_limit[1]) + + dA = torch.exp(dt * A) # (batch, nheads) + + # Process B, C + if self.is_mimo: + B = rearrange(B_raw, "b (g n r) -> b g n r", g=self.ngroups, r=self.mimo_rank) + C = rearrange(C_raw, "b (g n r) -> b g n r", g=self.ngroups, r=self.mimo_rank) + else: + B = rearrange(B_raw, "b (g n) -> b g n", g=self.ngroups) + C = rearrange(C_raw, "b (g n) -> b g n", g=self.ngroups) + + # BC Norm + if self.use_bc_norm: + orig = B.shape + if self.is_mimo: + # MIMO: (b, g, d_state, mimo_rank) — move rank before d_state + B = self.B_norm(B.movedim(-1, -2).reshape(-1, self.d_state)).reshape( + *orig[:-2], orig[-1], orig[-2] + ).movedim(-1, -2) + C = self.C_norm(C.movedim(-1, -2).reshape(-1, self.d_state)).reshape( + *orig[:-2], orig[-1], orig[-2] + ).movedim(-1, -2) + else: + B = self.B_norm(B.reshape(-1, self.d_state)).reshape(orig) + C = self.C_norm(C.reshape(-1, self.d_state)).reshape(orig) + + # Expand B, C from groups to heads + nheads_per_group = self.nheads // self.ngroups + if self.is_mimo: + B = repeat(B, "b g n r -> b (g h) n r", h=nheads_per_group) + C = repeat(C, "b g n r -> b (g h) n r", h=nheads_per_group) + else: + B = repeat(B, "b g n -> b (g h) n", h=nheads_per_group) + C = repeat(C, "b g n -> b (g h) n", h=nheads_per_group) + + # Apply head-specific BC bias (after expansion for true per-head bias) + if self.use_bc_bias: + if self.is_mimo: + B = B + self.B_bias.view(1, self.nheads, self.d_state, 1) + C = C + self.C_bias.view(1, self.nheads, self.d_state, 1) + else: + B = B + self.B_bias + C = C + self.C_bias + + # Unpack conv_state -> prev_Bx and cum_theta + prev_Bx_flat_size = self.nheads * self.headdim * self.d_state + half_d = self.d_state // 2 + prev_Bx = conv_state[:, :prev_Bx_flat_size].view( + -1, self.nheads, self.headdim, self.d_state + ) + + # Apply RoPE to B, C + if self.use_rope and theta_raw is not None: + theta = rearrange(theta_raw, "b (h d) -> b h d", h=self.nheads) + cum_theta_offset = prev_Bx_flat_size + cum_theta = conv_state[:, cum_theta_offset:cum_theta_offset + self.nheads * half_d].view( + -1, self.nheads, half_d + ) + cum_theta = cum_theta + theta + conv_state[:, cum_theta_offset:cum_theta_offset + self.nheads * half_d] = cum_theta.view( + -1, self.nheads * half_d + ) + + cos_t = torch.cos(cum_theta) + sin_t = torch.sin(cum_theta) + + if self.is_mimo: + for r in range(self.mimo_rank): + B[:, :, :, r] = apply_rotary_emb(B[:, :, :, r], cos_t, sin_t) + C[:, :, :, r] = apply_rotary_emb(C[:, :, :, r], cos_t, sin_t) + else: + B = apply_rotary_emb(B, cos_t, sin_t) + C = apply_rotary_emb(C, cos_t, sin_t) + + # Process x + x = rearrange(x, "b (h p) -> b h p", p=self.headdim) + if self.is_mimo: + x = self.x_mimo_proj(x) + x = rearrange(x, "b h (p r) -> b h p r", r=self.mimo_rank) + + # Compute trapezoidal weights + lam_val = None + gamma_scalar = None + beta_scalar = None + if self.use_trapezoidal and lam_raw is not None: + lam_val = torch.sigmoid(lam_raw) # (batch, nheads) + gamma_scalar = lam_val * dt # (batch, nheads) + beta_scalar = (1 - lam_val) * dt * dA # (batch, nheads) + else: + gamma_scalar = dt + + # Use Triton kernel for decode when available (supports both SISO and MIMO) + use_triton_decode = ( + mamba3_state_update is not None + and ssm_state.is_cuda + ) + + if use_triton_decode: + # Triton kernel handles: state update + output in one fused op + # B, C are already at head level after preprocessing — pass as ngroups=nheads + if self.is_mimo: + # MIMO: kernel returns (B, H, P, R), D and z not applied by kernel + y = mamba3_state_update( + ssm_state, x, dt, A, + B, C, + D=None, z=None, + prev_Bx=prev_Bx, + beta=beta_scalar, + gamma=gamma_scalar, + ) + D_val = self.D.to(dtype) + if self.D_has_hdim: + y = y + x * rearrange(D_val, "(h p) -> 1 h p 1", p=self.headdim) + else: + y = y + x * rearrange(D_val, "h -> 1 h 1 1") + y = self.mimo_out_proj(rearrange(y, "b h p r -> b h (p r)")) + y = rearrange(y, "b h p -> b (h p)") + else: + # SISO: kernel also applies D (scalar per head) + y = mamba3_state_update( + ssm_state, x, dt, A, + B, C, + D=self.D if not self.D_has_hdim else None, + z=None, # we handle norm+gate separately + prev_Bx=prev_Bx, + beta=beta_scalar, + gamma=gamma_scalar, + ) + if self.D_has_hdim: + y = y + rearrange(self.D.to(dtype), "(h p) -> h p", p=self.headdim) * x + y = rearrange(y, "b h p -> b (h p)") + # prev_Bx was updated in-place by the kernel + else: + # PyTorch fallback (MIMO or no Triton) + if self.is_mimo: + Bx = torch.einsum("bhpr,bhnr->bhpn", x.float(), B.float()) + else: + Bx = torch.einsum("bhp,bhn->bhpn", x.float(), B.float()) + + gamma_4d = gamma_scalar.unsqueeze(-1).unsqueeze(-1) + if beta_scalar is not None: + beta_4d = beta_scalar.unsqueeze(-1).unsqueeze(-1) + ssm_state.copy_( + ssm_state * rearrange(dA, "b h -> b h 1 1") + + gamma_4d * Bx + + beta_4d * prev_Bx + ) + else: + ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + gamma_4d * Bx) + + # Store current Bx as prev_Bx for next step + conv_state[:, :prev_Bx_flat_size] = Bx.view(-1, prev_Bx_flat_size) + + # Output: y = C^T @ h + if self.is_mimo: + y = torch.einsum("bhpn,bhnr->bhpr", ssm_state.to(dtype), C) + if self.D_has_hdim: + y = y + x * rearrange(self.D.to(dtype), "(h p) -> 1 h p 1", p=self.headdim) + else: + y = y + x * rearrange(self.D.to(dtype), "h -> 1 h 1 1") + y = self.mimo_out_proj(rearrange(y, "b h p r -> b h (p r)")) + y = rearrange(y, "b h p -> b (h p)") + else: + y = torch.einsum("bhpn,bhn->bhp", ssm_state.to(dtype), C) + if self.D_has_hdim: + y = y + rearrange(self.D.to(dtype), "(h p) -> h p", p=self.headdim) * x + else: + y = y + rearrange(self.D.to(dtype), "h -> h 1") * x + y = rearrange(y, "b h p -> b (h p)") + + # Gated output + if self.rmsnorm: + y = self.norm(y, z) + else: + y = y * F.silu(z) + + out = self.out_proj(y) + if self.process_group is not None: + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + out = reduce_fn(out, self.process_group) + return out.unsqueeze(1), conv_state, ssm_state + + def _conv_state_size(self): + """Size of the flattened conv_state for inference.""" + prev_Bx_size = self.nheads * self.headdim * self.d_state + theta_size = self.nheads * (self.d_state // 2) if self.use_rope else 0 + return prev_Bx_size + theta_size + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + device = self.out_proj.weight.device + # Always float32 to avoid precision loss when storing Bx and cumulative theta + conv_state = torch.zeros( + batch_size, self._conv_state_size(), device=device, dtype=torch.float32, + ) + ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + ssm_state = torch.zeros( + batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype, + ) + return conv_state, ssm_state + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + assert self.layer_idx is not None + if self.layer_idx not in inference_params.key_value_memory_dict: + conv_state = torch.zeros( + batch_size, self._conv_state_size(), + device=self.in_proj.weight.device, + dtype=torch.float32, + ) + ssm_state = torch.zeros( + batch_size, self.nheads, self.headdim, self.d_state, + device=self.in_proj.weight.device, + dtype=self.in_proj.weight.dtype, + ) + inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) + else: + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + if initialize_states: + conv_state.zero_() + ssm_state.zero_() + return conv_state, ssm_state diff --git a/mamba_ssm/modules/mamba3_simple.py b/mamba_ssm/modules/mamba3_simple.py new file mode 100644 index 000000000..20eba372a --- /dev/null +++ b/mamba_ssm/modules/mamba3_simple.py @@ -0,0 +1,358 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Mamba-3 simplified implementation (no TP, no inference/step support). + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat + +try: + from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated +except ImportError: + RMSNormGated = None + +from mamba_ssm.modules.mamba3 import apply_rotary_emb, compute_cumulative_rotary + +try: + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_chunk_scan_combined +except ImportError: + mamba3_chunk_scan_combined = None + +try: + from mamba_ssm.ops.triton.mamba3_combined import mamba3_chunk_scan_combined_triton + _has_triton_combined = True +except ImportError: + mamba3_chunk_scan_combined_triton = None + _has_triton_combined = False + + +class Mamba3Simple(nn.Module): + def __init__( + self, + d_model, + d_state=64, + expand=2, + headdim=64, + ngroups=1, + A_init_range=(1, 16), + dt_min=0.001, + dt_max=0.1, + dt_init_floor=1e-4, + dt_limit=(0.0, float("inf")), + learnable_init_states=False, + bias=False, + # Mamba-3 specific + use_rope=True, + use_trapezoidal=True, + use_bc_norm=True, + use_bc_bias=True, + mimo_rank=0, + # Kernel options + chunk_size=256, + use_triton_fwd=True, # use Triton-accelerated forward when available + layer_idx=None, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.expand = expand + self.d_inner = self.expand * self.d_model + self.headdim = headdim + self.ngroups = ngroups + assert self.d_inner % self.headdim == 0 + self.nheads = self.d_inner // self.headdim + self.dt_limit = dt_limit + self.learnable_init_states = learnable_init_states + self.chunk_size = chunk_size + self.use_triton_fwd = use_triton_fwd + self.layer_idx = layer_idx + + # Mamba-3 specific + self.use_rope = use_rope + self.use_trapezoidal = use_trapezoidal + self.use_bc_norm = use_bc_norm + self.use_bc_bias = use_bc_bias + self.mimo_rank = mimo_rank + self.is_mimo = mimo_rank > 0 + + bc_dim = self.ngroups * self.d_state + bc_proj_dim = bc_dim * self.mimo_rank if self.is_mimo else bc_dim + theta_dim = self.nheads * (self.d_state // 2) if self.use_rope else 0 + lambda_dim = self.nheads if self.use_trapezoidal else 0 + + d_in_proj = ( + self.d_inner + self.d_inner + bc_proj_dim + bc_proj_dim + + self.nheads + theta_dim + lambda_dim + ) + self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs) + + if self.is_mimo: + self.x_mimo_proj = nn.Linear(self.headdim, self.headdim * self.mimo_rank, bias=False, **factory_kwargs) + self.mimo_out_proj = nn.Linear(self.headdim * self.mimo_rank, self.headdim, bias=False, **factory_kwargs) + + if self.learnable_init_states: + self.init_states = nn.Parameter( + torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs) + ) + self.init_states._no_weight_decay = True + + # dt bias + dt = torch.exp( + torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) + ) + dt = torch.clamp(dt, min=dt_init_floor) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + self.dt_bias._no_weight_decay = True + + # A + assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] + A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range) + self.A_log = nn.Parameter(torch.log(A).to(dtype=dtype)) + self.A_log._no_weight_decay = True + + # D + self.D = nn.Parameter(torch.ones(self.nheads, device=device)) + self.D._no_weight_decay = True + + # BC Norm + if self.use_bc_norm: + self.B_norm = nn.RMSNorm(self.d_state, eps=1e-5, **factory_kwargs) + self.C_norm = nn.RMSNorm(self.d_state, eps=1e-5, **factory_kwargs) + + # BC Bias (init=1.0 per paper Table 9a) + if self.use_bc_bias: + self.B_bias = nn.Parameter(torch.ones(self.nheads, self.d_state, **factory_kwargs)) + self.C_bias = nn.Parameter(torch.ones(self.nheads, self.d_state, **factory_kwargs)) + + # Output norm + assert RMSNormGated is not None + self.norm = RMSNormGated( + self.d_inner, eps=1e-5, norm_before_gate=False, + group_size=self.d_inner // self.ngroups, **factory_kwargs, + ) + + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + # Split sizes + self._split_sizes = [self.d_inner, self.d_inner, bc_proj_dim, bc_proj_dim, self.nheads] + if self.use_rope: + self._split_sizes.append(theta_dim) + if self.use_trapezoidal: + self._split_sizes.append(lambda_dim) + + def forward(self, u, seq_idx=None): + """u: (B, L, D). Returns same shape.""" + batch, seqlen, dim = u.shape + + proj = self.in_proj(u) + A = -torch.exp(self.A_log.float()) + + # Split + splits = torch.split(proj, self._split_sizes, dim=-1) + idx = 0 + z = splits[idx]; idx += 1 + x = splits[idx]; idx += 1 + B_raw = splits[idx]; idx += 1 + C_raw = splits[idx]; idx += 1 + dt_raw = splits[idx]; idx += 1 + theta_raw = splits[idx] if self.use_rope else None; idx += (1 if self.use_rope else 0) + lam_raw = splits[idx] if self.use_trapezoidal else None + + # Process dt + dt = F.softplus(dt_raw + self.dt_bias) + if self.dt_limit != (0.0, float("inf")): + dt = dt.clamp(min=self.dt_limit[0], max=self.dt_limit[1]) + + # Reshape B, C + if self.is_mimo: + B = rearrange(B_raw, "b l (g n r) -> b l g n r", g=self.ngroups, r=self.mimo_rank) + C = rearrange(C_raw, "b l (g n r) -> b l g n r", g=self.ngroups, r=self.mimo_rank) + else: + B = rearrange(B_raw, "b l (g n) -> b l g n", g=self.ngroups) + C = rearrange(C_raw, "b l (g n) -> b l g n", g=self.ngroups) + + # BC Norm + if self.use_bc_norm: + orig = B.shape + if self.is_mimo: + # MIMO: (b, l, g, d_state, mimo_rank) — move rank before d_state to normalize correctly + B = self.B_norm(B.movedim(-1, -2).reshape(-1, self.d_state)).reshape( + *orig[:-2], orig[-1], orig[-2] + ).movedim(-1, -2) + C = self.C_norm(C.movedim(-1, -2).reshape(-1, self.d_state)).reshape( + *orig[:-2], orig[-1], orig[-2] + ).movedim(-1, -2) + else: + B = self.B_norm(B.reshape(-1, self.d_state)).reshape(orig) + C = self.C_norm(C.reshape(-1, self.d_state)).reshape(orig) + + # Expand B, C from groups to heads + nheads_per_group = self.nheads // self.ngroups + if self.is_mimo: + B = repeat(B, "b l g n r -> b l (g h) n r", h=nheads_per_group) + C = repeat(C, "b l g n r -> b l (g h) n r", h=nheads_per_group) + else: + B = repeat(B, "b l g n -> b l (g h) n", h=nheads_per_group) + C = repeat(C, "b l g n -> b l (g h) n", h=nheads_per_group) + + # BC Bias (applied per-head after expansion for true head-specificity) + if self.use_bc_bias: + if self.is_mimo: + B = B + self.B_bias.view(1, 1, self.nheads, self.d_state, 1) + C = C + self.C_bias.view(1, 1, self.nheads, self.d_state, 1) + else: + B = B + self.B_bias + C = C + self.C_bias + + # Apply RoPE + if self.use_rope and theta_raw is not None: + theta = rearrange(theta_raw, "b l (h d) -> b l h d", h=self.nheads) + theta_cumsum = torch.cumsum(theta, dim=1) + cos_t, sin_t = compute_cumulative_rotary(theta_cumsum, self.d_state) + if self.is_mimo: + B_parts = [apply_rotary_emb(B[:, :, :, :, r], cos_t, sin_t) for r in range(self.mimo_rank)] + C_parts = [apply_rotary_emb(C[:, :, :, :, r], cos_t, sin_t) for r in range(self.mimo_rank)] + B = torch.stack(B_parts, dim=-1) + C = torch.stack(C_parts, dim=-1) + else: + B = apply_rotary_emb(B, cos_t, sin_t) + C = apply_rotary_emb(C, cos_t, sin_t) + + # Process lambda + lam = torch.sigmoid(lam_raw) if self.use_trapezoidal and lam_raw is not None else None + + # Process x + x = rearrange(x, "b l (h p) -> b l h p", p=self.headdim) + if self.is_mimo: + x = self.x_mimo_proj(x) + x = rearrange(x, "b l h (p r) -> b l h p r", r=self.mimo_rank) + + # Initial states + initial_states = repeat(self.init_states, "... -> b ...", b=batch) if self.learnable_init_states else None + + # Choose between Triton-accelerated chunked path and step-by-step recurrence + _scan_fn = None + if _has_triton_combined and self.use_triton_fwd and mamba3_chunk_scan_combined_triton is not None: + _scan_fn = mamba3_chunk_scan_combined_triton + elif mamba3_chunk_scan_combined is not None: + _scan_fn = mamba3_chunk_scan_combined + + if _scan_fn is not None: + # Compute trapezoidal weights + gamma_val = None + beta_val = None + if self.use_trapezoidal and lam is not None: + gamma_val = lam * dt # (batch, seqlen, nheads) + beta_val = (1 - lam) * dt * torch.exp(dt * A.view(1, 1, self.nheads)) + else: + gamma_val = dt # Euler fallback + + # B, C are already expanded to head level, so ngroups=nheads + # theta has already been applied via RoPE above, so pass theta=None + # to avoid double-application. However, mamba3_chunk_scan_combined + # expects B/C at group level if theta is provided (it expands internally). + # Since we already expanded and applied RoPE, pass theta=None and ngroups=nheads. + y = _scan_fn( + x, dt, A, B, C, + chunk_size=self.chunk_size, + gamma=gamma_val, + beta=beta_val if self.use_trapezoidal else None, + theta=None, # RoPE already applied above + D=None, # D applied outside + initial_states=initial_states, + return_final_states=False, + ngroups=self.nheads, # B, C already at head level + seq_idx=seq_idx, + ) + else: + # Fall back to step-by-step recurrence (reference impl) + y = self._recurrence(x, dt, A, B, C, lam=lam, initial_states=initial_states, seq_idx=seq_idx) + + # D skip + flatten (cast D to y's dtype to avoid float32 promotion) + D = self.D.to(dtype=y.dtype) + if self.is_mimo: + # y: (B, L, H, P, R), x: (B, L, H, P, R) + y = y + x * repeat(D, "h -> 1 1 h 1 1") + y = self.mimo_out_proj(rearrange(y, "b l h p r -> b l h (p r)")) + y = rearrange(y, "b l h p -> b l (h p)") + else: + y = rearrange(y, "b l h p -> b l (h p)") + x_flat = rearrange(x, "b l h p -> b l (h p)") + y = y + x_flat * repeat(D, "h -> (h p)", p=self.headdim).reshape(1, 1, -1) + + # Norm + gate + y = self.norm(y, z) + out = self.out_proj(y) + return out + + def _recurrence(self, x, dt, A, B, C, lam=None, initial_states=None, seq_idx=None): + """Step-by-step reference recurrence with trapezoidal discretization.""" + batch, seqlen = x.shape[0], x.shape[1] + nheads = self.nheads + headdim = self.headdim + dstate = self.d_state + + alpha = torch.exp(dt.unsqueeze(-1) * A.view(1, 1, nheads, 1)) # (B, L, H, 1) + + if self.use_trapezoidal and lam is not None: + gamma = lam * dt # λ * Δt + beta = (1 - lam) * dt * torch.exp(dt * A.view(1, 1, nheads)) + else: + gamma = dt # Euler fallback + beta = None + + h = torch.zeros(batch, nheads, headdim, dstate, device=x.device, dtype=torch.float32) + if initial_states is not None: + h = initial_states.float() + + ys = [] + prev_Bx = None + + for t in range(seqlen): + # Reset state at document boundaries + if seq_idx is not None and t > 0: + boundary = (seq_idx[:, t] != seq_idx[:, t - 1]) # (batch,) + if boundary.any(): + mask = boundary.view(-1, 1, 1, 1).float() + h = h * (1 - mask) + if initial_states is not None: + h = h + mask * initial_states.float() + prev_Bx = None if boundary.all() else ( + prev_Bx * (1 - mask) if prev_Bx is not None else None + ) + + x_t = x[:, t] + B_t = B[:, t] + C_t = C[:, t] + + if self.is_mimo: + Bx_t = torch.einsum("bhpr,bhnr->bhpn", x_t.float(), B_t.float()) + else: + Bx_t = torch.einsum("bhp,bhn->bhpn", x_t.float(), B_t.float()) + + alpha_t = alpha[:, t].unsqueeze(-1) # (B, H, 1, 1) + gamma_t = gamma[:, t].unsqueeze(-1).unsqueeze(-1) + + h = alpha_t * h + gamma_t * Bx_t + + if beta is not None and prev_Bx is not None: + beta_t = beta[:, t].unsqueeze(-1).unsqueeze(-1) + h = h + beta_t * prev_Bx + + prev_Bx = Bx_t + + if self.is_mimo: + y_t = torch.einsum("bhpn,bhnr->bhpr", h.to(C_t.dtype), C_t) + else: + y_t = torch.einsum("bhpn,bhn->bhp", h.to(C_t.dtype), C_t) + + ys.append(y_t) + + return torch.stack(ys, dim=1) diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index a41f1359c..55ca28d15 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -17,7 +17,10 @@ from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd -import selective_scan_cuda +try: + import selective_scan_cuda +except ImportError: + selective_scan_cuda = None class SelectiveScanFn(torch.autograd.Function): diff --git a/mamba_ssm/ops/triton/mamba3_chunk_scan.py b/mamba_ssm/ops/triton/mamba3_chunk_scan.py new file mode 100644 index 000000000..90c084478 --- /dev/null +++ b/mamba_ssm/ops/triton/mamba3_chunk_scan.py @@ -0,0 +1,362 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Mamba-3 chunk scan Triton kernel (forward pass). +# +# Extends Mamba-2's _chunk_scan_fwd_kernel to support the trapezoidal +# discretization used in Mamba-3: +# +# Y_off[m] = C[m] @ prev_states * exp(dA_cs[m]) (inter-chunk, same as Mamba-2) +# +# Y_diag[m] = sum_k L[m,k] * gamma_k * CB[m,k] * x[k] (intra-chunk current term) +# + sum_k L[m,k] * beta_k * CB_s[m,k] * x_s[k] (intra-chunk lookback term) +# +# Where L[m,k] = exp(dA_cs[m] - dA_cs[k]) is the causal decay matrix. +# gamma replaces dt in Mamba-2's diagonal accumulation. +# The lookback term is optional (controlled by HAS_LOOKBACK constexpr). + +import math +from packaging import version + +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + +from mamba_ssm.utils.determinism import autotune_configs + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +def init_to_zero(names): + return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + + +@triton.autotune( + configs=autotune_configs([ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=2, num_warps=2), + ]), + key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], +) +@triton.jit +def _mamba3_chunk_scan_fwd_kernel( + # Pointers to matrices + cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, + dA_cumsum_ptr, gamma_ptr, seq_idx_ptr, + C_ptr, prev_states_ptr, D_ptr, + # Lookback pointers (may be null) + beta_ptr, cb_shifted_ptr, x_shifted_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # cb strides + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, + # x strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + # z strides + stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, + # out strides + stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, + # dA_cumsum strides (chunked layout: batch, chunk, head, csize in memory) + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + # gamma strides (same chunked layout) + stride_gamma_batch, stride_gamma_chunk, stride_gamma_head, stride_gamma_csize, + # seq_idx strides + stride_seq_idx_batch, stride_seq_idx_seqlen, + # C strides + stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, + # prev_states strides + stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + # D stride + stride_D_head, + # beta strides (same chunked layout) + stride_beta_batch, stride_beta_chunk, stride_beta_head, stride_beta_csize, + # cb_shifted strides (same as cb) + stride_cbs_batch, stride_cbs_chunk, stride_cbs_head, stride_cbs_csize_m, stride_cbs_csize_k, + # x_shifted strides (same as x) + stride_xs_batch, stride_xs_seqlen, stride_xs_head, stride_xs_hdim, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + HAS_LOOKBACK: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + + # Advance base pointers to this batch, chunk, head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + gamma_ptr += pid_b * stride_gamma_batch + pid_c * stride_gamma_chunk + pid_h * stride_gamma_head + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head + prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + if HAS_LOOKBACK: + beta_ptr += pid_b * stride_beta_batch + pid_c * stride_beta_chunk + pid_h * stride_beta_head + cb_shifted_ptr += pid_b * stride_cbs_batch + pid_c * stride_cbs_chunk + (pid_h // nheads_ngroups_ratio) * stride_cbs_head + x_shifted_ptr += pid_b * stride_xs_batch + pid_c * chunk_size * stride_xs_seqlen + pid_h * stride_xs_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # === Phase 1: Off-diagonal (inter-chunk) contribution === + # Y_off[m] = C[m] @ prev_states * exp(dA_cs[m]) + # This is identical to Mamba-2. + if IS_TRITON_22 or pid_c > -1: + offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) + prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate) + if not HAS_SEQ_IDX: + scale_m = tl.exp(dA_cs_m) + else: + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + if BLOCK_SIZE_DSTATE <= 128: + C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc = tl.dot(C, prev_states) * scale_m[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K + acc *= scale_m[:, None] + + # === Phase 2: Diagonal (intra-chunk) contribution === + # Current term: sum_k L[m,k] * gamma_k * CB[m,k] * x[k] + # Lookback term: sum_k L[m,k] * beta_k * CB_shifted[m,k] * x_shifted[k] + offs_k = tl.arange(0, BLOCK_SIZE_K) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) + x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + gamma_ptrs = gamma_ptr + offs_k * stride_gamma_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + if HAS_LOOKBACK: + cbs_ptrs = cb_shifted_ptr + (offs_m[:, None] * stride_cbs_csize_m + offs_k[None, :] * stride_cbs_csize_k) + xs_ptrs = x_shifted_ptr + (offs_k[:, None] * stride_xs_seqlen + offs_n[None, :] * stride_xs_hdim) + beta_ptrs = beta_ptr + offs_k * stride_beta_csize + + K_MAX = chunk_size_limit if not IS_CAUSAL else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + for k in range(0, K_MAX, BLOCK_SIZE_K): + # Load CB values and compute decay + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + # If there's seq_idx, CB was already zeroed for cross-doc pairs by _bmm_chunk_fwd. + # L[m,k] = exp(dA_cs_m - dA_cs_k), clamped to avoid overflow + cb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_k[None, :]), 0.0)) + + # Scale by gamma (Mamba-3) instead of dt (Mamba-2) + gamma_k = tl.load(gamma_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + cb *= gamma_k + + # Causal mask + if IS_CAUSAL: + mask = offs_m[:, None] >= k + offs_k[None, :] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(x_ptr.dtype.element_ty) + + x = tl.load(x_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), other=0.0) + acc += tl.dot(cb, x) + + # Lookback term + if HAS_LOOKBACK: + cbs = tl.load(cbs_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32) + cbs *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_k[None, :]), 0.0)) + + beta_k = tl.load(beta_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + cbs *= beta_k + + if IS_CAUSAL: + cbs = tl.where(mask, cbs, 0.0) + cbs = cbs.to(x_ptr.dtype.element_ty) + + xs = tl.load(xs_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), other=0.0) + acc += tl.dot(cbs, xs) + + # Advance lookback pointers + cbs_ptrs += BLOCK_SIZE_K * stride_cbs_csize_k + xs_ptrs += BLOCK_SIZE_K * stride_xs_seqlen + beta_ptrs += BLOCK_SIZE_K * stride_beta_csize + + # Advance pointers + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + gamma_ptrs += BLOCK_SIZE_K * stride_gamma_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + # === D skip connection === + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + x_residual = tl.load( + x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0 + ).to(tl.float32) + acc += x_residual * D + + # === Z gating === + if HAS_Z: + out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) + tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) + + z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head + z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) + z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32) + acc *= z * tl.sigmoid(z) + + # Store output + out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) + tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) + + +def _mamba3_chunk_scan_fwd(CB, x, dt, dA_cumsum, gamma, C, prev_states, + D=None, z=None, beta=None, CB_shifted=None, + x_shifted=None, seq_idx=None): + """ + Compute chunked scan output for Mamba-3 SSD. + + Arguments: + CB: (batch, nchunks, ngroups, chunk_size, chunk_size) -- C^T @ B per chunk + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) -- NOT directly used by kernel (kept for API compat) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + gamma: (batch, nheads, nchunks, chunk_size) -- current term weight (replaces dt in Mamba-2) + C: (batch, seqlen, ngroups, dstate) -- for off-diagonal computation + prev_states: (batch, nchunks, nheads, headdim, dstate) -- boundary states + D: (nheads,) or (nheads, headdim) or None -- skip connection + z: (batch, seqlen, nheads, headdim) or None -- gating + beta: (batch, nheads, nchunks, chunk_size) or None -- lookback weight + CB_shifted: (batch, nchunks, ngroups, chunk_size, chunk_size) or None -- C^T @ B_shifted + x_shifted: (batch, seqlen, nheads, headdim) or None -- shifted x + seq_idx: (batch, seqlen) or None -- document boundaries + + Returns: + out: (batch, seqlen, nheads, headdim) + out_x: (batch, seqlen, nheads, headdim) or None -- pre-gating output if z is present + """ + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dA_cumsum.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + assert gamma.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + + has_lookback = beta is not None and CB_shifted is not None and x_shifted is not None + if has_lookback: + assert beta.shape == (batch, nheads, nchunks, chunk_size) + assert CB_shifted.shape == CB.shape + assert x_shifted.shape == x.shape + + # Allocate output + out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + if z is not None: + out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + assert out_x.stride() == out.stride() + else: + out_x = None + + grid = lambda META: ( + triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, + nheads, + ) + + z_strides = (z.stride(0), z.stride(1), z.stride(2), z.stride(3)) if z is not None else (0, 0, 0, 0) + + with torch.cuda.device(x.device.index): + _mamba3_chunk_scan_fwd_kernel[grid]( + # Core data pointers + CB, x, z, out, out_x, + dA_cumsum, gamma, seq_idx, + C, prev_states, D, + # Lookback pointers (None if not used) + beta, CB_shifted, x_shifted, + # Dimensions + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + # CB strides + CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(3), CB.stride(4), + # x strides + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + # z strides + z_strides[0], z_strides[1], z_strides[2], z_strides[3], + # out strides + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + # dA_cumsum strides (note: tensor is (b, nheads, nchunks, chunk_size), kernel expects batch, chunk, head, csize) + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + # gamma strides + gamma.stride(0), gamma.stride(2), gamma.stride(1), gamma.stride(3), + # seq_idx strides + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + # C strides + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + # prev_states strides + prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4), + # D stride + D.stride(0) if D is not None else 0, + # beta strides + *((beta.stride(0), beta.stride(2), beta.stride(1), beta.stride(3)) + if has_lookback else (0, 0, 0, 0)), + # CB_shifted strides + *((CB_shifted.stride(0), CB_shifted.stride(1), CB_shifted.stride(2), CB_shifted.stride(3), CB_shifted.stride(4)) + if has_lookback else (0, 0, 0, 0, 0)), + # x_shifted strides + *((x_shifted.stride(0), x_shifted.stride(1), x_shifted.stride(2), x_shifted.stride(3)) + if has_lookback else (0, 0, 0, 0)), + # Constexpr meta-parameters + True, # IS_CAUSAL + D is not None, # HAS_D + D.dim() == 2 if D is not None else True, # D_HAS_HDIM + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + HAS_Z=z is not None, + HAS_SEQ_IDX=seq_idx is not None, + HAS_LOOKBACK=has_lookback, + IS_TRITON_22=TRITON_22, + ) + + return out, out_x diff --git a/mamba_ssm/ops/triton/mamba3_chunk_scan_bwd.py b/mamba_ssm/ops/triton/mamba3_chunk_scan_bwd.py new file mode 100644 index 000000000..374272725 --- /dev/null +++ b/mamba_ssm/ops/triton/mamba3_chunk_scan_bwd.py @@ -0,0 +1,1038 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Mamba-3 chunk scan Triton backward kernels. +# +# Extends Mamba-2's backward kernels from ssd_chunk_scan.py and ssd_combined.py +# to support the trapezoidal discretization used in Mamba-3: +# +# Y_diag[m] = sum_k L[m,k] * gamma[k] * CB[m,k] * x[k] (current term) +# + sum_k L[m,k] * beta[k] * CB_s[m,k] * x_s[k] (lookback term) +# +# Where L[m,k] = exp(dA_cs[m] - dA_cs[k]) is the causal decay matrix. +# gamma replaces dt in Mamba-2's intra-chunk, beta scales the lookback term. +# +# Three kernels: +# 1. _mamba3_chunk_scan_chunk_state_bwd_dx_kernel -- dx, dgamma, dbeta, dD from both paths +# 2. _mamba3_chunk_scan_bwd_dcb_kernel -- dCB and dCB_shifted +# 3. _mamba3_chunk_scan_bwd_ddAcs_stable_kernel -- ddA_cumsum (stable) + +import math +from packaging import version + +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + +from mamba_ssm.utils.determinism import ( + alloc_tile_workspace, + finalize_tile_workspace, + use_deterministic_mode, + autotune_configs, +) + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +def init_to_zero(names): + return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + + +# ============================================================================= +# Kernel 1: Combined backward dx from intra-chunk (CB path) + inter-chunk (states path) +# ============================================================================= + +@triton.autotune( + configs=autotune_configs([ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr", "dgamma_ptr", "dbeta_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr", "dgamma_ptr", "dbeta_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr", "dgamma_ptr", "dbeta_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr", "dgamma_ptr", "dbeta_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr", "dgamma_ptr", "dbeta_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr", "dgamma_ptr", "dbeta_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr", "dgamma_ptr", "dbeta_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=2, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr", "dgamma_ptr", "dbeta_ptr"])), + ]), + key=['chunk_size', 'hdim', 'dstate'], +) +@triton.jit +def _mamba3_chunk_scan_chunk_state_bwd_dx_kernel( + # Pointers to matrices + x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr, + b_ptr, dstates_ptr, + # Mamba-3 specific pointers + gamma_ptr, beta_ptr, + cb_shifted_ptr, x_shifted_ptr, b_shifted_ptr, + # Output pointers + dx_ptr, ddt_ptr, dD_ptr, + dgamma_ptr, dbeta_ptr, dx_shifted_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # x strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + # cb strides + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, + # dout strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + # dt strides (chunked layout: batch, chunk, head, csize) + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + # dA_cumsum strides + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + # seq_idx strides + stride_seq_idx_batch, stride_seq_idx_seqlen, + # D stride + stride_D_head, + # B strides + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + # dstates strides + stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate, + # gamma strides (chunked layout) + stride_gamma_batch, stride_gamma_chunk, stride_gamma_head, stride_gamma_csize, + # beta strides + stride_beta_batch, stride_beta_chunk, stride_beta_head, stride_beta_csize, + # cb_shifted strides + stride_cbs_batch, stride_cbs_chunk, stride_cbs_head, stride_cbs_csize_m, stride_cbs_csize_k, + # x_shifted strides + stride_xs_batch, stride_xs_seqlen, stride_xs_head, stride_xs_hdim, + # b_shifted strides + stride_bs_batch, stride_bs_seqlen, stride_bs_head, stride_bs_dstate, + # dx strides + stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, + # ddt strides + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, stride_ddt_tile, + # dD strides + stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, + # dgamma strides + stride_dgamma_batch, stride_dgamma_chunk, stride_dgamma_head, stride_dgamma_csize, stride_dgamma_tile, + # dbeta strides + stride_dbeta_batch, stride_dbeta_chunk, stride_dbeta_head, stride_dbeta_csize, stride_dbeta_tile, + # dx_shifted strides + stride_dxs_batch, stride_dxs_seqlen, stride_dxs_head, stride_dxs_hdim, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + HAS_LOOKBACK: tl.constexpr, + HAS_GAMMA: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, + DETERMINISTIC_REDUCTION: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + + # Advance base pointers + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + pid_n * stride_ddt_tile + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + if HAS_GAMMA: + gamma_ptr += pid_b * stride_gamma_batch + pid_c * stride_gamma_chunk + pid_h * stride_gamma_head + dgamma_ptr += pid_b * stride_dgamma_batch + pid_c * stride_dgamma_chunk + pid_h * stride_dgamma_head + pid_n * stride_dgamma_tile + if HAS_LOOKBACK: + beta_ptr += pid_b * stride_beta_batch + pid_c * stride_beta_chunk + pid_h * stride_beta_head + dbeta_ptr += pid_b * stride_dbeta_batch + pid_c * stride_dbeta_chunk + pid_h * stride_dbeta_head + pid_n * stride_dbeta_tile + cb_shifted_ptr += pid_b * stride_cbs_batch + pid_c * stride_cbs_chunk + (pid_h // nheads_ngroups_ratio) * stride_cbs_head + x_shifted_ptr += pid_b * stride_xs_batch + pid_c * chunk_size * stride_xs_seqlen + pid_h * stride_xs_head + b_shifted_ptr += pid_b * stride_bs_batch + pid_c * chunk_size * stride_bs_seqlen + (pid_h // nheads_ngroups_ratio) * stride_bs_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + # ======================================================================== + # Phase 1: Inter-chunk contribution (from dstates) + # dx_curr[m] += B[m]^T @ dstates * exp(dA_last - dA_m) * gamma[m] (current) + # dx_shift[m] += B_shifted[m]^T @ dstates * exp(dA_last - dA_m) * beta[m] (lookback) + # ======================================================================== + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + + if not HAS_SEQ_IDX: + scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)) + else: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0) + + # Compute B[m] @ dstates for current inter-chunk term + offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate) + dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate) + if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128: + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0) + dstates_val = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + dstates_val = dstates_val.to(b_ptr.dtype.element_ty) + acc = tl.dot(b, dstates_val) * scale[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0) + dstates_val = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + dstates_val = dstates_val.to(b_ptr.dtype.element_ty) + acc += tl.dot(b, dstates_val) + b_ptrs += BLOCK_SIZE_K * stride_b_dstate + dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate + acc *= scale[:, None] + + # acc now holds B[m] @ dstates * scale for inter-chunk. + # For lookback inter-chunk: B_shifted[m] @ dstates * scale * beta[m] + if HAS_LOOKBACK: + acc_shift = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + offs_dstate_lb = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + bs_ptrs = b_shifted_ptr + (offs_m[:, None] * stride_bs_seqlen + offs_dstate_lb[None, :] * stride_bs_dstate) + dstates_ptrs2 = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate_lb[:, None] * stride_dstates_dstate) + if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128: + bs = tl.load(bs_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate_lb[None, :] < dstate), other=0.0) + dstates_val2 = tl.load(dstates_ptrs2, mask=(offs_dstate_lb[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + dstates_val2 = dstates_val2.to(bs_ptrs.dtype.element_ty) + acc_shift = tl.dot(bs, dstates_val2) * scale[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + bs = tl.load(bs_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate_lb[None, :] < dstate - k), other=0.0) + dstates_val2 = tl.load(dstates_ptrs2, mask=(offs_dstate_lb[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + dstates_val2 = dstates_val2.to(bs_ptrs.dtype.element_ty) + acc_shift += tl.dot(bs, dstates_val2) + bs_ptrs += BLOCK_SIZE_K * stride_bs_dstate + dstates_ptrs2 += BLOCK_SIZE_K * stride_dstates_dstate + acc_shift *= scale[:, None] + + # ======================================================================== + # Phase 2: Intra-chunk contribution (from CB path, transposed causal) + # dx_curr[m] += sum_k CB^T[k,m] * L^T[k,m] * gamma[m] * dout[k] (k >= m) + # dx_shift[m] += sum_k CB_shifted^T[k,m] * L^T[k,m] * beta[m] * dout[k] + # + # In the backward, we iterate over k >= m (upper triangle of CB, transposed). + # The CB matrix is stored as CB[row, col] with row=k, col=m in the transposed view. + # In the stored layout, cb_ptr[m, k] is accessed as cb_ptr + m * stride_csize_m + k * stride_csize_k. + # For the backward, we need CB^T[k, m] which is cb[m, k] since CB is stored with + # stride_cb_csize_m for row and stride_cb_csize_k for col. + # ======================================================================== + offs_k = tl.arange(0, BLOCK_SIZE_K) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) + dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + K_MAX = chunk_size_limit + K_MIN = pid_m * BLOCK_SIZE_M + cb_ptrs += K_MIN * stride_cb_csize_k + dout_ptrs += K_MIN * stride_dout_seqlen + dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize + + if HAS_LOOKBACK: + cbs_ptrs = cb_shifted_ptr + (offs_m[:, None] * stride_cbs_csize_m + offs_k[None, :] * stride_cbs_csize_k) + cbs_ptrs += K_MIN * stride_cbs_csize_k + + for k in range(K_MIN, K_MAX, BLOCK_SIZE_K): + k = tl.multiple_of(k, BLOCK_SIZE_K) + # Load CB values and compute transposed causal decay + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) + dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) + # L^T[k,m] = exp(dA_cs_k - dA_cs_m), but transposed: for backward k >= m + cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0)) + mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX) + cb = tl.where(mask, cb, 0.0) + cb = cb.to(dout_ptr.dtype.element_ty) + acc += tl.dot(cb, dout) + + # Lookback intra-chunk contribution + if HAS_LOOKBACK: + cbs = tl.load(cbs_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) + cbs *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0)) + cbs = tl.where(mask, cbs, 0.0) + cbs = cbs.to(dout_ptr.dtype.element_ty) + acc_shift += tl.dot(cbs, dout) + cbs_ptrs += BLOCK_SIZE_K * stride_cbs_csize_k + + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + # ======================================================================== + # Phase 3: Scale by gamma/beta, compute outputs + # ======================================================================== + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + if HAS_GAMMA: + gamma_ptrs = gamma_ptr + offs_m * stride_gamma_csize + gamma_m = tl.load(gamma_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + dx = acc * gamma_m[:, None] + else: + # Fallback to dt (Mamba-2 compatible) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + dx = acc * dt_m[:, None] + + # D skip connection gradient + dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head + dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) + if HAS_D: + dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + dx += dout_res * D + tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + + # dD computation + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if HAS_D: + dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize + if D_HAS_HDIM: + dD_ptrs_local = dD_ptr + offs_n * stride_dD_hdim + dD = tl.sum(dout_res * x, axis=0) + tl.store(dD_ptrs_local, dD, mask=offs_n < hdim) + else: + dD = tl.sum(dout_res * x) + if DETERMINISTIC_REDUCTION: + tl.store(dD_ptr + pid_n * stride_dD_hdim, dD) + else: + tl.atomic_add(dD_ptr, dD) + + # When HAS_GAMMA: scaling factor is gamma (not dt), so ddt=0 from this kernel. + # The dt gradient only comes from the ddA_cumsum path. + # When !HAS_GAMMA: dt IS the scaling factor (Mamba-2 fallback), so ddt = sum(acc*x). + if HAS_GAMMA: + # dgamma = sum(acc * x) — gradient w.r.t. the gamma scaling factor + dgamma = tl.sum(acc * x, axis=1) + dgamma_ptrs = dgamma_ptr + offs_m * stride_dgamma_csize + if DETERMINISTIC_REDUCTION: + tl.store(dgamma_ptrs, dgamma, mask=offs_m < chunk_size) + else: + tl.atomic_add(dgamma_ptrs, dgamma, mask=offs_m < chunk_size) + # ddt stays zero (from zero_init) — dt gradient only comes from ddA path + else: + # Mamba-2 fallback: dt IS the scaling factor, so ddt = sum(acc * x) + ddt = tl.sum(acc * x, axis=1) + ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize + if DETERMINISTIC_REDUCTION: + tl.store(ddt_ptrs, ddt, mask=offs_m < chunk_size) + else: + tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + + # Lookback: store dx_shifted and compute dbeta + if HAS_LOOKBACK: + beta_ptrs = beta_ptr + offs_m * stride_beta_csize + beta_m = tl.load(beta_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + dx_shift = acc_shift * beta_m[:, None] + dx_shifted_ptr += pid_b * stride_dxs_batch + pid_c * chunk_size * stride_dxs_seqlen + pid_h * stride_dxs_head + dxs_ptrs = dx_shifted_ptr + (offs_m[:, None] * stride_dxs_seqlen + offs_n[None, :] * stride_dxs_hdim) + tl.store(dxs_ptrs, dx_shift, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + + # dbeta[m] = sum_n acc_shift_before_beta[m,n] * x_shifted[m,n] + xs_ptrs = x_shifted_ptr + (offs_m[:, None] * stride_xs_seqlen + offs_n[None, :] * stride_xs_hdim) + xs = tl.load(xs_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + dbeta = tl.sum(acc_shift * xs, axis=1) + dbeta_ptrs = dbeta_ptr + offs_m * stride_dbeta_csize + if DETERMINISTIC_REDUCTION: + tl.store(dbeta_ptrs, dbeta, mask=offs_m < chunk_size) + else: + tl.atomic_add(dbeta_ptrs, dbeta, mask=offs_m < chunk_size) + + +_MAMBA3_CHUNK_SCAN_CHUNK_STATE_BWD_DX_MIN_BLOCK_N = min( + cfg.kwargs['BLOCK_SIZE_N'] for cfg in _mamba3_chunk_scan_chunk_state_bwd_dx_kernel.configs +) + + +def _mamba3_chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, + D=None, seq_idx=None, + gamma=None, beta=None, + CB_shifted=None, x_shifted=None, B_shifted=None): + """ + Combined backward for dx from intra-chunk (CB) and inter-chunk (states). + + Arguments: + x: (batch, seqlen, nheads, headdim) -- input + dt: (batch, nheads, nchunks, chunk_size) -- timestep + dA_cumsum: (batch, nheads, nchunks, chunk_size) -- cumulative dA + B: (batch, seqlen, ngroups, dstate) -- input projection + CB: (batch, nchunks, ngroups, chunk_size, chunk_size) -- C^T B product + dout: (batch, seqlen, nheads, headdim) -- output gradient + dstates: (batch, nchunks, nheads, headdim, dstate) -- state gradient + D: (nheads,) or (nheads, headdim) or None -- skip connection + seq_idx: (batch, seqlen) or None + gamma: (batch, nheads, nchunks, chunk_size) or None -- current term weight + beta: (batch, nheads, nchunks, chunk_size) or None -- lookback weight + CB_shifted: same shape as CB or None + x_shifted: same shape as x or None + B_shifted: same shape as B or None + + Returns: dx, ddt, dD, dgamma, dbeta, dx_shifted + """ + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dout.shape == x.shape + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + + has_gamma = gamma is not None + has_lookback = (beta is not None and CB_shifted is not None + and x_shifted is not None and B_shifted is not None) + if has_gamma: + assert gamma.shape == (batch, nheads, nchunks, chunk_size) + if has_lookback: + assert beta.shape == (batch, nheads, nchunks, chunk_size) + assert CB_shifted.shape == CB.shape + assert x_shifted.shape == x.shape + assert B_shifted.shape == B.shape + + deterministic = use_deterministic_mode() + + # Allocate outputs + dx = torch.empty_like(x) + + tile_count = math.ceil(headdim / _MAMBA3_CHUNK_SCAN_CHUNK_STATE_BWD_DX_MIN_BLOCK_N) + + ddt, stride_ddt_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + dout.device, + deterministic, + zero_init=True, + ) + + # dD allocation + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert D.stride(-1) == 1 + BLOCK_SIZE_min = 32 + pid_m_tiles = triton.cdiv(chunk_size, BLOCK_SIZE_min) + pid_n_tiles = math.ceil(headdim / _MAMBA3_CHUNK_SCAN_CHUNK_STATE_BWD_DX_MIN_BLOCK_N) + if D.dim() == 2: + dD_hdim = headdim + elif deterministic: + dD_hdim = pid_n_tiles + else: + dD_hdim = 1 + dD = torch.zeros(pid_m_tiles, batch, nchunks, nheads, dD_hdim, device=D.device, dtype=torch.float32) + dD_strides = (dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) + else: + dD = None + dD_strides = (0, 0, 0, 0, 0) + + # dgamma, dbeta allocation + if has_gamma: + dgamma, stride_dgamma_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + dout.device, + deterministic, + zero_init=True, + ) + else: + dgamma = None + stride_dgamma_tile = 0 + + if has_lookback: + dbeta, stride_dbeta_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + dout.device, + deterministic, + zero_init=True, + ) + dx_shifted = torch.empty_like(x) + else: + dbeta = None + stride_dbeta_tile = 0 + dx_shifted = None + + # Strides for optional tensors + gamma_strides = (gamma.stride(0), gamma.stride(2), gamma.stride(1), gamma.stride(3)) if has_gamma else (0, 0, 0, 0) + beta_strides = (beta.stride(0), beta.stride(2), beta.stride(1), beta.stride(3)) if has_lookback else (0, 0, 0, 0) + cbs_strides = (CB_shifted.stride(0), CB_shifted.stride(1), CB_shifted.stride(2), CB_shifted.stride(-1), CB_shifted.stride(-2)) if has_lookback else (0, 0, 0, 0, 0) + xs_strides = (x_shifted.stride(0), x_shifted.stride(1), x_shifted.stride(2), x_shifted.stride(3)) if has_lookback else (0, 0, 0, 0) + bs_strides = (B_shifted.stride(0), B_shifted.stride(1), B_shifted.stride(2), B_shifted.stride(3)) if has_lookback else (0, 0, 0, 0) + dxs_strides = (dx_shifted.stride(0), dx_shifted.stride(1), dx_shifted.stride(2), dx_shifted.stride(3)) if has_lookback else (0, 0, 0, 0) + + # dgamma/dbeta have shape (batch, nheads, nchunks, chunk_size) from alloc_tile_workspace + # Stride order: batch, chunk, head, csize (matching ddt convention) + dgamma_strides = (dgamma.stride(0), dgamma.stride(2), dgamma.stride(1), dgamma.stride(3)) if has_gamma else (0, 0, 0, 0) + dbeta_strides = (dbeta.stride(0), dbeta.stride(2), dbeta.stride(1), dbeta.stride(3)) if has_lookback else (0, 0, 0, 0) + + grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _mamba3_chunk_scan_chunk_state_bwd_dx_kernel[grid_dx]( + x, CB, dout, dt, dA_cumsum, seq_idx, D, + B, dstates, + gamma, beta, + CB_shifted, x_shifted, B_shifted, + dx, ddt, dD, + dgamma, dbeta, dx_shifted, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + # x strides + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + # CB strides + CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2), + # dout strides + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + # dt strides + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + # dA_cumsum strides + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + # seq_idx strides + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + # D stride + D.stride(0) if D is not None else 0, + # B strides + B.stride(0), B.stride(1), B.stride(2), B.stride(3), + # dstates strides + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), + # gamma strides + *gamma_strides, + # beta strides + *beta_strides, + # CB_shifted strides + *cbs_strides, + # x_shifted strides + *xs_strides, + # B_shifted strides + *bs_strides, + # dx strides + dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), + # ddt strides + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), stride_ddt_tile, + # dD strides + dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], + # dgamma strides + *dgamma_strides, stride_dgamma_tile, + # dbeta strides + *dbeta_strides, stride_dbeta_tile, + # dx_shifted strides + *dxs_strides, + # constexpr + D is not None, + D.dim() == 2 if D is not None else True, + HAS_SEQ_IDX=seq_idx is not None, + HAS_LOOKBACK=has_lookback, + HAS_GAMMA=has_gamma, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + IS_TRITON_22=TRITON_22, + DETERMINISTIC_REDUCTION=deterministic, + ) + + # Finalize reductions + ddt = finalize_tile_workspace(ddt, deterministic) + if D is not None: + BLOCK_SIZE_actual = _mamba3_chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)) + if D.dim() == 1: + dD = dD.sum(dim=-1) + dD = dD.to(dtype=D.dtype) + if has_gamma: + dgamma = finalize_tile_workspace(dgamma, deterministic) + if has_lookback: + dbeta = finalize_tile_workspace(dbeta, deterministic) + + return dx, ddt, dD, dgamma, dbeta, dx_shifted + + +# ============================================================================= +# Kernel 2: Backward dCB (and dCB_shifted) +# ============================================================================= + +@triton.autotune( + configs=autotune_configs([ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + ]), + key=['chunk_size', 'hdim'], +) +@triton.jit +def _mamba3_chunk_scan_bwd_dcb_kernel( + # Pointers to matrices + x_ptr, dout_ptr, dA_cumsum_ptr, seq_idx_ptr, + gamma_ptr, beta_ptr, x_shifted_ptr, + # Output pointers + dcb_ptr, dcb_shifted_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + # x strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + # dout strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + # dA_cumsum strides + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + # seq_idx strides + stride_seq_idx_batch, stride_seq_idx_seqlen, + # gamma strides + stride_gamma_batch, stride_gamma_chunk, stride_gamma_head, stride_gamma_csize, + # beta strides + stride_beta_batch, stride_beta_chunk, stride_beta_head, stride_beta_csize, + # x_shifted strides + stride_xs_batch, stride_xs_seqlen, stride_xs_head, stride_xs_hdim, + # dcb strides + stride_dcb_batch, stride_dcb_chunk, stride_dcb_split, stride_dcb_group, stride_dcb_csize_m, stride_dcb_csize_n, + # dcb_shifted strides + stride_dcbs_batch, stride_dcbs_chunk, stride_dcbs_split, stride_dcbs_group, stride_dcbs_csize_m, stride_dcbs_csize_n, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + HAS_LOOKBACK: tl.constexpr, + HAS_GAMMA: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_sg = tl.program_id(axis=2) + pid_s = pid_sg // ngroups + pid_g = pid_sg - pid_s * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head + if HAS_GAMMA: + gamma_ptr += pid_b * stride_gamma_batch + pid_c * stride_gamma_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_gamma_head + if HAS_LOOKBACK: + beta_ptr += pid_b * stride_beta_batch + pid_c * stride_beta_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_beta_head + x_shifted_ptr += pid_b * stride_xs_batch + pid_c * chunk_size * stride_xs_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_xs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) + if HAS_GAMMA: + gamma_ptrs = gamma_ptr + offs_n * stride_gamma_csize + if HAS_LOOKBACK: + xs_ptrs = x_shifted_ptr + (offs_n[None, :] * stride_xs_seqlen + offs_k[:, None] * stride_xs_hdim) + beta_ptrs = beta_ptr + offs_n * stride_beta_csize + + # Early exit for blocks entirely above the causal diagonal + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split + dcb_ptrs_out = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n) + tl.store(dcb_ptrs_out, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=dcb_ptr.dtype.element_ty), mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) + if HAS_LOOKBACK: + dcbs_ptr = dcb_shifted_ptr + pid_b * stride_dcbs_batch + pid_c * stride_dcbs_chunk + pid_g * stride_dcbs_group + pid_s * stride_dcbs_split + dcbs_ptrs_out = dcbs_ptr + (offs_m[:, None] * stride_dcbs_csize_m + offs_n[None, :] * stride_dcbs_csize_n) + tl.store(dcbs_ptrs_out, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=dcbs_ptr.dtype.element_ty), mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) + return + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if HAS_LOOKBACK: + acc_shift = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) + for h in range(nheads_iter): + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) + dcb = tl.dot(dout, x) + + # Scale by gamma[n] (replaces dt[n] in Mamba-2) + if HAS_GAMMA: + gamma_n = tl.load(gamma_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32) + dcb *= gamma_n + else: + # Mamba-2 fallback: would use dt, but in Mamba-3 context this shouldn't happen + pass + + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32) + dcb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0)) + acc += dcb + + # Lookback term: dCB_shifted[m,n] = sum_k dout[m,k] * x_shifted[n,k] * beta[n] * L[m,n] + if HAS_LOOKBACK: + xs = tl.load(xs_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) + dcbs = tl.dot(dout, xs) + beta_n = tl.load(beta_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32) + dcbs *= beta_n + dcbs *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0)) + acc_shift += dcbs + xs_ptrs += stride_xs_head + beta_ptrs += stride_beta_head + + dout_ptrs += stride_dout_head + x_ptrs += stride_x_head + dA_cumsum_ptr += stride_dA_cs_head + if HAS_GAMMA: + gamma_ptrs += stride_gamma_head + + # Apply causal mask and seq_idx mask + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_SEQ_IDX: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2) + acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + if HAS_LOOKBACK: + acc_shift = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc_shift, 0.0) + mask = offs_m[:, None] >= offs_n[None, :] + acc = tl.where(mask, acc, 0.0) + + dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split + dcb_ptrs_out = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n) + tl.store(dcb_ptrs_out, acc, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) + + if HAS_LOOKBACK: + acc_shift = tl.where(mask, acc_shift, 0.0) + dcbs_ptr = dcb_shifted_ptr + pid_b * stride_dcbs_batch + pid_c * stride_dcbs_chunk + pid_g * stride_dcbs_group + pid_s * stride_dcbs_split + dcbs_ptrs_out = dcbs_ptr + (offs_m[:, None] * stride_dcbs_csize_m + offs_n[None, :] * stride_dcbs_csize_n) + tl.store(dcbs_ptrs_out, acc_shift, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) + + +def _mamba3_chunk_scan_bwd_dcb(x, dA_cumsum, dout, seq_idx=None, + gamma=None, beta=None, x_shifted=None, ngroups=1): + """ + Backward for dCB (and dCB_shifted if lookback is enabled). + + Arguments: + x: (batch, seqlen, nheads, headdim) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + dout: (batch, seqlen, nheads, headdim) + seq_idx: (batch, seqlen) or None + gamma: (batch, nheads, nchunks, chunk_size) or None + beta: (batch, nheads, nchunks, chunk_size) or None + x_shifted: (batch, seqlen, nheads, headdim) or None + ngroups: int + + Returns: dCB, dCB_shifted (or dCB_shifted=None if no lookback) + """ + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dA_cumsum.shape + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert dout.shape == x.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + + has_gamma = gamma is not None + has_lookback = beta is not None and x_shifted is not None + + if has_gamma: + assert gamma.shape == (batch, nheads, nchunks, chunk_size) + if has_lookback: + assert beta.shape == (batch, nheads, nchunks, chunk_size) + assert x_shifted.shape == x.shape + + nheads_ngroups_ratio = nheads // ngroups + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) + nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) + + dcb = torch.empty(batch, nchunks, nsplits, ngroups, chunk_size, chunk_size, device=x.device, dtype=torch.float32) + if has_lookback: + dcb_shifted = torch.empty_like(dcb) + else: + dcb_shifted = None + + # Strides for optional tensors + gamma_strides = (gamma.stride(0), gamma.stride(2), gamma.stride(1), gamma.stride(3)) if has_gamma else (0, 0, 0, 0) + beta_strides = (beta.stride(0), beta.stride(2), beta.stride(1), beta.stride(3)) if has_lookback else (0, 0, 0, 0) + xs_strides = (x_shifted.stride(0), x_shifted.stride(1), x_shifted.stride(2), x_shifted.stride(3)) if has_lookback else (0, 0, 0, 0) + dcbs_strides = (dcb_shifted.stride(0), dcb_shifted.stride(1), dcb_shifted.stride(2), dcb_shifted.stride(3), dcb_shifted.stride(4), dcb_shifted.stride(5)) if has_lookback else (0, 0, 0, 0, 0, 0) + + grid_dcb = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']), + batch * nchunks, nsplits * ngroups) + with torch.cuda.device(x.device.index): + _mamba3_chunk_scan_bwd_dcb_kernel[grid_dcb]( + x, dout, dA_cumsum, seq_idx, + gamma, beta, x_shifted, + dcb, dcb_shifted, + chunk_size, headdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + # x strides + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + # dout strides + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + # dA_cumsum strides + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + # seq_idx strides + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + # gamma strides + *gamma_strides, + # beta strides + *beta_strides, + # x_shifted strides + *xs_strides, + # dcb strides + dcb.stride(0), dcb.stride(1), dcb.stride(2), dcb.stride(3), dcb.stride(4), dcb.stride(5), + # dcb_shifted strides + *dcbs_strides, + # constexpr + HAS_SEQ_IDX=seq_idx is not None, + HAS_LOOKBACK=has_lookback, + HAS_GAMMA=has_gamma, + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + + dcb = dcb.sum(2) + if has_lookback: + dcb_shifted = dcb_shifted.sum(2) + + return dcb, dcb_shifted + + +# ============================================================================= +# Kernel 3: Backward ddA_cumsum (stable version) +# ============================================================================= + +@triton.autotune( + configs=autotune_configs([ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + ]), + key=['chunk_size', 'hdim'], +) +@triton.jit +def _mamba3_chunk_scan_bwd_ddAcs_stable_kernel( + # Pointers to matrices + x_ptr, dout_ptr, dA_cumsum_ptr, cb_ptr, + gamma_ptr, beta_ptr, x_shifted_ptr, cb_shifted_ptr, + # Output pointer + ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, nheads_ngroups_ratio, + # x strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + # dout strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + # dA_cumsum strides + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + # cb strides + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, + # gamma strides + stride_gamma_batch, stride_gamma_chunk, stride_gamma_head, stride_gamma_csize, + # beta strides + stride_beta_batch, stride_beta_chunk, stride_beta_head, stride_beta_csize, + # x_shifted strides + stride_xs_batch, stride_xs_seqlen, stride_xs_head, stride_xs_hdim, + # cb_shifted strides + stride_cbs_batch, stride_cbs_chunk, stride_cbs_head, stride_cbs_csize_m, stride_cbs_csize_n, + # ddA_cumsum strides + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n, + # Meta-parameters + HAS_LOOKBACK: tl.constexpr, + HAS_GAMMA: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m + if HAS_GAMMA: + gamma_ptr += pid_b * stride_gamma_batch + pid_c * stride_gamma_chunk + pid_h * stride_gamma_head + if HAS_LOOKBACK: + beta_ptr += pid_b * stride_beta_batch + pid_c * stride_beta_chunk + pid_h * stride_beta_head + x_shifted_ptr += pid_b * stride_xs_batch + pid_c * chunk_size * stride_xs_seqlen + pid_h * stride_xs_head + cb_shifted_ptr += pid_b * stride_cbs_batch + pid_c * stride_cbs_chunk + (pid_h // nheads_ngroups_ratio) * stride_cbs_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) + if HAS_GAMMA: + gamma_ptrs = gamma_ptr + offs_n * stride_gamma_csize + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) + ddAcs_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n + tl.store(ddA_cumsum_ptr, 0.0) + + if HAS_LOOKBACK: + xs_ptrs = x_shifted_ptr + (offs_n[None, :] * stride_xs_seqlen + offs_k[:, None] * stride_xs_hdim) + beta_ptrs = beta_ptr + offs_n * stride_beta_csize + cbs_ptrs = cb_shifted_ptr + (offs_m[:, None] * stride_cbs_csize_m + offs_n[None, :] * stride_cbs_csize_n) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + lo, hi = 0, (pid_m + 1) * BLOCK_SIZE_M + + for start_n in range(lo, hi, BLOCK_SIZE_N): + start_n = tl.multiple_of(start_n, BLOCK_SIZE_N) + + # Current term: dout[m] @ x[n]^T * gamma[n] * CB[m,n] * L[m,n] + x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit - start_n), other=0.0) + acc = tl.dot(dout, x) + + if HAS_GAMMA: + gamma_n = tl.load(gamma_ptrs, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) + acc *= gamma_n + # If there's seq_idx, CB was already zeroed for cross-doc pairs + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32) + acc *= cb + dA_cs_n = tl.load(dA_cumsum_ptr + (start_n + offs_n) * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) + acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0)) + + # Lookback term: dout[m] @ x_shifted[n]^T * beta[n] * CB_shifted[m,n] * L[m,n] + if HAS_LOOKBACK: + xs = tl.load(xs_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit - start_n), other=0.0) + acc_lb = tl.dot(dout, xs) + beta_n = tl.load(beta_ptrs, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) + acc_lb *= beta_n + cbs = tl.load(cbs_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32) + acc_lb *= cbs + acc_lb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0)) + acc += acc_lb + + # Apply causal mask and cumsum (same structure as Mamba-2) + mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1 + acc = tl.where(mask, acc, 0.0) + rowsum_new = rowsum + tl.sum(acc, axis=1) + acc = rowsum[:, None] + tl.cumsum(acc, axis=1) + rowsum = rowsum_new + acc = tl.where(mask, acc, 0.0) + ddA_cs = tl.sum(acc, axis=0) + tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - start_n - 1) + + # Advance pointers + x_ptrs += BLOCK_SIZE_N * stride_x_seqlen + cb_ptrs += BLOCK_SIZE_N * stride_cb_csize_n + ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n + if HAS_GAMMA: + gamma_ptrs += BLOCK_SIZE_N * stride_gamma_csize + if HAS_LOOKBACK: + xs_ptrs += BLOCK_SIZE_N * stride_xs_seqlen + beta_ptrs += BLOCK_SIZE_N * stride_beta_csize + cbs_ptrs += BLOCK_SIZE_N * stride_cbs_csize_n + + # Zero out the rest (since we sum rows together later) + for start_n in range(hi, chunk_size, BLOCK_SIZE_N): + tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32), mask=offs_n < chunk_size - start_n - 1) + ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n + + +def _mamba3_chunk_scan_bwd_ddAcs_stable(x, dA_cumsum, dout, CB, seq_idx=None, + gamma=None, beta=None, + x_shifted=None, CB_shifted=None, ngroups=1): + """ + Backward for ddA_cumsum (numerically stable version). + + Arguments: + x: (batch, seqlen, nheads, headdim) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + dout: (batch, seqlen, nheads, headdim) + CB: (batch, nchunks, ngroups, chunk_size, chunk_size) + seq_idx: (batch, seqlen) or None + gamma: (batch, nheads, nchunks, chunk_size) or None + beta: (batch, nheads, nchunks, chunk_size) or None + x_shifted: (batch, seqlen, nheads, headdim) or None + CB_shifted: (batch, nchunks, ngroups, chunk_size, chunk_size) or None + ngroups: int + + Returns: ddA_cumsum + """ + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dA_cumsum.shape + assert dout.shape == x.shape + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + assert nheads % ngroups == 0 + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + + has_gamma = gamma is not None + has_lookback = (beta is not None and x_shifted is not None and CB_shifted is not None) + + if has_gamma: + assert gamma.shape == (batch, nheads, nchunks, chunk_size) + if has_lookback: + assert beta.shape == (batch, nheads, nchunks, chunk_size) + assert x_shifted.shape == x.shape + assert CB_shifted.shape == CB.shape + + # Strides for optional tensors + gamma_strides = (gamma.stride(0), gamma.stride(2), gamma.stride(1), gamma.stride(3)) if has_gamma else (0, 0, 0, 0) + beta_strides = (beta.stride(0), beta.stride(2), beta.stride(1), beta.stride(3)) if has_lookback else (0, 0, 0, 0) + xs_strides = (x_shifted.stride(0), x_shifted.stride(1), x_shifted.stride(2), x_shifted.stride(3)) if has_lookback else (0, 0, 0, 0) + cbs_strides = (CB_shifted.stride(0), CB_shifted.stride(1), CB_shifted.stride(2), CB_shifted.stride(3), CB_shifted.stride(4)) if has_lookback else (0, 0, 0, 0, 0) + + BLOCK_SIZE_M_min = 32 + ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), + chunk_size, device=x.device, dtype=torch.float32) + grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _mamba3_chunk_scan_bwd_ddAcs_stable_kernel[grid_ddtcs]( + x, dout, dA_cumsum, CB, + gamma, beta, x_shifted, CB_shifted, + ddA_cumsum, + chunk_size, headdim, + batch, seqlen, nheads // ngroups, + # x strides + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + # dout strides + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + # dA_cumsum strides + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + # cb strides + CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(3), CB.stride(4), + # gamma strides + *gamma_strides, + # beta strides + *beta_strides, + # x_shifted strides + *xs_strides, + # cb_shifted strides + *cbs_strides, + # ddA_cumsum strides + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4), + # constexpr + HAS_LOOKBACK=has_lookback, + HAS_GAMMA=has_gamma, + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + BLOCK_SIZE_M_actual = _mamba3_chunk_scan_bwd_ddAcs_stable_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual + ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) + return ddA_cumsum diff --git a/mamba_ssm/ops/triton/mamba3_chunk_state.py b/mamba_ssm/ops/triton/mamba3_chunk_state.py new file mode 100644 index 000000000..39eae27a2 --- /dev/null +++ b/mamba_ssm/ops/triton/mamba3_chunk_state.py @@ -0,0 +1,293 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Mamba-3 chunk state Triton kernel (forward pass). +# +# Extends Mamba-2's _chunk_state_fwd_kernel to support the trapezoidal +# discretization used in Mamba-3: +# +# states[c] = sum_t exp(dA_last - dA_t) * gamma_t * B_t outer x_t (current term) +# + sum_t exp(dA_last - dA_t) * beta_t * B_shifted_t outer x_shifted_t (lookback term) +# +# The lookback term is optional (controlled by HAS_LOOKBACK constexpr). +# When gamma is None, falls back to dt scaling (Mamba-2 compatibility mode). + +import torch +import triton +import triton.language as tl + +from mamba_ssm.utils.determinism import autotune_configs + + +def init_to_zero(names): + return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + + +@triton.autotune( + configs=autotune_configs([ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=2, num_warps=2), + ]), + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _mamba3_chunk_state_fwd_kernel( + # Pointers to matrices + x_ptr, b_ptr, states_ptr, + dt_ptr, dA_cumsum_ptr, gamma_ptr, seq_idx_ptr, + # Lookback pointers (may be null) + beta_ptr, b_shifted_ptr, x_shifted_ptr, + # Matrix dimensions + hdim, dstate, chunk_size, + batch, seqlen, nheads_ngroups_ratio, + # x strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + # b strides + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + # states strides + stride_states_batch, stride_states_chunk, stride_states_head, + stride_states_hdim, stride_states_dstate, + # dt strides (chunked layout: b, nheads, nchunks, chunk_size) + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + # dA_cumsum strides (same layout as dt) + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + # gamma strides (same layout as dt) + stride_gamma_batch, stride_gamma_chunk, stride_gamma_head, stride_gamma_csize, + # seq_idx strides + stride_seq_idx_batch, stride_seq_idx_seqlen, + # beta strides (same layout as dt) + stride_beta_batch, stride_beta_chunk, stride_beta_head, stride_beta_csize, + # b_shifted strides (same layout as b) + stride_bs_batch, stride_bs_seqlen, stride_bs_head, stride_bs_dstate, + # x_shifted strides (same layout as x) + stride_xs_batch, stride_xs_seqlen, stride_xs_head, stride_xs_hdim, + # Meta-parameters + HAS_GAMMA: tl.constexpr, + HAS_LOOKBACK: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + + # Advance base pointers to this batch, chunk, head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_GAMMA: + gamma_ptr += pid_b * stride_gamma_batch + pid_c * stride_gamma_chunk + pid_h * stride_gamma_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + if HAS_LOOKBACK: + beta_ptr += pid_b * stride_beta_batch + pid_c * stride_beta_chunk + pid_h * stride_beta_head + b_shifted_ptr += pid_b * stride_bs_batch + pid_c * chunk_size * stride_bs_seqlen + (pid_h // nheads_ngroups_ratio) * stride_bs_head + x_shifted_ptr += pid_b * stride_xs_batch + pid_c * chunk_size * stride_xs_seqlen + pid_h * stride_xs_head + + # Tile offsets + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Pointers for the inner loop + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + if HAS_GAMMA: + gamma_ptrs = gamma_ptr + offs_k * stride_gamma_csize + if HAS_SEQ_IDX: + seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen + if HAS_LOOKBACK: + beta_ptrs = beta_ptr + offs_k * stride_beta_csize + bs_ptrs = b_shifted_ptr + (offs_n[None, :] * stride_bs_dstate + offs_k[:, None] * stride_bs_seqlen) + xs_ptrs = x_shifted_ptr + (offs_m[:, None] * stride_xs_hdim + offs_k[None, :] * stride_xs_seqlen) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + if HAS_SEQ_IDX: + seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + # Load x and B tiles + x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0) + b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + + # Load dA cumsum and compute decay from position to end of chunk + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + + # Seq_idx masking: zero out positions from different documents + if HAS_SEQ_IDX: + seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) + + # Compute scale = exp(dA_last - dA_k) * weight_k + # weight_k is gamma_k for Mamba-3, dt_k for Mamba-2 fallback + if HAS_GAMMA: + weight_k = tl.load(gamma_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + else: + weight_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + + if not HAS_SEQ_IDX: + scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * weight_k + else: + scale = tl.where( + (seq_idx_last >= 0) & (seq_idx_k == seq_idx_last), + tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * weight_k, + 0.0 + ) + + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + + # Lookback term: beta_k * B_shifted_k outer x_shifted_k + if HAS_LOOKBACK: + xs = tl.load(xs_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0) + bs = tl.load(bs_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + + beta_k = tl.load(beta_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + + if not HAS_SEQ_IDX: + scale_lb = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * beta_k + else: + scale_lb = tl.where( + (seq_idx_last >= 0) & (seq_idx_k == seq_idx_last), + tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * beta_k, + 0.0 + ) + + bs *= scale_lb[:, None] + bs = bs.to(x_ptr.dtype.element_ty) + acc += tl.dot(xs, bs) + + # Advance lookback pointers + xs_ptrs += BLOCK_SIZE_K * stride_xs_seqlen + bs_ptrs += BLOCK_SIZE_K * stride_bs_seqlen + beta_ptrs += BLOCK_SIZE_K * stride_beta_csize + + # Advance pointers + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + if HAS_GAMMA: + gamma_ptrs += BLOCK_SIZE_K * stride_gamma_csize + if HAS_SEQ_IDX: + seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + + # Store the output tile + states = acc.to(states_ptr.dtype.element_ty) + states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +def _mamba3_chunk_state_fwd(B, x, dt, dA_cumsum, gamma=None, beta=None, + B_shifted=None, x_shifted=None, seq_idx=None, + states_in_fp32=True): + """ + Compute per-chunk states for Mamba-3 SSD. + + If gamma is None, falls back to dt scaling (Mamba-2 mode). + If beta/B_shifted/x_shifted are None, only current term (Euler mode). + + Arguments: + B: (batch, seqlen, ngroups, dstate) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) -- chunked layout + dA_cumsum: (batch, nheads, nchunks, chunk_size) -- cumulative dA within chunks + gamma: (batch, nheads, nchunks, chunk_size) or None -- current term weight + beta: (batch, nheads, nchunks, chunk_size) or None -- lookback term weight + B_shifted: (batch, seqlen, ngroups, dstate) or None -- B shifted by 1 + x_shifted: (batch, seqlen, nheads, headdim) or None -- x shifted by 1 + seq_idx: (batch, seqlen) or None -- document boundary indices + + Returns: + states: (batch, nchunks, nheads, headdim, dstate) + """ + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + if gamma is not None: + assert gamma.shape == (batch, nheads, nchunks, chunk_size) + if beta is not None: + assert beta.shape == (batch, nheads, nchunks, chunk_size) + assert B_shifted is not None and x_shifted is not None + assert B_shifted.shape == B.shape + assert x_shifted.shape == x.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + + has_lookback = beta is not None and B_shifted is not None and x_shifted is not None + has_gamma = gamma is not None + + states_dtype = torch.float32 if states_in_fp32 else B.dtype + states = torch.empty( + (batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype + ) + + grid = lambda META: ( + triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, + nheads, + ) + + with torch.cuda.device(x.device.index): + _mamba3_chunk_state_fwd_kernel[grid]( + # Core data pointers + x, B, states, + dt, dA_cumsum, gamma, seq_idx, + # Lookback pointers (None if not used) + beta, B_shifted, x_shifted, + # Dimensions + headdim, dstate, chunk_size, + batch, seqlen, nheads // ngroups, + # x strides + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + # B strides + B.stride(0), B.stride(1), B.stride(2), B.stride(-1), + # states strides + states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), + # dt strides (note: dt is (b, nheads, nchunks, chunk_size), kernel expects batch, chunk, head, csize) + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + # dA_cumsum strides + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + # gamma strides + *((gamma.stride(0), gamma.stride(2), gamma.stride(1), gamma.stride(3)) + if has_gamma else (0, 0, 0, 0)), + # seq_idx strides + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + # beta strides + *((beta.stride(0), beta.stride(2), beta.stride(1), beta.stride(3)) + if has_lookback else (0, 0, 0, 0)), + # B_shifted strides + *((B_shifted.stride(0), B_shifted.stride(1), B_shifted.stride(2), B_shifted.stride(-1)) + if has_lookback else (0, 0, 0, 0)), + # x_shifted strides + *((x_shifted.stride(0), x_shifted.stride(1), x_shifted.stride(2), x_shifted.stride(3)) + if has_lookback else (0, 0, 0, 0)), + # Constexpr flags + HAS_GAMMA=has_gamma, + HAS_LOOKBACK=has_lookback, + HAS_SEQ_IDX=seq_idx is not None, + ) + + return states diff --git a/mamba_ssm/ops/triton/mamba3_chunk_state_bwd.py b/mamba_ssm/ops/triton/mamba3_chunk_state_bwd.py new file mode 100644 index 000000000..404700809 --- /dev/null +++ b/mamba_ssm/ops/triton/mamba3_chunk_state_bwd.py @@ -0,0 +1,711 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Mamba-3 chunk state backward Triton kernels. +# +# Extends Mamba-2's _chunk_state_bwd_db_kernel and _chunk_state_bwd_ddAcs_stable_kernel +# to support the trapezoidal discretization used in Mamba-3: +# +# dB[c,t,g,n] = sum_h sum_p x[c,t,h,p] * dstates[c,h,p,n] * exp(dA_last - dA_t) * gamma[t,h] +# dB_shifted[c,t,g,n] = sum_h sum_p x_shifted[c,t,h,p] * dstates[c,h,p,n] * exp(...) * beta[t,h] +# +# The lookback term is optional (controlled by HAS_LOOKBACK constexpr). +# When gamma is None, falls back to dt scaling (Mamba-2 compatibility mode). + +import math +import torch +import triton +import triton.language as tl + +from mamba_ssm.utils.determinism import ( + alloc_tile_workspace, + finalize_tile_workspace, + use_deterministic_mode, + autotune_configs, +) + + +def init_to_zero(names): + return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + + +# ============================================================================= +# Kernel 1: _mamba3_chunk_state_bwd_db_kernel +# Computes dB (and optionally dB_shifted) from dstates. +# ============================================================================= + +@triton.autotune( + configs=autotune_configs([ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + ]), + key=['chunk_size', 'dstate', 'hdim'], +) +@triton.jit +def _mamba3_chunk_state_bwd_db_kernel( + # Pointers to matrices + x_ptr, dstates_ptr, b_ptr, dA_cumsum_ptr, seq_idx_ptr, + # Weight pointers (gamma or dt for fallback) + gamma_ptr, dt_ptr, + # Lookback pointers + beta_ptr, x_shifted_ptr, + # Output pointers + db_ptr, db_shifted_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, dstate, hdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + # x strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + # dstates strides + stride_dstates_batch, stride_dstates_chunk, stride_states_head, + stride_states_hdim, stride_states_dstate, + # b strides (for ddA computation) + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + # dA_cumsum strides + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + # seq_idx strides + stride_seq_idx_batch, stride_seq_idx_seqlen, + # gamma strides (same layout as dA_cumsum) + stride_gamma_batch, stride_gamma_chunk, stride_gamma_head, stride_gamma_csize, + # dt strides (same layout as dA_cumsum) + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + # beta strides + stride_beta_batch, stride_beta_chunk, stride_beta_head, stride_beta_csize, + # x_shifted strides + stride_xs_batch, stride_xs_seqlen, stride_xs_head, stride_xs_hdim, + # db strides + stride_db_batch, stride_db_seqlen, stride_db_split, stride_db_group, stride_db_dstate, + # db_shifted strides + stride_dbs_batch, stride_dbs_seqlen, stride_dbs_split, stride_dbs_group, stride_dbs_dstate, + # ddA_cumsum strides + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_ddA_tile, + # Meta-parameters + HAS_GAMMA: tl.constexpr, + HAS_LOOKBACK: tl.constexpr, + HAS_DDA_CS: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + DETERMINISTIC_REDUCTION: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + """Backward through chunk state computation to compute dB and dB_shifted. + + Grid: (cdiv(chunk_size, BSM) * cdiv(dstate, BSN), batch * nchunks, nsplits * ngroups) + + Each program computes a tile of dB (and dB_shifted) for one (batch, chunk, group, split) + combination, iterating over heads within the split and accumulating over headdim. + + The key change from Mamba-2: instead of scaling by dt, we scale by gamma (current term) + and optionally by beta (lookback term with x_shifted). + """ + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_sg = tl.program_id(axis=2) + pid_s = pid_sg // ngroups + pid_g = pid_sg - pid_s * ngroups + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + + # Advance pointers to this batch, chunk, group, head-split + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head + db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_g * stride_db_group + pid_s * stride_db_split + dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_states_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head + if HAS_GAMMA: + gamma_ptr += pid_b * stride_gamma_batch + pid_c * stride_gamma_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_gamma_head + else: + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head + if HAS_DDA_CS: + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_g * stride_b_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_n * stride_ddA_tile + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + if HAS_LOOKBACK: + beta_ptr += pid_b * stride_beta_batch + pid_c * stride_beta_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_beta_head + x_shifted_ptr += pid_b * stride_xs_batch + pid_c * chunk_size * stride_xs_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_xs_head + db_shifted_ptr += pid_b * stride_dbs_batch + pid_c * chunk_size * stride_dbs_seqlen + pid_g * stride_dbs_group + pid_s * stride_dbs_split + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Pointers for the inner loop over headdim (K dimension) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim) + dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize + if HAS_GAMMA: + gamma_ptrs = gamma_ptr + offs_m * stride_gamma_csize + else: + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + if HAS_DDA_CS: + b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + if HAS_LOOKBACK: + beta_ptrs = beta_ptr + offs_m * stride_beta_csize + xs_ptrs = x_shifted_ptr + (offs_m[:, None] * stride_xs_seqlen + offs_k[None, :] * stride_xs_hdim) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if HAS_LOOKBACK: + acc_shifted = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if HAS_DDA_CS: + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + + nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) + for h in range(nheads_iter): + # Load x tile: (BLOCK_SIZE_M, BLOCK_SIZE_K) -- chunk_size x headdim + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + # Load dstates tile: (BLOCK_SIZE_K, BLOCK_SIZE_N) -- headdim x dstate + dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) + dstates = dstates.to(x_ptrs.dtype.element_ty) + # db_raw = x @ dstates: (BLOCK_SIZE_M, BLOCK_SIZE_N) -- chunk_size x dstate + db = tl.dot(x, dstates) + + # Compute decay scale = exp(dA_last - dA_m) for this head + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + + if not HAS_SEQ_IDX: + scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)) + else: + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0) + + # Weight: gamma for Mamba-3, dt for Mamba-2 fallback + if HAS_GAMMA: + weight_m = tl.load(gamma_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + else: + weight_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + + db *= (scale * weight_m)[:, None] + + if HAS_DDA_CS: + # Gradient wrt dA_cumsum: sum over dstate of db * b + ddA_cs = tl.sum(db * b, axis=1) + if DETERMINISTIC_REDUCTION: + tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + else: + tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + + acc += db + + # Lookback term: beta * x_shifted @ dstates + if HAS_LOOKBACK: + xs = tl.load(xs_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + db_s = tl.dot(xs, dstates) + beta_m = tl.load(beta_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + db_s *= (scale * beta_m)[:, None] + + if HAS_DDA_CS: + # Lookback contribution to ddA from B_shifted (loaded outside this kernel) + # We compute the ddA contribution in _mamba3_chunk_state_bwd_ddAcs_stable_kernel + # so we skip it here to avoid needing B_shifted pointer + pass + + acc_shifted += db_s + + # Advance lookback pointers to next head + xs_ptrs += stride_xs_head + beta_ptrs += stride_beta_head + + # Advance to next head + x_ptrs += stride_x_head + dstates_ptrs += stride_states_head + dA_cumsum_ptr += stride_dA_cs_head + dA_cumsum_ptrs += stride_dA_cs_head + if HAS_GAMMA: + gamma_ptrs += stride_gamma_head + else: + dt_ptrs += stride_dt_head + if HAS_DDA_CS: + ddA_cumsum_ptrs += stride_ddA_cs_head + + # Store dB + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate) + tl.store(db_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) + + # Store dB_shifted + if HAS_LOOKBACK: + dbs_ptrs = db_shifted_ptr + (offs_m[:, None] * stride_dbs_seqlen + offs_n[None, :] * stride_dbs_dstate) + tl.store(dbs_ptrs, acc_shifted, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) + + +_MAMBA3_CHUNK_STATE_BWD_DB_MIN_BLOCK_N = min( + cfg.kwargs['BLOCK_SIZE_N'] for cfg in _mamba3_chunk_state_bwd_db_kernel.configs +) + + +# ============================================================================= +# Kernel 2: _mamba3_chunk_state_bwd_ddAcs_stable_kernel +# Computes ddA_cumsum contribution from the chunk state computation path. +# ============================================================================= + +@triton.autotune( + configs=autotune_configs([ + triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + ]), + key=['chunk_size', 'hdim', 'dstate'], +) +@triton.jit +def _mamba3_chunk_state_bwd_ddAcs_stable_kernel( + # Pointers to matrices + x_ptr, b_ptr, dstates_ptr, dA_cumsum_ptr, seq_idx_ptr, + # Weight pointers + gamma_ptr, dt_ptr, + # Lookback pointers + beta_ptr, b_shifted_ptr, x_shifted_ptr, + # Output + ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # x strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + # b strides + stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, + # dstates strides + stride_dstates_batch, stride_dstates_chunk, stride_states_head, + stride_states_hdim, stride_states_dstate, + # dA_cumsum strides + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + # seq_idx strides + stride_seq_idx_batch, stride_seq_idx_seqlen, + # gamma strides + stride_gamma_batch, stride_gamma_chunk, stride_gamma_head, stride_gamma_csize, + # dt strides + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + # beta strides + stride_beta_batch, stride_beta_chunk, stride_beta_head, stride_beta_csize, + # b_shifted strides + stride_bs_batch, stride_bs_seqlen, stride_bs_head, stride_bs_dstate, + # x_shifted strides + stride_xs_batch, stride_xs_seqlen, stride_xs_head, stride_xs_hdim, + # ddA_cumsum strides + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_ddA_tile, + # Meta-parameters + HAS_GAMMA: tl.constexpr, + HAS_LOOKBACK: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + DETERMINISTIC_REDUCTION: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + """Backward through chunk state to compute ddA_cumsum. + + Grid: (cdiv(chunk_size, BSM) * cdiv(hdim, BSN), batch * nchunks, nheads) + + For each (batch, chunk, head), computes: + ddA[t] = sum_n sum_p B[t,g,n] * dstates[h,p,n] * x[t,h,p] * exp(dA_last - dA_t) * gamma[t] + + sum_n sum_p B_shifted[t,g,n] * dstates[h,p,n] * x_shifted[t,h,p] * exp(...) * beta[t] + """ + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_n * stride_ddA_tile + if HAS_GAMMA: + gamma_ptr += pid_b * stride_gamma_batch + pid_c * stride_gamma_chunk + pid_h * stride_gamma_head + else: + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + if HAS_LOOKBACK: + beta_ptr += pid_b * stride_beta_batch + pid_c * stride_beta_chunk + pid_h * stride_beta_head + b_shifted_ptr += pid_b * stride_bs_batch + pid_c * chunk_size * stride_bs_seqlen + (pid_h // nheads_ngroups_ratio) * stride_bs_head + x_shifted_ptr += pid_b * stride_xs_batch + pid_c * chunk_size * stride_xs_seqlen + pid_h * stride_xs_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + # --- Compute B @ dstates^T for this tile: (chunk_size_tile, hdim_tile) --- + # Use a single pass or loop over dstate dimension + offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate) + dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate) + + if BLOCK_SIZE_DSTATE <= 128: + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptrs.dtype.element_ty) + acc = tl.dot(b, dstates) + else: + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, dstate, BLOCK_SIZE_K): + b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0) + dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + dstates = dstates.to(b_ptrs.dtype.element_ty) + acc += tl.dot(b, dstates) + b_ptrs += BLOCK_SIZE_K * stride_b_dstate + dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate + + # --- Apply scale and compute ddA contribution for current term --- + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + + if not HAS_SEQ_IDX: + scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)) + else: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0) + + acc *= scale[:, None] + + # Load x for this tile and compute ddA = sum over hdim of (B@dstates * scale) * x * weight + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + + if HAS_GAMMA: + weight_m = tl.load(gamma_ptr + offs_m * stride_gamma_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + else: + weight_m = tl.load(dt_ptr + offs_m * stride_dt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + + # ddA_cs = sum_p (acc[m,p] * x[m,p]) * weight[m] + ddt = tl.sum(acc * x, axis=1) + ddA_cs = ddt * weight_m + + # --- Lookback contribution --- + if HAS_LOOKBACK: + # Compute B_shifted @ dstates^T + offs_k_lb = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + bs_ptrs = b_shifted_ptr + (offs_m[:, None] * stride_bs_seqlen + offs_k_lb[None, :] * stride_bs_dstate) + dstates_ptrs_lb = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_lb[:, None] * stride_states_dstate) + + if BLOCK_SIZE_DSTATE <= 128: + bs = tl.load(bs_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_lb[None, :] < dstate), other=0.0) + dstates_lb = tl.load(dstates_ptrs_lb, mask=(offs_k_lb[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + dstates_lb = dstates_lb.to(bs_ptrs.dtype.element_ty) + acc_lb = tl.dot(bs, dstates_lb) + else: + acc_lb = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, dstate, BLOCK_SIZE_K): + bs = tl.load(bs_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_lb[None, :] < dstate - k), other=0.0) + dstates_lb = tl.load(dstates_ptrs_lb, mask=(offs_k_lb[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + dstates_lb = dstates_lb.to(bs_ptrs.dtype.element_ty) + acc_lb += tl.dot(bs, dstates_lb) + bs_ptrs += BLOCK_SIZE_K * stride_bs_dstate + dstates_ptrs_lb += BLOCK_SIZE_K * stride_states_dstate + + acc_lb *= scale[:, None] + + xs_ptrs = x_shifted_ptr + (offs_m[:, None] * stride_xs_seqlen + offs_n[None, :] * stride_xs_hdim) + xs = tl.load(xs_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + beta_m = tl.load(beta_ptr + offs_m * stride_beta_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + + ddt_lb = tl.sum(acc_lb * xs, axis=1) + ddA_cs += ddt_lb * beta_m + + # Store ddA_cumsum (shifted by 1 -- position 0 never contributes to state) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + if DETERMINISTIC_REDUCTION: + tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + else: + tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) + + +_MAMBA3_CHUNK_STATE_BWD_DDACS_MIN_BLOCK_N = min( + cfg.kwargs['BLOCK_SIZE_N'] for cfg in _mamba3_chunk_state_bwd_ddAcs_stable_kernel.configs +) + + +# ============================================================================= +# Python wrappers +# ============================================================================= + +def _mamba3_chunk_state_bwd_db(x, dA_cumsum, dstates, seq_idx=None, B=None, + gamma=None, beta=None, x_shifted=None, ngroups=1): + """Compute dB and dB_shifted from dstates (backward through chunk state). + + Args: + x: (batch, seqlen, nheads, headdim) -- input + dA_cumsum: (batch, nheads, nchunks, chunk_size) -- cumulative dA + dstates: (batch, nchunks, nheads, headdim, dstate) -- gradient of states + seq_idx: (batch, seqlen) or None -- document boundaries + B: (batch, seqlen, ngroups, dstate) or None -- if provided, also compute ddA_cumsum + gamma: (batch, nheads, nchunks, chunk_size) or None -- Mamba-3 current weight + beta: (batch, nheads, nchunks, chunk_size) or None -- Mamba-3 lookback weight + x_shifted: (batch, seqlen, nheads, headdim) or None -- shifted input for lookback + ngroups: int + + Returns: + If B is None: dB (batch, seqlen, ngroups, dstate) + If B is not None: (dB, ddA_cumsum) + Additionally returns dB_shifted if lookback is active, as second element of tuple. + Full return: (dB, dB_shifted_or_None) or (dB, dB_shifted_or_None, ddA_cumsum) + """ + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dA_cumsum.shape + dstate = dstates.shape[-1] + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + + has_gamma = gamma is not None + has_lookback = beta is not None and x_shifted is not None + + if has_gamma: + assert gamma.shape == (batch, nheads, nchunks, chunk_size) + if has_lookback: + assert beta.shape == (batch, nheads, nchunks, chunk_size) + assert x_shifted.shape == x.shape + + deterministic = use_deterministic_mode() + + # B strides for ddA computation + if B is not None: + assert B.shape == (batch, seqlen, ngroups, dstate) + B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3)) + tile_count = math.ceil(dstate / _MAMBA3_CHUNK_STATE_BWD_DB_MIN_BLOCK_N) + ddA_cumsum_out, stride_ddA_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + x.device, + deterministic, + zero_init=True, + ) + ddA_cumsum_strides = ( + ddA_cumsum_out.stride(0), ddA_cumsum_out.stride(2), + ddA_cumsum_out.stride(1), ddA_cumsum_out.stride(3), + ) + else: + B_strides = (0, 0, 0, 0) + ddA_cumsum_out = None + ddA_cumsum_strides = (0, 0, 0, 0) + stride_ddA_tile = 0 + + nheads_ngroups_ratio = nheads // ngroups + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) + nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) + + # Allocate dB output with split dimension for reduction + dB = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32) + if has_lookback: + dB_shifted = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32) + else: + dB_shifted = None + + # We need a dummy dt tensor for the fallback path (HAS_GAMMA=False) + # In Mamba-2 mode, gamma is None and we use dA_cumsum as a stand-in for dt strides + # (the actual dt values come from the caller's dt tensor) + # For simplicity, when gamma is None we pass dA_cumsum as dt with matching strides + # This requires the caller to pass dt separately -- but the Mamba-2 bwd_db kernel + # uses dt directly. We'll use dA_cumsum as a placeholder for strides. + + grid_db = lambda META: ( + triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, + nsplits * ngroups, + ) + + with torch.cuda.device(x.device.index): + _mamba3_chunk_state_bwd_db_kernel[grid_db]( + # Core pointers + x, dstates, B, dA_cumsum, seq_idx, + # Weight pointers + gamma, dA_cumsum, # dt_ptr placeholder (unused when HAS_GAMMA=True) + # Lookback pointers + beta, x_shifted, + # Output pointers + dB, dB_shifted, ddA_cumsum_out, + # Dimensions + chunk_size, dstate, headdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + # x strides + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + # dstates strides + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), + # B strides + *B_strides, + # dA_cumsum strides + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + # seq_idx strides + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + # gamma strides + *((gamma.stride(0), gamma.stride(2), gamma.stride(1), gamma.stride(3)) + if has_gamma else (0, 0, 0, 0)), + # dt strides (placeholder, unused when HAS_GAMMA=True) + *((dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3)) + if not has_gamma else (0, 0, 0, 0)), + # beta strides + *((beta.stride(0), beta.stride(2), beta.stride(1), beta.stride(3)) + if has_lookback else (0, 0, 0, 0)), + # x_shifted strides + *((x_shifted.stride(0), x_shifted.stride(1), x_shifted.stride(2), x_shifted.stride(3)) + if has_lookback else (0, 0, 0, 0)), + # dB strides + dB.stride(0), dB.stride(1), dB.stride(2), dB.stride(3), dB.stride(4), + # dB_shifted strides + *((dB_shifted.stride(0), dB_shifted.stride(1), dB_shifted.stride(2), + dB_shifted.stride(3), dB_shifted.stride(4)) + if has_lookback else (0, 0, 0, 0, 0)), + # ddA_cumsum strides + *ddA_cumsum_strides, stride_ddA_tile, + # Constexpr flags + HAS_GAMMA=has_gamma, + HAS_LOOKBACK=has_lookback, + HAS_DDA_CS=ddA_cumsum_out is not None, + HAS_SEQ_IDX=seq_idx is not None, + DETERMINISTIC_REDUCTION=deterministic, + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + + # Reduce over head splits + dB = dB.sum(2) + if has_lookback: + dB_shifted = dB_shifted.sum(2) + + if ddA_cumsum_out is not None: + ddA_cumsum_out = finalize_tile_workspace(ddA_cumsum_out, deterministic) + torch.cumsum(ddA_cumsum_out, dim=-1, out=ddA_cumsum_out) + + if B is None: + return dB, dB_shifted + else: + return dB, dB_shifted, ddA_cumsum_out + + +def _mamba3_chunk_state_bwd_ddAcs_stable(x, dA_cumsum, dstates, B, seq_idx=None, + gamma=None, beta=None, + x_shifted=None, B_shifted=None, ngroups=1): + """Compute ddA_cumsum from the chunk state backward path. + + This computes the gradient of the loss w.r.t. dA_cumsum through the state + computation. It extends Mamba-2's version to handle Mamba-3 trapezoidal terms. + + Args: + x: (batch, seqlen, nheads, headdim) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + dstates: (batch, nchunks, nheads, headdim, dstate) + B: (batch, seqlen, ngroups, dstate) + seq_idx: (batch, seqlen) or None + gamma: (batch, nheads, nchunks, chunk_size) or None -- Mamba-3 current weight + beta: (batch, nheads, nchunks, chunk_size) or None -- Mamba-3 lookback weight + x_shifted: (batch, seqlen, nheads, headdim) or None + B_shifted: (batch, seqlen, ngroups, dstate) or None + ngroups: int + + Returns: + ddA_cumsum: (batch, nheads, nchunks, chunk_size) + """ + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dA_cumsum.shape + _, _, ngroups_B, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups_B, dstate) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + + has_gamma = gamma is not None + has_lookback = beta is not None and x_shifted is not None and B_shifted is not None + + if has_gamma: + assert gamma.shape == (batch, nheads, nchunks, chunk_size) + if has_lookback: + assert beta.shape == (batch, nheads, nchunks, chunk_size) + assert x_shifted.shape == x.shape + assert B_shifted.shape == B.shape + + deterministic = use_deterministic_mode() + tile_count = math.ceil(headdim / _MAMBA3_CHUNK_STATE_BWD_DDACS_MIN_BLOCK_N) + ddA_cumsum_out, stride_ddA_tile = alloc_tile_workspace( + (batch, nheads, nchunks, chunk_size), + tile_count, + torch.float32, + x.device, + deterministic, + zero_init=True, + ) + + grid_ddtcs = lambda META: ( + triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, + nheads, + ) + + with torch.cuda.device(x.device.index): + _mamba3_chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs]( + # Core pointers + x, B, dstates, dA_cumsum, seq_idx, + # Weight pointers + gamma, dA_cumsum, # dt placeholder + # Lookback pointers + beta, B_shifted, x_shifted, + # Output + ddA_cumsum_out, + # Dimensions + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + # x strides + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + # B strides + B.stride(0), B.stride(1), B.stride(2), B.stride(-1), + # dstates strides + dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), + # dA_cumsum strides + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + # seq_idx strides + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + # gamma strides + *((gamma.stride(0), gamma.stride(2), gamma.stride(1), gamma.stride(3)) + if has_gamma else (0, 0, 0, 0)), + # dt strides (placeholder) + *((dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3)) + if not has_gamma else (0, 0, 0, 0)), + # beta strides + *((beta.stride(0), beta.stride(2), beta.stride(1), beta.stride(3)) + if has_lookback else (0, 0, 0, 0)), + # B_shifted strides + *((B_shifted.stride(0), B_shifted.stride(1), B_shifted.stride(2), B_shifted.stride(-1)) + if has_lookback else (0, 0, 0, 0)), + # x_shifted strides + *((x_shifted.stride(0), x_shifted.stride(1), x_shifted.stride(2), x_shifted.stride(3)) + if has_lookback else (0, 0, 0, 0)), + # ddA_cumsum strides + ddA_cumsum_out.stride(0), ddA_cumsum_out.stride(2), ddA_cumsum_out.stride(1), ddA_cumsum_out.stride(3), stride_ddA_tile, + # Constexpr flags + HAS_GAMMA=has_gamma, + HAS_LOOKBACK=has_lookback, + HAS_SEQ_IDX=seq_idx is not None, + DETERMINISTIC_REDUCTION=deterministic, + BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16), + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + ) + + ddA_cumsum_out = finalize_tile_workspace(ddA_cumsum_out, deterministic) + # Cumsum starting from position 1 (position 0 does not contribute to state) + torch.cumsum(ddA_cumsum_out[..., 1:], dim=-1, out=ddA_cumsum_out[..., 1:]) + return ddA_cumsum_out diff --git a/mamba_ssm/ops/triton/mamba3_combined.py b/mamba_ssm/ops/triton/mamba3_combined.py new file mode 100644 index 000000000..d4901df9c --- /dev/null +++ b/mamba_ssm/ops/triton/mamba3_combined.py @@ -0,0 +1,1156 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""Mamba-3 fused chunked SSD with Triton forward and backward. + +Uses Triton kernels for the forward pass (speed improvement over pure PyTorch). +Backward uses a Triton backward pipeline when available (SISO, CUDA), falling +back to PyTorch autograd recomputation otherwise (MIMO, CPU, missing kernels). + +Architecture: + - Forward: Triton kernels for SISO on CUDA; PyTorch fallback for MIMO or CPU. + - Backward: Triton backward pipeline for SISO on CUDA; PyTorch recompute fallback. + +This provides a drop-in replacement for mamba3_chunk_scan_combined. +""" + +import math + +import torch +import torch.nn.functional as F + +from mamba_ssm.utils.torch import custom_bwd, custom_fwd + +from einops import rearrange, repeat + +from mamba_ssm.ops.triton.mamba3_ssd import ( + mamba3_ssd_chunked, + mamba3_chunk_scan_combined as _mamba3_chunk_scan_combined_ref, + apply_rotary_emb_to_bc, +) + + +def _triton_forward_available(): + """Check if the Mamba-3 Triton forward kernels are importable. + + These kernels are provided by separate modules (mamba3_chunk_state, mamba3_chunk_scan) + and may not be available if they have not been compiled or if the GPU does not support them. + """ + try: + from mamba_ssm.ops.triton.mamba3_chunk_state import _mamba3_chunk_state_fwd # noqa: F401 + from mamba_ssm.ops.triton.mamba3_chunk_scan import _mamba3_chunk_scan_fwd # noqa: F401 + return True + except ImportError: + return False + + +def _triton_backward_available(): + """Check if the Mamba-3 Triton backward kernels are importable. + + These kernels are provided by separate modules and may not be available + if they have not been compiled. + """ + try: + from mamba_ssm.ops.triton.mamba3_chunk_scan_bwd import ( # noqa: F401 + _mamba3_chunk_scan_chunk_state_bwd_dx, + _mamba3_chunk_scan_bwd_dcb, + _mamba3_chunk_scan_bwd_ddAcs_stable, + ) + from mamba_ssm.ops.triton.mamba3_chunk_state_bwd import ( # noqa: F401 + _mamba3_chunk_state_bwd_db, + _mamba3_chunk_state_bwd_ddAcs_stable, + ) + return True + except ImportError: + return False + + +# Cache the availability checks so we only do them once. +_TRITON_FWD_AVAILABLE = None +_TRITON_BWD_AVAILABLE = None + + +def _check_triton_fwd(): + global _TRITON_FWD_AVAILABLE + if _TRITON_FWD_AVAILABLE is None: + _TRITON_FWD_AVAILABLE = _triton_forward_available() + return _TRITON_FWD_AVAILABLE + + +def _check_triton_bwd(): + global _TRITON_BWD_AVAILABLE + if _TRITON_BWD_AVAILABLE is None: + _TRITON_BWD_AVAILABLE = _triton_backward_available() + return _TRITON_BWD_AVAILABLE + + +def _mamba3_chunk_scan_combined_bwd( + dout, x, dt, A, B, C, out, chunk_size, + D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, + dt_softplus=False, dt_limit=(0.0, float("inf")), + gamma=None, beta=None, theta=None, initial_prev_Bx=None, + ngroups=1, + dfinal_states=None, +): + """Triton backward for Mamba-3 chunked SSD. + + Follows the Mamba-2 backward pattern from _mamba_chunk_scan_combined_bwd + with extensions for Mamba-3's trapezoidal discretization, RoPE, and shift. + + Steps: + 1. Pad seqlen to multiple of chunk_size (same as forward) + 2. Recompute forward intermediates (dA_cumsum, dt_out, states, CB) + 3. If z: compute dz via _chunk_scan_bwd_dz (reuse Mamba-2) + 4. Compute dstates via _chunk_scan_bwd_dstates (reuse Mamba-2) + 5. Backward state passing via _state_passing_bwd (reuse Mamba-2) + 6. Compute dx, dgamma, dbeta, ddt, dD, dx_shifted via _mamba3_chunk_scan_chunk_state_bwd_dx + 7. Accumulate dx_shifted into dx (shift backward) + 8. Compute dB, dB_shifted, ddA_next via _mamba3_chunk_state_bwd_db + 9. Compute dC, ddA_cumsum_prev via _chunk_scan_bwd_dC (reuse Mamba-2) + 10. Compute dCB, dCB_shifted via _mamba3_chunk_scan_bwd_dcb + 11. Convert dCB -> dB_scan, dC_scan via _bmm_chunk_bwd (reuse Mamba-2) + 12. Handle dCB_shifted -> additional dB, dC via _bmm_chunk_bwd + 13. Compute ddA_cumsum from scan path via _mamba3_chunk_scan_bwd_ddAcs_stable + 14. If initial_prev_Bx: compute backward through state+output corrections + 15. Accumulate all ddA contributions (ddA_next, ddA_prev, ddA_ipBx) + 16. Convert ddA -> ddt, dA, ddt_bias via _chunk_cumsum_bwd (reuse Mamba-2) + 17. If theta: backward through RoPE; else reduce heads to groups + 18. Convert dgamma_c, dbeta_c from chunked layout to flat layout + 19. Unpad all outputs to original seqlen + """ + from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd + from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd + from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd + from mamba_ssm.ops.triton.ssd_chunk_scan import ( + _chunk_scan_bwd_dz, + _chunk_scan_bwd_dstates, + _chunk_scan_bwd_dC, + ) + from mamba_ssm.ops.triton.mamba3_chunk_state import _mamba3_chunk_state_fwd + from mamba_ssm.ops.triton.mamba3_chunk_scan_bwd import ( + _mamba3_chunk_scan_chunk_state_bwd_dx, + _mamba3_chunk_scan_bwd_dcb as _mamba3_chunk_scan_bwd_dcb_fn, + _mamba3_chunk_scan_bwd_ddAcs_stable as _mamba3_scan_bwd_ddAcs, + ) + from mamba_ssm.ops.triton.mamba3_chunk_state_bwd import ( + _mamba3_chunk_state_bwd_db, + _mamba3_chunk_state_bwd_ddAcs_stable, + ) + + if dout.stride(-1) != 1: + dout = dout.contiguous() + + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups_bc, dstate = B.shape + nheads_per_group = nheads // ngroups_bc + use_trapezoidal = gamma is not None + + assert dout.shape == (batch, seqlen, nheads, headdim) + assert dt.shape == (batch, seqlen, nheads) + assert A.shape == (nheads,) + assert nheads % ngroups_bc == 0 + assert B.shape == (batch, seqlen, ngroups_bc, dstate) + assert C.shape == B.shape + assert out.shape == x.shape + + # ---- Step 1: Pad seqlen to multiple of chunk_size ---- + pad_len = (chunk_size - seqlen % chunk_size) % chunk_size + if pad_len > 0: + x = F.pad(x, (0, 0, 0, 0, 0, pad_len)) + dt = F.pad(dt, (0, 0, 0, pad_len)) + B = F.pad(B, (0, 0, 0, 0, 0, pad_len)) + C = F.pad(C, (0, 0, 0, 0, 0, pad_len)) + dout = F.pad(dout, (0, 0, 0, 0, 0, pad_len)) + out = F.pad(out, (0, 0, 0, 0, 0, pad_len)) + if gamma is not None: + gamma = F.pad(gamma, (0, 0, 0, pad_len)) + if beta is not None: + beta = F.pad(beta, (0, 0, 0, pad_len)) + if z is not None: + z = F.pad(z, (0, 0, 0, 0, 0, pad_len)) + if theta is not None: + theta = F.pad(theta, (0, 0, 0, 0, 0, pad_len)) + if seq_idx is not None: + seq_idx = F.pad(seq_idx, (0, pad_len), value=-1) + + padded_seqlen = seqlen + pad_len + nchunks = padded_seqlen // chunk_size + + # ---- Step 2: Recompute forward intermediates ---- + # Clone dt to avoid Triton context issues (same as Mamba-2) + dt_in = dt.clone() + dA_cumsum, dt_out = _chunk_cumsum_fwd( + dt_in, A, chunk_size, dt_bias=dt_bias, + dt_softplus=dt_softplus, dt_limit=dt_limit, + ) + + # RoPE + if theta is not None: + B_heads, C_heads = apply_rotary_emb_to_bc(B, C, theta, nheads, ngroups_bc) + else: + B_heads = repeat(B, "b l g n -> b l (g h) n", h=nheads_per_group) + C_heads = repeat(C, "b l g n -> b l (g h) n", h=nheads_per_group) + + if B_heads.stride(-1) != 1: + B_heads = B_heads.contiguous() + if C_heads.stride(-1) != 1: + C_heads = C_heads.contiguous() + if x.stride(-1) != 1 and x.stride(1) != 1: + x = x.contiguous() + + # Shift computation for trapezoidal + if use_trapezoidal: + gamma_c = rearrange(gamma, "b (c l) h -> b h c l", l=chunk_size) + beta_c = rearrange(beta, "b (c l) h -> b h c l", l=chunk_size) if beta is not None else None + + B_shifted = torch.zeros_like(B_heads) + x_shifted = torch.zeros_like(x) + B_shifted[:, 1:] = B_heads[:, :-1] + x_shifted[:, 1:] = x[:, :-1] + + if seq_idx is not None: + shift_valid = torch.ones(batch, padded_seqlen, dtype=torch.bool, device=x.device) + shift_valid[:, 1:] = seq_idx[:, 1:] == seq_idx[:, :-1] + shift_valid[:, 0] = False + sv = shift_valid[:, :, None, None] + B_shifted = B_shifted * sv + x_shifted = x_shifted * sv + + if B_shifted.stride(-1) != 1: + B_shifted = B_shifted.contiguous() + if x_shifted.stride(-1) != 1 and x_shifted.stride(1) != 1: + x_shifted = x_shifted.contiguous() + else: + gamma_c = dt_out + beta_c = None + B_shifted = None + x_shifted = None + + # Chunk states + states = _mamba3_chunk_state_fwd( + B_heads, x, dt_out, dA_cumsum, + gamma=gamma_c, + beta=beta_c, B_shifted=B_shifted, x_shifted=x_shifted, + seq_idx=seq_idx, + states_in_fp32=True, + ) + + # initial_prev_Bx correction on chunk 0 state + if initial_prev_Bx is not None and use_trapezoidal and beta is not None: + beta_flat = rearrange(beta, "b (c l) h -> b c l h", l=chunk_size) + beta_0 = beta_flat[:, 0, 0, :] + correction = rearrange(beta_0, "b h -> b h 1 1") * initial_prev_Bx.float() + decay_states = torch.exp(dA_cumsum[:, :, 0, -1:] - dA_cumsum[:, :, 0, :]) + decay_from_0 = decay_states[:, :, 0] + states[:, 0] = states[:, 0] + rearrange(decay_from_0, "b h -> b h 1 1") * correction + + # State passing + states, _ = _state_passing_fwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum[:, :, :, -1], + initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, + seq_idx=seq_idx, chunk_size=chunk_size, + ) + states = rearrange(states, "... (p n) -> ... p n", n=dstate) + + # CB + CB = _bmm_chunk_fwd(C_heads, B_heads, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) + CB_shifted = None + if use_trapezoidal and B_shifted is not None: + CB_shifted = _bmm_chunk_fwd(C_heads, B_shifted, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) + + # ---- Step 3: dz computation (if z gating present) ---- + if z is not None: + dz, dout, dD, *rest = _chunk_scan_bwd_dz( + x, z, out, dout, chunk_size=chunk_size, has_ddAcs=False, D=D, + ) + else: + dz = None + + # ---- Step 4: dstates ---- + dstates = _chunk_scan_bwd_dstates(C_heads, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype) + + # ---- Step 5: Backward state passing ---- + dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum[:, :, :, -1], + rearrange(dstates, "... p n -> ... (p n)"), + dfinal_states=rearrange(dfinal_states, "... p n -> ... (p n)") if dfinal_states is not None else None, + seq_idx=seq_idx, + has_initial_states=initial_states is not None, + dstates_dtype=x.dtype, + states_dtype=x.dtype, + chunk_size=chunk_size, + ) + states = rearrange(states, "... (p n) -> ... p n", n=dstate) + dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate) + dinitial_states = rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate) if dinitial_states is not None else None + + # ---- Step 6: dx, dgamma, dbeta, ddt, dD, dx_shifted ---- + # The Mamba-3 dx kernel handles both current and lookback terms internally, + # returning all 6 outputs: dx, ddt, dD, dgamma_c, dbeta_c, dx_shifted. + # B_heads is per-head, so we pass ngroups=nheads (nheads_ngroups_ratio=1). + dx, ddt, dD_from_x, dgamma_c, dbeta_c, dx_shifted = _mamba3_chunk_scan_chunk_state_bwd_dx( + x, dt_out, dA_cumsum, B_heads, CB, dout, dstates, + D=D, seq_idx=seq_idx, + gamma=gamma_c, beta=beta_c, + CB_shifted=CB_shifted, x_shifted=x_shifted, + B_shifted=B_shifted, + ) + + # ---- Step 7: Accumulate dx_shifted into dx (shift backward) ---- + # Contribution at position t in shifted maps back to position t-1 in original. + if dx_shifted is not None: + dx[:, :-1] += dx_shifted[:, 1:] + + # ---- Step 8: dB, dB_shifted, ddA_next from chunk state backward ---- + # B_heads is per-head (nheads groups of 1), so ngroups=nheads. + if use_trapezoidal: + # In trapezoidal mode, the db kernel only computes current term's ddA + # (lookback ddA is skipped). Use B=None to skip ddA in the db kernel, + # then call _mamba3_chunk_state_bwd_ddAcs_stable for the full ddA + # (both current and lookback terms). + dB_heads, dB_shifted_state = _mamba3_chunk_state_bwd_db( + x, dA_cumsum, dstates, + seq_idx=seq_idx, B=None, ngroups=nheads, + gamma=gamma_c, beta=beta_c, + x_shifted=x_shifted, + ) + ddA_next = _mamba3_chunk_state_bwd_ddAcs_stable( + x, dA_cumsum, dstates, B_heads, + seq_idx=seq_idx, ngroups=nheads, + gamma=gamma_c, beta=beta_c, + x_shifted=x_shifted, B_shifted=B_shifted, + ) + else: + # Mamba-2 mode: db kernel folds ddA into its return (no lookback). + dB_heads, dB_shifted_state, ddA_next = _mamba3_chunk_state_bwd_db( + x, dA_cumsum, dstates, + seq_idx=seq_idx, B=B_heads, ngroups=nheads, + gamma=gamma_c, beta=beta_c, + x_shifted=x_shifted, + ) + + # Accumulate dB_shifted_state into dB_heads (shift backward) + if dB_shifted_state is not None: + dB_heads[:, :-1] += dB_shifted_state[:, 1:] + + # ---- Step 9: dC via chunk scan backward (reuse Mamba-2) ---- + # C_heads is per-head, so pass ngroups=nheads for correct nheads_ngroups_ratio=1. + dC_heads, ddA_cumsum_prev = _chunk_scan_bwd_dC( + states.to(x.dtype), dA_cumsum, dout, + seq_idx=seq_idx, C=C_heads, ngroups=nheads, + ) + + # ---- Step 10: dCB, dCB_shifted via chunk scan backward ---- + # The Mamba-3 dcb kernel handles both current and lookback terms internally. + dCB, dCB_shifted = _mamba3_chunk_scan_bwd_dcb_fn( + x, dA_cumsum, dout, + seq_idx=seq_idx, ngroups=nheads, + gamma=gamma_c, beta=beta_c, + x_shifted=x_shifted, + ) + + # ---- Step 11: Convert dCB -> additional dB, dC via BMM backward ---- + dCB = dCB.to(CB.dtype) + # dCB[b,c,g,m,n]: C^T @ B product gradient + # dB_from_cb = C^T @ dCB (adds to dB_heads) + # dC_from_cb = dCB^T @ B (adds to dC_heads) + dB_scan = torch.empty_like(B_heads) + _bmm_chunk_bwd(C_heads, dCB, residual=dB_heads, out=dB_scan) + dC_scan = torch.empty_like(C_heads) + _bmm_chunk_bwd(B_heads, rearrange(dCB, "... l s -> ... s l"), residual=dC_heads, out=dC_scan) + + # ---- Step 12: Handle dCB_shifted -> additional dB, dC ---- + if dCB_shifted is not None: + dCB_shifted = dCB_shifted.to(CB_shifted.dtype) + # dB_shifted_from_cb = C^T @ dCB_shifted + dB_shifted_bmm = _bmm_chunk_bwd(C_heads, dCB_shifted) + # dC_from_lb_bmm = dCB_shifted^T @ B_shifted + dC_from_lb_bmm = _bmm_chunk_bwd(B_shifted, rearrange(dCB_shifted, "... l s -> ... s l")) + # Shift backward for dB_shifted: position t in shifted -> position t-1 in original + dB_scan[:, :-1] += dB_shifted_bmm[:, 1:] + # dC is NOT shifted (C appears unshifted in CB_shifted = C^T @ B_shifted) + dC_scan += dC_from_lb_bmm + + # If z is not None, dD was already computed in step 3 + if z is None: + dD = dD_from_x + + # ---- Step 13: ddA_cumsum from scan path ---- + ddA_scan = _mamba3_scan_bwd_ddAcs( + x, dA_cumsum, dout, CB, + seq_idx=seq_idx, + gamma=gamma_c, beta=beta_c, + x_shifted=x_shifted, CB_shifted=CB_shifted, + ngroups=nheads, + ) + + # Note: ddA from the state path (ddA_next) is computed in step 8. + # In trapezoidal mode, it uses _mamba3_chunk_state_bwd_ddAcs_stable + # (covers both current and lookback terms). In Mamba-2 mode, it's + # folded into _mamba3_chunk_state_bwd_db's return (current term only). + # ddA_cumsum_prev is already computed by _chunk_scan_bwd_dC (step 9). + + # ---- Step 14: initial_prev_Bx backward ---- + # Both corrections (state + output) are additive, so the existing Triton + # backward kernels produce correct gradients for all OTHER parameters. + # We compute gradients through the ipBx corrections and accumulate ddA/dC/dbeta. + # Must run before Step 15 (ddA accumulation) and Step 17 (RoPE/group reduction). + d_initial_prev_Bx = None + ddA_ipBx = None + dbeta_0_ipBx = None + + if initial_prev_Bx is not None and use_trapezoidal and beta is not None: + ipBx = initial_prev_Bx.float() + beta_flat = rearrange(beta, "b (c l) h -> b c l h", l=chunk_size) + beta_0 = beta_flat[:, 0, 0, :] # (batch, nheads) + beta_0_r = rearrange(beta_0, "b h -> b h 1 1") + correction = beta_0_r * ipBx # (b, h, P, N) + + # -- State correction backward -- + # Forward: states[:, 0] += decay_from_0 * correction + # where decay_from_0 = exp(dA_cumsum[:,:,0,-1] - dA_cumsum[:,:,0,0]) + decay_states_ipBx = torch.exp(dA_cumsum[:, :, 0, -1:] - dA_cumsum[:, :, 0, :]) + decay_from_0 = decay_states_ipBx[:, :, 0] # (b, h) + decay_from_0_r = rearrange(decay_from_0, "b h -> b h 1 1") + + dstates_0 = dstates[:, 0].float() # (b, h, P, N) + d_decay_from_0 = (dstates_0 * correction).sum(dim=(-2, -1)) # (b, h) + d_correction_state = decay_from_0_r * dstates_0 # (b, h, P, N) + d_ipBx_state = d_correction_state * beta_0_r # (b, h, P, N) + d_beta_0_state = (d_correction_state * ipBx).sum(dim=(-2, -1)) # (b, h) + ddA_from_state = d_decay_from_0 * decay_from_0 # (b, h) + + # -- Output correction backward -- + # Forward: Y_corr = einsum("blhn,bhpn,bhl->blhp", C_chunk0, correction, decay_0_to_m) + # out[:, :chunk_size] += Y_corr + decay_0_to_m = torch.exp(dA_cumsum[:, :, 0, :] - dA_cumsum[:, :, 0, 0:1]) # (b, h, L) + C_chunk0 = C_heads[:, :chunk_size].float() # (b, L, h, N) + dout_chunk0 = dout[:, :chunk_size].float() # (b, L, h, P) + + d_correction_out = torch.einsum( + "blhn,blhp,bhl->bhpn", C_chunk0, dout_chunk0, decay_0_to_m, + ) + dC_ipBx_chunk0 = torch.einsum( + "blhp,bhpn,bhl->blhn", dout_chunk0, correction, decay_0_to_m, + ) + d_decay_0_to_m = torch.einsum( + "blhn,bhpn,blhp->bhl", C_chunk0, correction, dout_chunk0, + ) + + d_ipBx_out = d_correction_out * beta_0_r # (b, h, P, N) + d_beta_0_out = (d_correction_out * ipBx).sum(dim=(-2, -1)) # (b, h) + ddA_from_output = d_decay_0_to_m * decay_0_to_m # (b, h, L) + + # -- Accumulate d_initial_prev_Bx -- + d_initial_prev_Bx = (d_ipBx_state + d_ipBx_out).to(initial_prev_Bx.dtype) + dbeta_0_ipBx = d_beta_0_state + d_beta_0_out # (b, h) + + # -- Build ddA_ipBx at chunk 0 -- + ddA_ipBx = torch.zeros_like(dA_cumsum) # (b, h, nchunks, L) + # State correction: d/d(dA_cumsum[:,:,0,-1]) += ddA_from_state, [:,:,0,0] -= ddA_from_state + ddA_ipBx[:, :, 0, -1] += ddA_from_state + ddA_ipBx[:, :, 0, 0] -= ddA_from_state + # Output correction: d/d(dA_cumsum[:,:,0,:]) += ddA_from_output, [:,:,0,0] -= sum + ddA_ipBx[:, :, 0, :] += ddA_from_output + ddA_ipBx[:, :, 0, 0] -= ddA_from_output.sum(dim=-1) + + # -- Accumulate dC from output correction (per-head, at chunk 0 positions) -- + dC_scan[:, :chunk_size] += dC_ipBx_chunk0.to(dC_scan.dtype) + + # ---- Step 15: Accumulate all ddA contributions ---- + # ddA_cumsum_prev is in cumsum space → convert to per-position dA space via reverse cumsum + ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum + if ddA_ipBx is not None: + # ddA_ipBx is also in cumsum space, merge before reverse cumsum + ddA_cumsum_prev = ddA_cumsum_prev + ddA_ipBx + ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1]) + + ddA = ddA_scan + ddA_next + ddA_prev + + # ---- Step 16: ddA -> ddt, dA, ddt_bias ---- + if use_trapezoidal: + ddt_for_cumsum = torch.zeros_like(ddt) + else: + ddt_for_cumsum = dgamma_c if dgamma_c is not None else ddt + ddt_out, dA, ddt_bias_out = _chunk_cumsum_bwd( + ddA, ddt_for_cumsum, dt_in, A, dt_bias=dt_bias, + dt_softplus=dt_softplus, dt_limit=dt_limit, + ) + + # ---- Step 17: RoPE backward ---- + # dB_scan and dC_scan are at head level. We need to convert back to group level + # and also backprop through RoPE if theta is present. + if theta is not None: + dtheta, dB_group, dC_group = _rope_bwd_pytorch( + dB_scan, dC_scan, B, C, theta, nheads, ngroups_bc, + ) + else: + dtheta = None + # Reduce from heads back to groups + dB_group = rearrange(dB_scan, "b l (g h) n -> b l g h n", g=ngroups_bc).sum(dim=3) + dC_group = rearrange(dC_scan, "b l (g h) n -> b l g h n", g=ngroups_bc).sum(dim=3) + + # ---- Step 18: Convert dgamma_c, dbeta_c from chunked layout ---- + dgamma = None + dbeta = None + if use_trapezoidal: + if dgamma_c is not None: + dgamma = rearrange(dgamma_c, "b h c l -> b (c l) h") + if dbeta_c is not None: + dbeta = rearrange(dbeta_c, "b h c l -> b (c l) h") + # Accumulate ipBx contribution to dbeta at position (chunk=0, pos=0) + if dbeta_0_ipBx is not None: + if dbeta is None: + dbeta = torch.zeros_like(dgamma) if dgamma is not None else torch.zeros(batch, padded_seqlen, nheads, device=x.device, dtype=x.dtype) + dbeta[:, 0] += dbeta_0_ipBx.to(dbeta.dtype) + + # ---- Step 19: Unpad all outputs to original seqlen ---- + if pad_len > 0: + dx = dx[:, :seqlen] + dB_group = dB_group[:, :seqlen] + dC_group = dC_group[:, :seqlen] + if dz is not None: + dz = dz[:, :seqlen] + if dgamma is not None: + dgamma = dgamma[:, :seqlen] + if dbeta is not None: + dbeta = dbeta[:, :seqlen] + if dtheta is not None: + dtheta = dtheta[:, :seqlen] + + # ddt_out is (batch, seqlen, nheads) from _chunk_cumsum_bwd -- already correct shape + # since dt_in was padded and _chunk_cumsum_bwd produces matching shape. + # We need to unpad it too. + if pad_len > 0: + ddt_out = ddt_out[:, :seqlen] + + return ( + dx, ddt_out, dA, dB_group, dC_group, dD, dz, ddt_bias_out, + dinitial_states, dgamma, dbeta, dtheta, d_initial_prev_Bx, + ) + + +def _rope_bwd_pytorch(dB_heads, dC_heads, B, C, theta, nheads, ngroups): + """Backward through RoPE using PyTorch autograd. + + This is a fallback for when Triton RoPE backward kernels are unavailable. + We recompute the RoPE forward with autograd enabled, then use + torch.autograd.grad to get gradients. + """ + B_detach = B.detach().requires_grad_(True) + C_detach = C.detach().requires_grad_(True) + theta_detach = theta.detach().requires_grad_(True) + + with torch.enable_grad(): + B_rot, C_rot = apply_rotary_emb_to_bc( + B_detach, C_detach, theta_detach, nheads, ngroups, + ) + + # Compute gradients + grads = torch.autograd.grad( + [B_rot, C_rot], + [B_detach, C_detach, theta_detach], + [dB_heads, dC_heads], + allow_unused=True, + ) + + dB_group = grads[0] + dC_group = grads[1] + dtheta = grads[2] + + return dtheta, dB_group, dC_group + + +class Mamba3ChunkScanCombinedFn(torch.autograd.Function): + """Autograd function for Mamba-3 chunked SSD with Triton-accelerated forward and backward. + + Forward: Uses Triton kernels when available (SISO, CUDA tensors). Falls back to + the PyTorch reference implementation for MIMO or when Triton is unavailable. + + Backward: Uses Triton backward pipeline when available (SISO, CUDA, kernels present). + Falls back to PyTorch autograd recomputation otherwise. + """ + + @staticmethod + @custom_fwd + def forward(ctx, x, dt, A, B, C, chunk_size, + D=None, z=None, dt_bias=None, + initial_states=None, seq_idx=None, + dt_softplus=False, dt_limit=(0.0, float("inf")), + return_final_states=False, + gamma=None, beta=None, theta=None, + initial_prev_Bx=None, mimo_rank=0, ngroups=1): + """Forward pass using Triton kernels when possible, PyTorch otherwise. + + Falls back to PyTorch for: + - MIMO (mimo_rank > 0): Triton kernels are SISO-only. + - CPU tensors: Triton requires CUDA. + - Missing Triton kernel modules. + + All non-tensor arguments are stored on ctx as attributes (not saved_tensors). + """ + # Determine whether to use Triton forward + use_triton = ( + mimo_rank == 0 + and _check_triton_fwd() + and x.is_cuda + ) + + # Track which output to save for backward (out_x for dz kernel when z is present) + out_for_bwd = None + + if use_triton: + try: + out, out_x, final_states = _mamba3_triton_fwd( + x, dt, A, B, C, chunk_size, + D=D, z=z, dt_bias=dt_bias, + initial_states=initial_states, seq_idx=seq_idx, + dt_softplus=dt_softplus, dt_limit=dt_limit, + return_final_states=return_final_states, + gamma=gamma, beta=beta, theta=theta, + initial_prev_Bx=initial_prev_Bx, + ngroups=ngroups, + ) + # Save pre-z output when z is present (needed by _chunk_scan_bwd_dz) + out_for_bwd = out if z is None else out_x + except Exception as e: + # Graceful fallback if Triton kernels fail at runtime + # (e.g., unsupported GPU, shape issues) + import warnings + warnings.warn( + f"Triton forward failed ({type(e).__name__}: {e}), " + "falling back to PyTorch forward.", + stacklevel=2, + ) + out, final_states = _mamba3_pytorch_fwd( + x, dt, A, B, C, chunk_size, + D=D, z=z, dt_bias=dt_bias, + initial_states=initial_states, seq_idx=seq_idx, + dt_softplus=dt_softplus, dt_limit=dt_limit, + return_final_states=return_final_states, + gamma=gamma, beta=beta, theta=theta, + initial_prev_Bx=initial_prev_Bx, + mimo_rank=mimo_rank, ngroups=ngroups, + ) + out_for_bwd = out + else: + out, final_states = _mamba3_pytorch_fwd( + x, dt, A, B, C, chunk_size, + D=D, z=z, dt_bias=dt_bias, + initial_states=initial_states, seq_idx=seq_idx, + dt_softplus=dt_softplus, dt_limit=dt_limit, + return_final_states=return_final_states, + gamma=gamma, beta=beta, theta=theta, + initial_prev_Bx=initial_prev_Bx, + mimo_rank=mimo_rank, ngroups=ngroups, + ) + out_for_bwd = out + + # Save inputs for backward (recompute strategy -- we only save tensors, + # non-tensor config goes on ctx as attributes). + # When z is present and we used the Triton path, out_for_bwd is out_x + # (the pre-z output), which is needed by _chunk_scan_bwd_dz. + # When z is None, out_for_bwd == out. + ctx.save_for_backward(x, dt, A, B, C, D, z, dt_bias, + initial_states, seq_idx, gamma, beta, theta, + initial_prev_Bx, out_for_bwd) + ctx.chunk_size = chunk_size + ctx.dt_softplus = dt_softplus + ctx.dt_limit = dt_limit + ctx.return_final_states = return_final_states + ctx.mimo_rank = mimo_rank + ctx.ngroups = ngroups + + if return_final_states: + return out, final_states + return out + + @staticmethod + @custom_bwd + def backward(ctx, dout, *args): + """Backward pass using Triton backward pipeline when available. + + Falls back to PyTorch autograd recomputation for: + - MIMO (mimo_rank > 0) + - CPU tensors + - Missing Triton backward kernels + """ + (x, dt, A, B, C, D, z, dt_bias, + initial_states, seq_idx, gamma, beta, theta, + initial_prev_Bx, out) = ctx.saved_tensors + + dfinal_states = args[0] if ctx.return_final_states and len(args) > 0 else None + + # Determine whether to use Triton backward + # Note: z + initial_prev_Bx forces PyTorch backward because the Triton + # forward asserts z is None when ipBx is present (they interact in the + # output correction). If z is present with ipBx, the forward fell back + # to PyTorch, so the saved out is the final output (not pre-z), making + # _chunk_scan_bwd_dz incorrect. + use_triton_bwd = ( + ctx.mimo_rank == 0 + and x.is_cuda + and _check_triton_bwd() + and not (z is not None and initial_prev_Bx is not None) + ) + + if use_triton_bwd: + try: + grads = _mamba3_chunk_scan_combined_bwd( + dout, x, dt, A, B, C, out, ctx.chunk_size, + D=D, z=z, dt_bias=dt_bias, + initial_states=initial_states, seq_idx=seq_idx, + dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit, + gamma=gamma, beta=beta, theta=theta, + initial_prev_Bx=initial_prev_Bx, + ngroups=ctx.ngroups, + dfinal_states=dfinal_states, + ) + (dx, ddt, dA, dB, dC, dD_val, dz_val, ddt_bias, + dinitial_states, dgamma, dbeta, dtheta, + d_initial_prev_Bx) = grads + + return ( + dx, # x + ddt, # dt + dA, # A + dB, # B + dC, # C + None, # chunk_size (int) + dD_val, # D + dz_val, # z + ddt_bias, # dt_bias + dinitial_states, # initial_states + None, # seq_idx (int tensor) + None, # dt_softplus (bool) + None, # dt_limit (tuple) + None, # return_final_states (bool) + dgamma, # gamma + dbeta, # beta + dtheta, # theta + d_initial_prev_Bx, # initial_prev_Bx + None, # mimo_rank (int) + None, # ngroups (int) + ) + except Exception as e: + # Fall through to PyTorch backward. + # Log the exception so silent fallbacks are visible during development. + import warnings + warnings.warn( + f"Triton backward failed ({type(e).__name__}: {e}), " + "falling back to PyTorch recompute backward.", + stacklevel=2, + ) + + # ---- PyTorch recompute fallback ---- + return _mamba3_pytorch_backward( + ctx, dout, x, dt, A, B, C, D, z, dt_bias, + initial_states, seq_idx, gamma, beta, theta, + initial_prev_Bx, dfinal_states, + ) + + +def _mamba3_pytorch_backward(ctx, dout, x, dt, A, B, C, D, z, dt_bias, + initial_states, seq_idx, gamma, beta, theta, + initial_prev_Bx, dfinal_states): + """PyTorch autograd recomputation backward (fallback path). + + Strategy: + 1. Detach all saved tensors. + 2. Re-enable requires_grad on differentiable inputs. + 3. Run PyTorch reference forward with grad tracking. + 4. Use torch.autograd.grad to compute gradients w.r.t. differentiable inputs. + 5. Return gradients in the exact order of forward's arguments. + """ + tensor_inputs = [ + ("x", x, True), + ("dt", dt, True), + ("A", A, True), + ("B", B, True), + ("C", C, True), + ("D", D, D is not None), + ("z", z, z is not None), + ("dt_bias", dt_bias, dt_bias is not None), + ("initial_states", initial_states, initial_states is not None), + ("seq_idx", seq_idx, False), + ("gamma", gamma, gamma is not None), + ("beta", beta, beta is not None), + ("theta", theta, theta is not None), + ("initial_prev_Bx", initial_prev_Bx, initial_prev_Bx is not None), + ] + + recomp_tensors = {} + grad_tensors = [] + for name, tensor, needs_grad in tensor_inputs: + if tensor is None: + recomp_tensors[name] = None + else: + t = tensor.detach() + if needs_grad: + t = t.requires_grad_(True) + grad_tensors.append((name, t)) + recomp_tensors[name] = t + + with torch.enable_grad(): + result = _mamba3_chunk_scan_combined_ref( + recomp_tensors["x"], + recomp_tensors["dt"], + recomp_tensors["A"], + recomp_tensors["B"], + recomp_tensors["C"], + ctx.chunk_size, + gamma=recomp_tensors["gamma"], + beta=recomp_tensors["beta"], + theta=recomp_tensors["theta"], + D=recomp_tensors["D"], + z=recomp_tensors["z"], + dt_bias=recomp_tensors["dt_bias"], + dt_softplus=ctx.dt_softplus, + dt_limit=ctx.dt_limit, + initial_states=recomp_tensors["initial_states"], + initial_prev_Bx=recomp_tensors["initial_prev_Bx"], + return_final_states=ctx.return_final_states, + ngroups=ctx.ngroups, + seq_idx=recomp_tensors["seq_idx"], + ) + + if ctx.return_final_states: + recomp_out, recomp_final_states = result + else: + recomp_out = result + recomp_final_states = None + + outputs = [recomp_out] + grad_outputs = [dout] + + if ctx.return_final_states and recomp_final_states is not None and dfinal_states is not None: + outputs.append(recomp_final_states) + grad_outputs.append(dfinal_states) + + diff_inputs = [t for _, t in grad_tensors] + grads = torch.autograd.grad( + outputs, + diff_inputs, + grad_outputs, + allow_unused=True, + ) + + grad_map = {} + for (name, _), g in zip(grad_tensors, grads): + grad_map[name] = g + + return ( + grad_map.get("x"), + grad_map.get("dt"), + grad_map.get("A"), + grad_map.get("B"), + grad_map.get("C"), + None, # chunk_size (int) + grad_map.get("D"), + grad_map.get("z"), + grad_map.get("dt_bias"), + grad_map.get("initial_states"), + None, # seq_idx (int tensor) + None, # dt_softplus (bool) + None, # dt_limit (tuple) + None, # return_final_states (bool) + grad_map.get("gamma"), + grad_map.get("beta"), + grad_map.get("theta"), + grad_map.get("initial_prev_Bx"), + None, # mimo_rank (int) + None, # ngroups (int) + ) + + +def _mamba3_triton_fwd(x, dt, A, B, C, chunk_size, + D=None, z=None, dt_bias=None, + initial_states=None, seq_idx=None, + dt_softplus=False, dt_limit=(0.0, float("inf")), + return_final_states=False, + gamma=None, beta=None, theta=None, + initial_prev_Bx=None, ngroups=1): + """Forward pass using Triton kernels for SISO Mamba-3 chunked SSD. + + Pipeline: + 1. dt preprocessing (reuse _chunk_cumsum_fwd from Mamba-2) + 2. RoPE on B, C (PyTorch -- will be Triton-ized later) + 3. Compute shifted B, x for trapezoidal (PyTorch preprocessing) + 4. Chunk state computation (_mamba3_chunk_state_fwd -- Triton) + 5. State passing (_state_passing_fwd -- reuse from Mamba-2) + 6. BMM: C^T @ B and C^T @ B_shifted (_bmm_chunk_fwd -- reuse from Mamba-2) + 7. Chunk scan output (_mamba3_chunk_scan_fwd -- Triton) + 8. initial_prev_Bx correction (PyTorch) + + This function is SISO-only (mimo_rank=0). MIMO falls back to PyTorch. + """ + from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd + from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd + from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd + from mamba_ssm.ops.triton.mamba3_chunk_state import _mamba3_chunk_state_fwd + from mamba_ssm.ops.triton.mamba3_chunk_scan import _mamba3_chunk_scan_fwd + + batch, seqlen, nheads, headdim = x.shape + ngroups_bc = B.shape[2] # B is (batch, seqlen, ngroups, dstate) + dstate = B.shape[-1] + nheads_per_group = nheads // ngroups_bc + out_dtype = x.dtype + + # Pad sequence length to multiple of chunk_size + pad_len = (chunk_size - seqlen % chunk_size) % chunk_size + if pad_len > 0: + x = F.pad(x, (0, 0, 0, 0, 0, pad_len)) + dt = F.pad(dt, (0, 0, 0, pad_len)) + B = F.pad(B, (0, 0, 0, 0, 0, pad_len)) + C = F.pad(C, (0, 0, 0, 0, 0, pad_len)) + if gamma is not None: + gamma = F.pad(gamma, (0, 0, 0, pad_len)) + if beta is not None: + beta = F.pad(beta, (0, 0, 0, pad_len)) + if z is not None: + z = F.pad(z, (0, 0, 0, 0, 0, pad_len)) + if theta is not None: + theta = F.pad(theta, (0, 0, 0, 0, 0, pad_len)) + if seq_idx is not None: + seq_idx = F.pad(seq_idx, (0, pad_len), value=-1) + + padded_seqlen = seqlen + pad_len + nchunks = padded_seqlen // chunk_size + + # ---- Stage 1: dt preprocessing ---- + # Reuse Mamba-2's _chunk_cumsum_fwd for dt bias, softplus, limit, and cumsum + dA_cumsum, dt_out = _chunk_cumsum_fwd( + dt.contiguous(), A.contiguous(), chunk_size, + dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit, + ) + # dA_cumsum: (batch, nheads, nchunks, chunk_size) + # dt_out: (batch, nheads, nchunks, chunk_size) -- processed dt values + + # ---- Stage 2: RoPE on B, C (PyTorch) ---- + if theta is not None: + B_heads, C_heads = apply_rotary_emb_to_bc(B, C, theta, nheads, ngroups_bc) + else: + # Expand B, C from groups to heads without RoPE + B_heads = repeat(B, "b l g n -> b l (g h) n", h=nheads_per_group) + C_heads = repeat(C, "b l g n -> b l (g h) n", h=nheads_per_group) + + # Make contiguous for kernel consumption + if B_heads.stride(-1) != 1: + B_heads = B_heads.contiguous() + if C_heads.stride(-1) != 1: + C_heads = C_heads.contiguous() + if x.stride(-1) != 1 and x.stride(1) != 1: + x = x.contiguous() + + # ---- Stage 3: Compute gamma/beta and shifted tensors (PyTorch) ---- + use_trapezoidal = gamma is not None + + if use_trapezoidal: + # Reshape gamma, beta to match dt_out layout: (batch, nheads, nchunks, chunk_size) + gamma_c = rearrange(gamma, "b (c l) h -> b h c l", l=chunk_size) + beta_c = rearrange(beta, "b (c l) h -> b h c l", l=chunk_size) if beta is not None else None + + # Compute shifted B and x for the lookback term + B_shifted = torch.zeros_like(B_heads) + x_shifted = torch.zeros_like(x) + # Within-sequence shift by 1 + B_shifted[:, 1:] = B_heads[:, :-1] + x_shifted[:, 1:] = x[:, :-1] + + # Handle seq_idx masking on shifted tensors + if seq_idx is not None: + shift_valid = torch.ones(batch, padded_seqlen, dtype=torch.bool, device=x.device) + shift_valid[:, 1:] = seq_idx[:, 1:] == seq_idx[:, :-1] + shift_valid[:, 0] = False # no lookback for first position + # Mask invalid shifts + sv = shift_valid[:, :, None, None] # (b, l, 1, 1) + B_shifted = B_shifted * sv + x_shifted = x_shifted * sv + + if B_shifted.stride(-1) != 1: + B_shifted = B_shifted.contiguous() + if x_shifted.stride(-1) != 1 and x_shifted.stride(1) != 1: + x_shifted = x_shifted.contiguous() + else: + # Euler mode: gamma = dt (already in dt_out) + gamma_c = dt_out + beta_c = None + B_shifted = None + x_shifted = None + + # ---- Stage 4: Chunk state computation (Triton) ---- + # Compute per-chunk states using the Mamba-3 variant that handles + # both the current (gamma) and lookback (beta) terms. + states = _mamba3_chunk_state_fwd( + B_heads, x, dt_out, dA_cumsum, + gamma=gamma_c, + beta=beta_c, B_shifted=B_shifted, x_shifted=x_shifted, + seq_idx=seq_idx, + states_in_fp32=True, + ) + + # Handle initial_prev_Bx correction on chunk 0 state + if initial_prev_Bx is not None and use_trapezoidal and beta is not None: + beta_flat = rearrange(beta, "b (c l) h -> b c l h", l=chunk_size) + beta_0 = beta_flat[:, 0, 0, :] # (batch, nheads) + correction = rearrange(beta_0, "b h -> b h 1 1") * initial_prev_Bx.float() + # Decay correction from position 0 to end of chunk 0 + decay_states = torch.exp(dA_cumsum[:, :, 0, -1:] - dA_cumsum[:, :, 0, :]) + decay_from_0 = decay_states[:, :, 0] # decay from pos 0 to end of chunk + states[:, 0] = states[:, 0] + rearrange(decay_from_0, "b h -> b h 1 1") * correction + + # ---- Stage 5: State passing (reuse Mamba-2 Triton kernel) ---- + states, final_states = _state_passing_fwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum[:, :, :, -1], + initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, + seq_idx=seq_idx, chunk_size=chunk_size, + out_dtype=C_heads.dtype, + ) + states, final_states = [ + rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states] + ] + + # ---- Stage 6: BMM for C^T @ B products (reuse Mamba-2 Triton kernel) ---- + # Main CB product for current-time term + CB = _bmm_chunk_fwd(C_heads, B_heads, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) + + # Shifted CB product for trapezoidal lookback term + CB_shifted = None + if use_trapezoidal and B_shifted is not None: + CB_shifted = _bmm_chunk_fwd(C_heads, B_shifted, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) + + # ---- Stage 7: Chunk scan output (Triton) ---- + out, out_x = _mamba3_chunk_scan_fwd( + CB, x, dt_out, dA_cumsum, gamma_c, C_heads, states, + D=D, z=z, + beta=beta_c, CB_shifted=CB_shifted, x_shifted=x_shifted, + seq_idx=seq_idx, + ) + + # ---- Stage 8: initial_prev_Bx output correction (PyTorch) ---- + # NOTE: This correction is applied after the scan kernel, which may have applied z gating. + # If z is not None, the correction would be added after z gating, which is incorrect. + # In practice, z is always None here (modules apply z gating externally). + if initial_prev_Bx is not None and use_trapezoidal and beta is not None: + assert z is None, ( + "initial_prev_Bx correction is not compatible with z gating in the Triton path. " + "Pass z=None and apply z gating externally." + ) + beta_flat = rearrange(beta, "b (c l) h -> b c l h", l=chunk_size) + beta_0 = beta_flat[:, 0, 0, :] # (batch, nheads) + correction = rearrange(beta_0, "b h -> b h 1 1") * initial_prev_Bx.float() + # Decay from position 0 to each position m within chunk 0 + decay_0_to_m = torch.exp(dA_cumsum[:, :, 0, :] - dA_cumsum[:, :, 0, 0:1]) # (b, h, chunk_size) + + # C_heads for chunk 0: (batch, chunk_size, nheads, dstate) + C_chunk0 = C_heads[:, :chunk_size] + # Y_correction = C_chunk0^T @ correction, weighted by decay + Y_correction = torch.einsum( + "blhn,bhpn,bhl->blhp", + C_chunk0.float(), correction, decay_0_to_m, + ) + out[:, :chunk_size] = out[:, :chunk_size] + Y_correction.to(out.dtype) + + # Un-pad if necessary + if pad_len > 0: + out = out[:, :seqlen] + if out_x is not None: + out_x = out_x[:, :seqlen] + + out = out.to(out_dtype) + + if return_final_states: + return out, out_x, final_states + return out, out_x, None + + +def _mamba3_pytorch_fwd(x, dt, A, B, C, chunk_size, + D=None, z=None, dt_bias=None, + initial_states=None, seq_idx=None, + dt_softplus=False, dt_limit=(0.0, float("inf")), + return_final_states=False, + gamma=None, beta=None, theta=None, + initial_prev_Bx=None, mimo_rank=0, ngroups=1): + """Forward pass using PyTorch reference (for backward recompute and MIMO fallback). + + Delegates to the existing mamba3_chunk_scan_combined reference implementation. + Returns (out, final_states) tuple where final_states is None if not requested. + """ + # Note: mamba3_chunk_scan_combined does not accept mimo_rank as a kwarg; + # it detects MIMO from B.dim() == 5. We do not pass mimo_rank here. + result = _mamba3_chunk_scan_combined_ref( + x, dt, A, B, C, chunk_size, + gamma=gamma, + beta=beta, + theta=theta, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + dt_limit=dt_limit, + initial_states=initial_states, + initial_prev_Bx=initial_prev_Bx, + return_final_states=return_final_states, + ngroups=ngroups, + seq_idx=seq_idx, + ) + + if return_final_states: + return result # already (out, final_states) + else: + return result, None # normalize to (out, None) + + +def mamba3_chunk_scan_combined_triton(x, dt, A, B, C, chunk_size, **kwargs): + """Drop-in replacement for mamba3_chunk_scan_combined with Triton acceleration. + + Uses Triton kernels for the forward pass (SISO on CUDA) and Triton backward + pipeline when available. Falls back to pure PyTorch for MIMO or when Triton + kernels are unavailable. + + Usage: + Replace calls to mamba3_chunk_scan_combined with this function. + The signature and return values are identical. + + Args: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads,) -- negative + B: (batch, seqlen, ngroups, d_state) + C: (batch, seqlen, ngroups, d_state) + chunk_size: int + **kwargs: All keyword arguments from mamba3_chunk_scan_combined: + gamma, beta, theta, D, z, dt_bias, dt_softplus, dt_limit, + initial_states, initial_prev_Bx, return_final_states, + ngroups, seq_idx, mimo_rank. + Returns: + Same as mamba3_chunk_scan_combined: + - Y: (batch, seqlen, nheads, headdim[, mimo_rank]) if not return_final_states + - (Y, final_state) if return_final_states + """ + # Extract mimo_rank to decide if we need to handle MIMO return shapes + mimo_rank = kwargs.get("mimo_rank", 0) + + # For the "ngroups" kwarg, mamba3_chunk_scan_combined uses it but it's also + # derivable from B.shape. We pass it through. + result = Mamba3ChunkScanCombinedFn.apply( + x, dt, A, B, C, chunk_size, + kwargs.get("D"), + kwargs.get("z"), + kwargs.get("dt_bias"), + kwargs.get("initial_states"), + kwargs.get("seq_idx"), + kwargs.get("dt_softplus", False), + kwargs.get("dt_limit", (0.0, float("inf"))), + kwargs.get("return_final_states", False), + kwargs.get("gamma"), + kwargs.get("beta"), + kwargs.get("theta"), + kwargs.get("initial_prev_Bx"), + mimo_rank, + kwargs.get("ngroups", 1), + ) + + return result diff --git a/mamba_ssm/ops/triton/mamba3_rope.py b/mamba_ssm/ops/triton/mamba3_rope.py new file mode 100644 index 000000000..0d63574db --- /dev/null +++ b/mamba_ssm/ops/triton/mamba3_rope.py @@ -0,0 +1,363 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Mamba-3 fused RoPE (Rotary Position Embedding) Triton kernels for B and C. +# +# Replaces `apply_rotary_emb_to_bc` in mamba3_ssd.py with fused Triton kernels: +# 1. Forward: expand B/C from groups to heads + apply sincos rotation +# 2. Backward: reverse rotation + reduce heads to groups + dtheta gradient +# +# Strategy (approach a): the cumulative sum of theta is computed in PyTorch +# (one kernel launch), then the Triton kernel handles the group->head expansion +# and sincos rotation. This avoids in-kernel sequential scan complexity. + +import torch +import triton +import triton.language as tl + + +# ============================================================================= +# Forward kernel: group->head expansion + RoPE rotation +# ============================================================================= + +@triton.jit +def _mamba3_rope_fwd_kernel( + # Input pointers + b_ptr, c_ptr, cos_ptr, sin_ptr, + # Output pointers + b_out_ptr, c_out_ptr, + # Dimensions + batch, seqlen, nheads, ngroups, dstate, half_d, + nheads_per_group, + # B strides (batch, seqlen, ngroups, dstate) + stride_b_batch, stride_b_seqlen, stride_b_group, stride_b_dstate, + # C strides (same layout as B) + stride_c_batch, stride_c_seqlen, stride_c_group, stride_c_dstate, + # cos/sin strides (batch, seqlen, nheads, half_d) + stride_cs_batch, stride_cs_seqlen, stride_cs_head, stride_cs_halfd, + # B_out strides (batch, seqlen, nheads, dstate) + stride_bo_batch, stride_bo_seqlen, stride_bo_head, stride_bo_dstate, + # C_out strides (same layout as B_out) + stride_co_batch, stride_co_seqlen, stride_co_head, stride_co_dstate, + # Meta-parameters + BLOCK_L: tl.constexpr, BLOCK_D: tl.constexpr, +): + """Fused group->head expansion and RoPE rotation for B and C. + + Grid: (batch, cdiv(seqlen, BLOCK_L), nheads) + + For each (batch, seqlen_tile, head): + 1. Load B from the corresponding group (head // nheads_per_group) + 2. Split into even (first half_d) and odd (second half_d) halves + 3. Apply rotation: B_out_even = B_even*cos - B_odd*sin + B_out_odd = B_even*sin + B_odd*cos + 4. Same for C + """ + pid_b = tl.program_id(axis=0) + pid_l = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + + # Which group does this head belong to? + pid_g = pid_h // nheads_per_group + + # Sequence offsets for this tile + offs_l = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) + # Half-dstate offsets + offs_d = tl.arange(0, BLOCK_D) + l_mask = offs_l < seqlen + d_mask = offs_d < half_d + + # --- Load cos and sin for this head --- + cs_base = pid_b * stride_cs_batch + pid_h * stride_cs_head + cos_ptrs = cos_ptr + cs_base + offs_l[:, None] * stride_cs_seqlen + offs_d[None, :] * stride_cs_halfd + sin_ptrs = sin_ptr + cs_base + offs_l[:, None] * stride_cs_seqlen + offs_d[None, :] * stride_cs_halfd + mask_ld = l_mask[:, None] & d_mask[None, :] + cos_val = tl.load(cos_ptrs, mask=mask_ld, other=1.0).to(tl.float32) + sin_val = tl.load(sin_ptrs, mask=mask_ld, other=0.0).to(tl.float32) + + # --- Process B --- + b_base = pid_b * stride_b_batch + pid_g * stride_b_group + # Even half (first half_d elements of dstate) + b_even_ptrs = b_ptr + b_base + offs_l[:, None] * stride_b_seqlen + offs_d[None, :] * stride_b_dstate + # Odd half (second half_d elements of dstate) + b_odd_ptrs = b_ptr + b_base + offs_l[:, None] * stride_b_seqlen + (offs_d[None, :] + half_d) * stride_b_dstate + b_even = tl.load(b_even_ptrs, mask=mask_ld, other=0.0).to(tl.float32) + b_odd = tl.load(b_odd_ptrs, mask=mask_ld, other=0.0).to(tl.float32) + + # Rotation + b_out_even = b_even * cos_val - b_odd * sin_val + b_out_odd = b_even * sin_val + b_odd * cos_val + + # Store to B_out at head level + bo_base = pid_b * stride_bo_batch + pid_h * stride_bo_head + bo_even_ptrs = b_out_ptr + bo_base + offs_l[:, None] * stride_bo_seqlen + offs_d[None, :] * stride_bo_dstate + bo_odd_ptrs = b_out_ptr + bo_base + offs_l[:, None] * stride_bo_seqlen + (offs_d[None, :] + half_d) * stride_bo_dstate + tl.store(bo_even_ptrs, b_out_even.to(b_out_ptr.dtype.element_ty), mask=mask_ld) + tl.store(bo_odd_ptrs, b_out_odd.to(b_out_ptr.dtype.element_ty), mask=mask_ld) + + # --- Process C (identical logic) --- + c_base = pid_b * stride_c_batch + pid_g * stride_c_group + c_even_ptrs = c_ptr + c_base + offs_l[:, None] * stride_c_seqlen + offs_d[None, :] * stride_c_dstate + c_odd_ptrs = c_ptr + c_base + offs_l[:, None] * stride_c_seqlen + (offs_d[None, :] + half_d) * stride_c_dstate + c_even = tl.load(c_even_ptrs, mask=mask_ld, other=0.0).to(tl.float32) + c_odd = tl.load(c_odd_ptrs, mask=mask_ld, other=0.0).to(tl.float32) + + c_out_even = c_even * cos_val - c_odd * sin_val + c_out_odd = c_even * sin_val + c_odd * cos_val + + co_base = pid_b * stride_co_batch + pid_h * stride_co_head + co_even_ptrs = c_out_ptr + co_base + offs_l[:, None] * stride_co_seqlen + offs_d[None, :] * stride_co_dstate + co_odd_ptrs = c_out_ptr + co_base + offs_l[:, None] * stride_co_seqlen + (offs_d[None, :] + half_d) * stride_co_dstate + tl.store(co_even_ptrs, c_out_even.to(c_out_ptr.dtype.element_ty), mask=mask_ld) + tl.store(co_odd_ptrs, c_out_odd.to(c_out_ptr.dtype.element_ty), mask=mask_ld) + + +def _mamba3_rope_fwd(B, C, theta, nheads, ngroups): + """Apply RoPE to B and C, expanding from groups to heads. + + Args: + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + theta: (batch, seqlen, nheads, dstate//2) -- per-step rotation angles + + Returns: + B_heads: (batch, seqlen, nheads, dstate) + C_heads: (batch, seqlen, nheads, dstate) + theta_cumsum: (batch, seqlen, nheads, dstate//2) -- saved for backward + """ + batch, seqlen, ngroups_B, dstate = B.shape + assert C.shape == B.shape + assert theta.shape == (batch, seqlen, nheads, dstate // 2) + assert nheads % ngroups == 0 + half_d = dstate // 2 + nheads_per_group = nheads // ngroups + + # Step 1: Compute cumulative sum of theta in PyTorch + theta_cumsum = torch.cumsum(theta.float(), dim=1) # (batch, seqlen, nheads, half_d) + cos_theta = torch.cos(theta_cumsum) + sin_theta = torch.sin(theta_cumsum) + + # Step 2: Allocate output at head level + B_heads = torch.empty(batch, seqlen, nheads, dstate, device=B.device, dtype=B.dtype) + C_heads = torch.empty(batch, seqlen, nheads, dstate, device=C.device, dtype=C.dtype) + + # Choose block sizes + BLOCK_L = min(triton.next_power_of_2(seqlen), 128) + BLOCK_D = triton.next_power_of_2(half_d) + + grid = (batch, triton.cdiv(seqlen, BLOCK_L), nheads) + + with torch.cuda.device(B.device.index): + _mamba3_rope_fwd_kernel[grid]( + B, C, cos_theta, sin_theta, + B_heads, C_heads, + batch, seqlen, nheads, ngroups, dstate, half_d, + nheads_per_group, + B.stride(0), B.stride(1), B.stride(2), B.stride(3), + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + cos_theta.stride(0), cos_theta.stride(1), cos_theta.stride(2), cos_theta.stride(3), + B_heads.stride(0), B_heads.stride(1), B_heads.stride(2), B_heads.stride(3), + C_heads.stride(0), C_heads.stride(1), C_heads.stride(2), C_heads.stride(3), + BLOCK_L=BLOCK_L, BLOCK_D=BLOCK_D, + ) + + return B_heads, C_heads, theta_cumsum + + +# ============================================================================= +# Backward kernel: reverse rotation + head->group reduction + dtheta +# ============================================================================= + +@triton.jit +def _mamba3_rope_bwd_kernel( + # Gradient inputs (at head level) + db_heads_ptr, dc_heads_ptr, + # Forward outputs (for dtheta computation) + b_heads_ptr, c_heads_ptr, + # cos/sin from forward + cos_ptr, sin_ptr, + # Outputs + db_ptr, dc_ptr, dtheta_ptr, + # Dimensions + batch, seqlen, nheads, ngroups, dstate, half_d, + nheads_per_group, + # dB_heads strides (batch, seqlen, nheads, dstate) + stride_dbh_batch, stride_dbh_seqlen, stride_dbh_head, stride_dbh_dstate, + # dC_heads strides + stride_dch_batch, stride_dch_seqlen, stride_dch_head, stride_dch_dstate, + # B_heads strides (forward outputs) + stride_bh_batch, stride_bh_seqlen, stride_bh_head, stride_bh_dstate, + # C_heads strides + stride_ch_batch, stride_ch_seqlen, stride_ch_head, stride_ch_dstate, + # cos/sin strides + stride_cs_batch, stride_cs_seqlen, stride_cs_head, stride_cs_halfd, + # dB output strides (batch, seqlen, ngroups, dstate) + stride_db_batch, stride_db_seqlen, stride_db_group, stride_db_dstate, + # dC output strides + stride_dc_batch, stride_dc_seqlen, stride_dc_group, stride_dc_dstate, + # dtheta strides (batch, seqlen, nheads, half_d) + stride_dth_batch, stride_dth_seqlen, stride_dth_head, stride_dth_halfd, + # Meta-parameters + BLOCK_L: tl.constexpr, BLOCK_D: tl.constexpr, +): + """Backward through RoPE for B and C. + + Grid: (batch, cdiv(seqlen, BLOCK_L), nheads) + + Computes: + dB[group] += reverse_rotate(dB_heads[head]) for all heads in group + dC[group] += reverse_rotate(dC_heads[head]) for all heads in group + dtheta[head] = (-B_out_odd)*dB_out_even + B_out_even*dB_out_odd + + (-C_out_odd)*dC_out_even + C_out_even*dC_out_odd + + Note: dtheta here is the gradient w.r.t. theta_cumsum. The reverse cumsum + to get gradient w.r.t. theta is done in PyTorch outside this kernel. + + For dB reduction across heads in a group, we use atomic_add since multiple + heads map to the same group. + """ + pid_b = tl.program_id(axis=0) + pid_l = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + pid_g = pid_h // nheads_per_group + + offs_l = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) + offs_d = tl.arange(0, BLOCK_D) + l_mask = offs_l < seqlen + d_mask = offs_d < half_d + mask_ld = l_mask[:, None] & d_mask[None, :] + + # Load cos/sin + cs_base = pid_b * stride_cs_batch + pid_h * stride_cs_head + cos_val = tl.load(cos_ptr + cs_base + offs_l[:, None] * stride_cs_seqlen + offs_d[None, :] * stride_cs_halfd, + mask=mask_ld, other=1.0).to(tl.float32) + sin_val = tl.load(sin_ptr + cs_base + offs_l[:, None] * stride_cs_seqlen + offs_d[None, :] * stride_cs_halfd, + mask=mask_ld, other=0.0).to(tl.float32) + + # --- dB backward --- + dbh_base = pid_b * stride_dbh_batch + pid_h * stride_dbh_head + db_out_even = tl.load(db_heads_ptr + dbh_base + offs_l[:, None] * stride_dbh_seqlen + offs_d[None, :] * stride_dbh_dstate, + mask=mask_ld, other=0.0).to(tl.float32) + db_out_odd = tl.load(db_heads_ptr + dbh_base + offs_l[:, None] * stride_dbh_seqlen + (offs_d[None, :] + half_d) * stride_dbh_dstate, + mask=mask_ld, other=0.0).to(tl.float32) + + # Reverse rotation: dB_even = dB_out_even*cos + dB_out_odd*sin + # dB_odd = -dB_out_even*sin + dB_out_odd*cos + db_even = db_out_even * cos_val + db_out_odd * sin_val + db_odd = -db_out_even * sin_val + db_out_odd * cos_val + + # Atomic add to dB at group level (multiple heads contribute to same group) + db_base = pid_b * stride_db_batch + pid_g * stride_db_group + db_even_ptrs = db_ptr + db_base + offs_l[:, None] * stride_db_seqlen + offs_d[None, :] * stride_db_dstate + db_odd_ptrs = db_ptr + db_base + offs_l[:, None] * stride_db_seqlen + (offs_d[None, :] + half_d) * stride_db_dstate + tl.atomic_add(db_even_ptrs, db_even.to(db_ptr.dtype.element_ty), mask=mask_ld) + tl.atomic_add(db_odd_ptrs, db_odd.to(db_ptr.dtype.element_ty), mask=mask_ld) + + # --- dC backward (identical structure) --- + dch_base = pid_b * stride_dch_batch + pid_h * stride_dch_head + dc_out_even = tl.load(dc_heads_ptr + dch_base + offs_l[:, None] * stride_dch_seqlen + offs_d[None, :] * stride_dch_dstate, + mask=mask_ld, other=0.0).to(tl.float32) + dc_out_odd = tl.load(dc_heads_ptr + dch_base + offs_l[:, None] * stride_dch_seqlen + (offs_d[None, :] + half_d) * stride_dch_dstate, + mask=mask_ld, other=0.0).to(tl.float32) + + dc_even = dc_out_even * cos_val + dc_out_odd * sin_val + dc_odd = -dc_out_even * sin_val + dc_out_odd * cos_val + + dc_base = pid_b * stride_dc_batch + pid_g * stride_dc_group + dc_even_ptrs = dc_ptr + dc_base + offs_l[:, None] * stride_dc_seqlen + offs_d[None, :] * stride_dc_dstate + dc_odd_ptrs = dc_ptr + dc_base + offs_l[:, None] * stride_dc_seqlen + (offs_d[None, :] + half_d) * stride_dc_dstate + tl.atomic_add(dc_even_ptrs, dc_even.to(dc_ptr.dtype.element_ty), mask=mask_ld) + tl.atomic_add(dc_odd_ptrs, dc_odd.to(dc_ptr.dtype.element_ty), mask=mask_ld) + + # --- dtheta computation --- + # dtheta_cumsum = (-B_out_odd)*dB_out_even + B_out_even*dB_out_odd + # + (-C_out_odd)*dC_out_even + C_out_even*dC_out_odd + bh_base = pid_b * stride_bh_batch + pid_h * stride_bh_head + b_out_even = tl.load(b_heads_ptr + bh_base + offs_l[:, None] * stride_bh_seqlen + offs_d[None, :] * stride_bh_dstate, + mask=mask_ld, other=0.0).to(tl.float32) + b_out_odd = tl.load(b_heads_ptr + bh_base + offs_l[:, None] * stride_bh_seqlen + (offs_d[None, :] + half_d) * stride_bh_dstate, + mask=mask_ld, other=0.0).to(tl.float32) + + ch_base = pid_b * stride_ch_batch + pid_h * stride_ch_head + c_out_even = tl.load(c_heads_ptr + ch_base + offs_l[:, None] * stride_ch_seqlen + offs_d[None, :] * stride_ch_dstate, + mask=mask_ld, other=0.0).to(tl.float32) + c_out_odd = tl.load(c_heads_ptr + ch_base + offs_l[:, None] * stride_ch_seqlen + (offs_d[None, :] + half_d) * stride_ch_dstate, + mask=mask_ld, other=0.0).to(tl.float32) + + dtheta_cs = ((-b_out_odd) * db_out_even + b_out_even * db_out_odd + + (-c_out_odd) * dc_out_even + c_out_even * dc_out_odd) + + dth_base = pid_b * stride_dth_batch + pid_h * stride_dth_head + dth_ptrs = dtheta_ptr + dth_base + offs_l[:, None] * stride_dth_seqlen + offs_d[None, :] * stride_dth_halfd + tl.store(dth_ptrs, dtheta_cs.to(dtheta_ptr.dtype.element_ty), mask=mask_ld) + + +def _mamba3_rope_bwd(dB_heads, dC_heads, B_heads, C_heads, theta_cumsum, ngroups): + """Backward through RoPE. + + Args: + dB_heads: (batch, seqlen, nheads, dstate) -- gradient of rotated B + dC_heads: (batch, seqlen, nheads, dstate) -- gradient of rotated C + B_heads: (batch, seqlen, nheads, dstate) -- forward output (rotated B) + C_heads: (batch, seqlen, nheads, dstate) -- forward output (rotated C) + theta_cumsum: (batch, seqlen, nheads, dstate//2) -- cumulative theta from forward + ngroups: int + + Returns: + dB: (batch, seqlen, ngroups, dstate) + dC: (batch, seqlen, ngroups, dstate) + dtheta: (batch, seqlen, nheads, dstate//2) + """ + batch, seqlen, nheads, dstate = dB_heads.shape + assert dC_heads.shape == dB_heads.shape + assert B_heads.shape == dB_heads.shape + assert C_heads.shape == dB_heads.shape + half_d = dstate // 2 + assert theta_cumsum.shape == (batch, seqlen, nheads, half_d) + assert nheads % ngroups == 0 + nheads_per_group = nheads // ngroups + + cos_theta = torch.cos(theta_cumsum) + sin_theta = torch.sin(theta_cumsum) + + # Allocate outputs -- dB and dC are zero-initialized for atomic_add + dB = torch.zeros(batch, seqlen, ngroups, dstate, device=dB_heads.device, dtype=torch.float32) + dC = torch.zeros(batch, seqlen, ngroups, dstate, device=dC_heads.device, dtype=torch.float32) + dtheta_cumsum = torch.empty(batch, seqlen, nheads, half_d, device=dB_heads.device, dtype=torch.float32) + + BLOCK_L = min(triton.next_power_of_2(seqlen), 128) + BLOCK_D = triton.next_power_of_2(half_d) + + grid = (batch, triton.cdiv(seqlen, BLOCK_L), nheads) + + with torch.cuda.device(dB_heads.device.index): + _mamba3_rope_bwd_kernel[grid]( + dB_heads, dC_heads, + B_heads, C_heads, + cos_theta, sin_theta, + dB, dC, dtheta_cumsum, + batch, seqlen, nheads, ngroups, dstate, half_d, + nheads_per_group, + # dB_heads strides + dB_heads.stride(0), dB_heads.stride(1), dB_heads.stride(2), dB_heads.stride(3), + # dC_heads strides + dC_heads.stride(0), dC_heads.stride(1), dC_heads.stride(2), dC_heads.stride(3), + # B_heads strides + B_heads.stride(0), B_heads.stride(1), B_heads.stride(2), B_heads.stride(3), + # C_heads strides + C_heads.stride(0), C_heads.stride(1), C_heads.stride(2), C_heads.stride(3), + # cos/sin strides + cos_theta.stride(0), cos_theta.stride(1), cos_theta.stride(2), cos_theta.stride(3), + # dB strides + dB.stride(0), dB.stride(1), dB.stride(2), dB.stride(3), + # dC strides + dC.stride(0), dC.stride(1), dC.stride(2), dC.stride(3), + # dtheta strides + dtheta_cumsum.stride(0), dtheta_cumsum.stride(1), dtheta_cumsum.stride(2), dtheta_cumsum.stride(3), + BLOCK_L=BLOCK_L, BLOCK_D=BLOCK_D, + ) + + # Reverse cumsum to get dtheta from dtheta_cumsum: + # dtheta[t] = sum_{s>=t} dtheta_cumsum[s] + # = flip(cumsum(flip(dtheta_cumsum, dim=1), dim=1), dim=1) + dtheta = dtheta_cumsum.flip(1).cumsum(1).flip(1) + + return dB, dC, dtheta diff --git a/mamba_ssm/ops/triton/mamba3_shift.py b/mamba_ssm/ops/triton/mamba3_shift.py new file mode 100644 index 000000000..59296bba9 --- /dev/null +++ b/mamba_ssm/ops/triton/mamba3_shift.py @@ -0,0 +1,251 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Mamba-3 shift-by-1 Triton kernels with seq_idx masking. +# +# Simple memory-bound kernels for shifting tensors by 1 position along the +# sequence dimension, used for the trapezoidal lookback term in Mamba-3. +# Supports seq_idx masking to zero out shifted values at document boundaries. +# +# Works for any tensor with seqlen as dim 1 (e.g. B: (b,l,h,n) or x: (b,l,h,p)). +# Everything after dim 1 is treated as a flat dimension. + +import torch +import triton +import triton.language as tl + + +# ============================================================================= +# Forward kernel: shift by 1 position +# ============================================================================= + +@triton.jit +def _mamba3_shift_fwd_kernel( + # Input/output pointers + x_ptr, out_ptr, seq_idx_ptr, initial_ptr, + # Dimensions + batch, seqlen, flat_dim, + # x strides + stride_x_batch, stride_x_seqlen, stride_x_flat, + # out strides + stride_out_batch, stride_out_seqlen, stride_out_flat, + # seq_idx strides + stride_si_batch, stride_si_seqlen, + # initial strides (batch, flat_dim) or scalar 0 + stride_init_batch, stride_init_flat, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + HAS_INITIAL: tl.constexpr, + BLOCK_L: tl.constexpr, BLOCK_D: tl.constexpr, +): + """Shift tensor by 1 along seqlen dimension. + + Grid: (batch, cdiv(seqlen, BLOCK_L), cdiv(flat_dim, BLOCK_D)) + + out[:,0,:] = initial (or 0) + out[:,t,:] = x[:,t-1,:] for t > 0 + if seq_idx: out[:,t,:] = 0 where seq_idx[:,t] != seq_idx[:,t-1] + """ + pid_b = tl.program_id(axis=0) + pid_l = tl.program_id(axis=1) + pid_d = tl.program_id(axis=2) + + offs_l = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) + offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) + + l_mask = offs_l < seqlen + d_mask = offs_d < flat_dim + mask = l_mask[:, None] & d_mask[None, :] + + # For t > 0: load from t-1 + # For t == 0: load initial or zero + src_l = offs_l - 1 # source positions (t-1) + src_valid = src_l >= 0 + + # Load source values (from position t-1) + x_base = pid_b * stride_x_batch + src_ptrs = x_ptr + x_base + src_l[:, None] * stride_x_seqlen + offs_d[None, :] * stride_x_flat + src_mask = mask & src_valid[:, None] + vals = tl.load(src_ptrs, mask=src_mask, other=0.0) + + # Handle position 0: use initial if provided + if HAS_INITIAL: + init_ptrs = initial_ptr + pid_b * stride_init_batch + offs_d * stride_init_flat + init_vals = tl.load(init_ptrs, mask=d_mask, other=0.0) + # Broadcast init_vals to (1, BLOCK_D) and apply where src_l < 0 + is_pos_zero = (offs_l == 0) + vals = tl.where(is_pos_zero[:, None] & d_mask[None, :], init_vals[None, :], vals) + + # seq_idx masking: zero out where document boundary crossed + if HAS_SEQ_IDX: + si_base = pid_b * stride_si_batch + si_cur = tl.load(seq_idx_ptr + si_base + offs_l * stride_si_seqlen, mask=l_mask, other=-1) + si_prev = tl.load(seq_idx_ptr + si_base + src_l * stride_si_seqlen, mask=l_mask & src_valid, other=-2) + # Zero out where current != previous (document boundary) or position 0 + boundary = ~src_valid | (si_cur != si_prev) + vals = tl.where(boundary[:, None], 0.0, vals) + + # Store output + out_base = pid_b * stride_out_batch + out_ptrs = out_ptr + out_base + offs_l[:, None] * stride_out_seqlen + offs_d[None, :] * stride_out_flat + tl.store(out_ptrs, vals, mask=mask) + + +def _mamba3_shift_fwd(x, seq_idx=None, initial=None): + """Shift tensor by 1 position along seqlen dim, with seq_idx masking. + + Args: + x: (batch, seqlen, ...) -- any shape with seqlen as dim 1 + seq_idx: (batch, seqlen) -- document boundaries, or None + initial: (batch, ...) value for position 0 (default: zero), or None + + Returns: + x_shifted: same shape as x, shifted by 1 + """ + batch = x.shape[0] + seqlen = x.shape[1] + # Flatten everything after dim 1 + orig_shape = x.shape + flat_dim = 1 + for s in x.shape[2:]: + flat_dim *= s + x_flat = x.reshape(batch, seqlen, flat_dim) + + if not x_flat.is_contiguous(): + x_flat = x_flat.contiguous() + + out = torch.empty_like(x_flat) + + has_initial = initial is not None + if has_initial: + initial_flat = initial.reshape(batch, flat_dim) + if not initial_flat.is_contiguous(): + initial_flat = initial_flat.contiguous() + else: + initial_flat = None + + BLOCK_L = min(triton.next_power_of_2(seqlen), 256) + BLOCK_D = min(triton.next_power_of_2(flat_dim), 256) + + grid = (batch, triton.cdiv(seqlen, BLOCK_L), triton.cdiv(flat_dim, BLOCK_D)) + + with torch.cuda.device(x.device.index): + _mamba3_shift_fwd_kernel[grid]( + x_flat, out, seq_idx, initial_flat, + batch, seqlen, flat_dim, + x_flat.stride(0), x_flat.stride(1), x_flat.stride(2), + out.stride(0), out.stride(1), out.stride(2), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + *((initial_flat.stride(0), initial_flat.stride(1)) if has_initial else (0, 0)), + HAS_SEQ_IDX=seq_idx is not None, + HAS_INITIAL=has_initial, + BLOCK_L=BLOCK_L, BLOCK_D=BLOCK_D, + ) + + return out.reshape(orig_shape) + + +# ============================================================================= +# Backward kernel: reverse shift +# ============================================================================= + +@triton.jit +def _mamba3_shift_bwd_kernel( + # Input/output pointers + dout_ptr, dx_ptr, seq_idx_ptr, + # Dimensions + batch, seqlen, flat_dim, + # dout strides + stride_dout_batch, stride_dout_seqlen, stride_dout_flat, + # dx strides + stride_dx_batch, stride_dx_seqlen, stride_dx_flat, + # seq_idx strides + stride_si_batch, stride_si_seqlen, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_L: tl.constexpr, BLOCK_D: tl.constexpr, +): + """Backward of shift: reverse the shift. + + Grid: (batch, cdiv(seqlen, BLOCK_L), cdiv(flat_dim, BLOCK_D)) + + dx[:,t,:] = dout[:,t+1,:] for t < seqlen-1 + dx[:,seqlen-1,:] = 0 + if seq_idx: dx[:,t,:] = 0 where seq_idx[:,t+1] != seq_idx[:,t] + """ + pid_b = tl.program_id(axis=0) + pid_l = tl.program_id(axis=1) + pid_d = tl.program_id(axis=2) + + offs_l = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) + offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) + + l_mask = offs_l < seqlen + d_mask = offs_d < flat_dim + mask = l_mask[:, None] & d_mask[None, :] + + # For position t: gradient comes from dout at position t+1 + # (because forward: out[t+1] = x[t]) + dst_l = offs_l + 1 # where this position's value went to in forward + dst_valid = dst_l < seqlen + + # Load gradient from position t+1 + dout_base = pid_b * stride_dout_batch + dout_ptrs = dout_ptr + dout_base + dst_l[:, None] * stride_dout_seqlen + offs_d[None, :] * stride_dout_flat + src_mask = mask & dst_valid[:, None] + vals = tl.load(dout_ptrs, mask=src_mask, other=0.0) + + # seq_idx masking: zero out where forward would have zeroed + if HAS_SEQ_IDX: + si_base = pid_b * stride_si_batch + si_cur = tl.load(seq_idx_ptr + si_base + offs_l * stride_si_seqlen, mask=l_mask, other=-1) + si_next = tl.load(seq_idx_ptr + si_base + dst_l * stride_si_seqlen, mask=l_mask & dst_valid, other=-2) + # In forward: out[t+1] = x[t] only if seq_idx[t+1] == seq_idx[t] + # So gradient flows back only when seq_idx matches + boundary = ~dst_valid | (si_next != si_cur) + vals = tl.where(boundary[:, None], 0.0, vals) + + # Store dx + dx_base = pid_b * stride_dx_batch + dx_ptrs = dx_ptr + dx_base + offs_l[:, None] * stride_dx_seqlen + offs_d[None, :] * stride_dx_flat + tl.store(dx_ptrs, vals, mask=mask) + + +def _mamba3_shift_bwd(dx_shifted, seq_idx=None): + """Backward through shift: reverse the shift to get gradient w.r.t. input. + + Args: + dx_shifted: (batch, seqlen, ...) -- gradient of shifted output + seq_idx: (batch, seqlen) -- document boundaries, or None + + Returns: + dx: same shape as dx_shifted, unshifted gradient + """ + batch = dx_shifted.shape[0] + seqlen = dx_shifted.shape[1] + orig_shape = dx_shifted.shape + flat_dim = 1 + for s in dx_shifted.shape[2:]: + flat_dim *= s + dout_flat = dx_shifted.reshape(batch, seqlen, flat_dim) + + if not dout_flat.is_contiguous(): + dout_flat = dout_flat.contiguous() + + dx = torch.empty_like(dout_flat) + + BLOCK_L = min(triton.next_power_of_2(seqlen), 256) + BLOCK_D = min(triton.next_power_of_2(flat_dim), 256) + + grid = (batch, triton.cdiv(seqlen, BLOCK_L), triton.cdiv(flat_dim, BLOCK_D)) + + with torch.cuda.device(dx_shifted.device.index): + _mamba3_shift_bwd_kernel[grid]( + dout_flat, dx, seq_idx, + batch, seqlen, flat_dim, + dout_flat.stride(0), dout_flat.stride(1), dout_flat.stride(2), + dx.stride(0), dx.stride(1), dx.stride(2), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_L=BLOCK_L, BLOCK_D=BLOCK_D, + ) + + return dx.reshape(orig_shape) diff --git a/mamba_ssm/ops/triton/mamba3_ssd.py b/mamba_ssm/ops/triton/mamba3_ssd.py new file mode 100644 index 000000000..094784671 --- /dev/null +++ b/mamba_ssm/ops/triton/mamba3_ssd.py @@ -0,0 +1,842 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Mamba-3 SSD operations: chunked parallel forward + Triton decode kernel. +# +# Chunked parallel SSD with exponential-trapezoidal discretization: +# h_t = α_t * h_{t-1} + β_t * B_{t-1} * x_{t-1} + γ_t * B_t * x_t +# +# Strategy: +# 1. Pre-convolve the state-input: v_t = γ_t * B_t ⊗ x_t + β_t * B_{t-1} ⊗ x_{t-1} +# Then the recurrence is h_t = α_t * h_{t-1} + v_t (standard linear recurrence). +# 2. Apply RoPE to B, C before chunked computation (doesn't affect kernel structure). +# 3. Use the SSD chunked algorithm (matmuls within chunks, sequential across chunks). + +import math +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +try: + import triton + import triton.language as tl + from mamba_ssm.ops.triton.softplus import softplus + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +# ============================================================================ +# Chunked Parallel SSD for Mamba-3 (PyTorch reference, differentiable) +# ============================================================================ + +def segsum(x): + """Stable segment sum for causal decay matrix.""" + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + x_segsum = torch.cumsum(x, dim=-2) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def mamba3_ssd_chunked( + X, dt, A, B, C, + block_len, + gamma=None, + beta=None, + D=None, + z=None, + initial_states=None, + return_final_states=False, + initial_prev_Bx=None, + seq_idx=None, +): + """ + Chunked parallel SSD for Mamba-3 with exponential-trapezoidal discretization. + + Supports both SISO and MIMO. For MIMO, B/C/X have an extra trailing rank dimension. + + Arguments: + X: (batch, length, n_heads, d_head[, mimo_rank]) + dt: (batch, length, n_heads) + A: (n_heads,) -- negative SSM eigenvalues + B: (batch, length, n_heads, d_state[, mimo_rank]) + C: (batch, length, n_heads, d_state[, mimo_rank]) + block_len: int -- chunk size + gamma: (batch, length, n_heads) or None -- trapezoidal current weight (λ * dt) + beta: (batch, length, n_heads) or None -- trapezoidal lookback weight ((1-λ) * dt * α) + D: (n_heads,) or (n_heads, d_head) or None -- skip connection + initial_states: (batch, n_heads, d_head, d_state) or None + return_final_states: bool + initial_prev_Bx: (batch, n_heads, d_head, d_state) or None -- prev B*x for trapezoidal t=0 + seq_idx: (batch, length) int or None -- document indices for packed training + Return: + Y: (batch, length, n_heads, d_head) + final_state: (batch, n_heads, d_head, d_state) if return_final_states + """ + is_mimo = B.dim() == 5 + batch, seqlen, nheads, headdim = X.shape[:4] + mimo_rank = X.shape[4] if is_mimo else 0 + dstate = B.shape[-2] if is_mimo else B.shape[-1] + + assert seqlen % block_len == 0 + nchunks = seqlen // block_len + out_dtype = X.dtype # preserve original dtype for output + + # Cast to float32 for numerical stability in matmuls + X = X.float() + B = B.float() + C = C.float() + dt = dt.float() + if gamma is not None: + gamma = gamma.float() + if beta is not None: + beta = beta.float() + + # Compute dA = dt * A per timestep + dA = dt * A.float().view(1, 1, nheads) # (batch, seqlen, nheads) + + # Prepare seq_idx masks for cross-document boundary handling + has_seq_idx = seq_idx is not None + if has_seq_idx: + seq_idx_c = rearrange(seq_idx, "b (c l) -> b c l", l=block_len) + + # Reshape into chunks + if is_mimo: + X_c = rearrange(X, "b (c l) h p r -> b c l h p r", l=block_len) + B_c = rearrange(B, "b (c l) h n r -> b c l h n r", l=block_len) + C_c = rearrange(C, "b (c l) h n r -> b c l h n r", l=block_len) + else: + X_c = rearrange(X, "b (c l) h p -> b c l h p", l=block_len) + B_c = rearrange(B, "b (c l) h n -> b c l h n", l=block_len) + C_c = rearrange(C, "b (c l) h n -> b c l h n", l=block_len) + dA_c = rearrange(dA, "b (c l) h -> b h c l", l=block_len) + dt_c = rearrange(dt, "b (c l) h -> b c l h", l=block_len) + + # Cumsum of dA within each chunk + dA_cumsum = torch.cumsum(dA_c, dim=-1) # (batch, nheads, nchunks, block_len) + + # For SISO, C_c is used directly. For MIMO, we compute per output rank. + + # === Handle trapezoidal convolution on state-input === + use_trapezoidal = gamma is not None and beta is not None + + if use_trapezoidal: + gamma_c = rearrange(gamma, "b (c l) h -> b c l h", l=block_len) + beta_c = rearrange(beta, "b (c l) h -> b c l h", l=block_len) + + # Shifted B and X for the lookback term (shift by 1 within each chunk) + B_shifted = torch.zeros_like(B_c) + X_shifted = torch.zeros_like(X_c) + B_shifted[:, :, 1:] = B_c[:, :, :-1] + X_shifted[:, :, 1:] = X_c[:, :, :-1] + # Cross-chunk boundary: position 0 of chunk c gets last position of chunk c-1 + B_shifted[:, 1:, 0] = B_c[:, :-1, -1] + X_shifted[:, 1:, 0] = X_c[:, :-1, -1] + + # Mask shifted values at document boundaries (no lookback across documents) + if has_seq_idx: + # shift_valid[b,c,t] = True if position t can look back to t-1 + shift_valid = torch.ones(batch, nchunks, block_len, dtype=torch.bool, device=X.device) + shift_valid[:, :, 1:] = seq_idx_c[:, :, 1:] == seq_idx_c[:, :, :-1] + shift_valid[:, 1:, 0] = seq_idx_c[:, 1:, 0] == seq_idx_c[:, :-1, -1] + shift_valid[:, 0, 0] = False # no lookback for very first position + if is_mimo: + sv = shift_valid[:, :, :, None, None, None] + else: + sv = shift_valid[:, :, :, None, None] + B_shifted = B_shifted * sv + X_shifted = X_shifted * sv + + # Streaming: position 0 of chunk 0 from initial_prev_Bx + # We can't split prev_Bx into separate B and x, so handle as state correction below + else: + gamma_c = dt_c # Fall back to Euler: γ = dt + beta_c = None + + # === 1. Intra-chunk computation (diagonal blocks) === + L = torch.exp(segsum(dA_c)) # (batch, nheads, nchunks, block_len, block_len) + + # Mask L for cross-document boundaries: L[i,j]=0 when seq_idx[i] != seq_idx[j] + if has_seq_idx: + seq_mask = seq_idx_c[:, :, :, None] == seq_idx_c[:, :, None, :] # (b, c, l, l) + L = L * rearrange(seq_mask.float(), "b c l s -> b 1 c l s") + + gamma_scale = rearrange(gamma_c, "b c l h -> b h c 1 l") + + if is_mimo: + # Per output rank: Y[r_out] = Σ_{r_in} L * γ * C[r_out]^T B[r_in] * X[r_in] + Y_diag = torch.zeros(batch, nchunks, block_len, nheads, headdim, mimo_rank, + device=X.device, dtype=X.dtype) + for r_out in range(mimo_rank): + for r_in in range(mimo_rank): + CB_r = torch.einsum("bclhn,bcshn->bhcls", C_c[..., r_out], B_c[..., r_in]) + Y_diag[..., r_out] = Y_diag[..., r_out] + torch.einsum( + "bhcls,bhcls,bcshp->bclhp", + L * gamma_scale, CB_r, X_c[..., r_in], + ) + if use_trapezoidal: + beta_scale = rearrange(beta_c, "b c l h -> b h c 1 l") + for r_out in range(mimo_rank): + for r_in in range(mimo_rank): + CB_shifted_r = torch.einsum("bclhn,bcshn->bhcls", C_c[..., r_out], B_shifted[..., r_in]) + Y_diag[..., r_out] = Y_diag[..., r_out] + torch.einsum( + "bhcls,bhcls,bcshp->bclhp", + L * beta_scale, CB_shifted_r, X_shifted[..., r_in], + ) + else: + CB = torch.einsum("bclhn,bcshn->bhcls", C_c, B_c) + Y_diag = torch.einsum( + "bhcls,bhcls,bcshp->bclhp", + L * gamma_scale, CB, X_c, + ) + if use_trapezoidal: + CB_shifted = torch.einsum("bclhn,bcshn->bhcls", C_c, B_shifted) + beta_scale = rearrange(beta_c, "b c l h -> b h c 1 l") + Y_diag = Y_diag + torch.einsum( + "bhcls,bhcls,bcshp->bclhp", + L * beta_scale, CB_shifted, X_shifted, + ) + + # === 2. Per-chunk state computation === + decay_states = torch.exp(dA_cumsum[:, :, :, -1:] - dA_cumsum) # (batch, nheads, nchunks, block_len) + + # Mask: only accumulate state from tokens in same document as chunk's last token + if has_seq_idx: + state_doc_mask = (seq_idx_c == seq_idx_c[:, :, -1:]).float() # (b, c, l) + decay_states = decay_states * rearrange(state_doc_mask, "b c l -> b 1 c l") + + gamma_decay = decay_states * rearrange(gamma_c, "b c l h -> b h c l") + + if is_mimo: + states = torch.zeros(batch, nchunks, nheads, headdim, dstate, + device=X.device, dtype=torch.float32) + for r in range(mimo_rank): + states = states + torch.einsum( + "bclhn,bhcl,bclhp->bchpn", + B_c[..., r].float(), gamma_decay, X_c[..., r].float(), + ) + if use_trapezoidal: + beta_decay = decay_states * rearrange(beta_c, "b c l h -> b h c l") + for r in range(mimo_rank): + states = states + torch.einsum( + "bclhn,bhcl,bclhp->bchpn", + B_shifted[..., r].float(), beta_decay, X_shifted[..., r].float(), + ) + else: + states = torch.einsum( + "bclhn,bhcl,bclhp->bchpn", + B_c, gamma_decay, X_c, + ) + if use_trapezoidal: + beta_decay = decay_states * rearrange(beta_c, "b c l h -> b h c l") + states_trap = torch.einsum( + "bclhn,bhcl,bclhp->bchpn", + B_shifted, beta_decay, X_shifted, + ) + states = states + states_trap + + # === Handle initial_prev_Bx correction === + # At t=0, trapezoidal adds β_0 * initial_prev_Bx to state. This was missed by + # B_shifted/X_shifted (which are zero at position 0 of chunk 0). + # We add the correction both to the state and the output. + if initial_prev_Bx is not None and use_trapezoidal: + beta_0 = beta[:, 0, :] # (batch, nheads) + correction = rearrange(beta_0, "b h -> b h 1 1") * initial_prev_Bx.float() + # Decay correction from position 0 to end of chunk 0 + decay_from_0 = decay_states[:, :, 0, 0] + states[:, 0] = states[:, 0] + rearrange(decay_from_0, "b h -> b h 1 1") * correction + # Output correction within chunk 0 + decay_0_to_m = torch.exp(dA_cumsum[:, :, 0, :] - dA_cumsum[:, :, 0, 0:1]) # (b, h, block_len) + if is_mimo: + for r_out in range(mimo_rank): + Y_corr = torch.einsum( + "bclhn,bhpn,bhcl->bclhp", + C_c[:, 0:1, :, :, :, r_out], correction.to(C_c.dtype), decay_0_to_m.unsqueeze(2), + ) + Y_diag[:, 0:1, :, :, :, r_out] = Y_diag[:, 0:1, :, :, :, r_out] + Y_corr + else: + Y_correction = torch.einsum( + "bclhn,bhpn,bhcl->bclhp", + C_c[:, 0:1], correction.to(C_c.dtype), decay_0_to_m.unsqueeze(2), + ) + Y_diag[:, 0:1] = Y_diag[:, 0:1] + Y_correction + + # === 3. Inter-chunk recurrence === + if initial_states is None: + initial_states_flat = torch.zeros( + batch, nheads, headdim * dstate, device=X.device, dtype=torch.float32, + ) + else: + initial_states_flat = rearrange(initial_states.float(), "b h p n -> b h (p n)") + + states_flat = rearrange(states.float(), "b c h p n -> b c h (p n)") + + # Total decay per chunk + dA_chunk_cumsum = dA_cumsum[:, :, :, -1] # (batch, nheads, nchunks) + + # Mask inter-chunk propagation at document boundaries + if has_seq_idx: + # Compare last token of each chunk with last token of previous chunk + chunk_end_idx = seq_idx_c[:, :, -1] # (b, c) + # chunk_same[b, c] = True if chunk c's last token is same doc as chunk c-1's last token + chunk_same = torch.ones(batch, nchunks, dtype=torch.bool, device=X.device) + chunk_same[:, 1:] = chunk_end_idx[:, 1:] == chunk_end_idx[:, :-1] + chunk_same[:, 0] = initial_states is not None # propagate initial state only if provided + chunk_propagate = rearrange(chunk_same.float(), "b c -> b 1 c") # (b, 1, c) for heads + + # Sequential scan across chunks + all_states = [] + prev_state = initial_states_flat + for c in range(nchunks): + scale = torch.exp(dA_chunk_cumsum[:, :, c]).unsqueeze(-1) # (batch, nheads, 1) + if has_seq_idx: + scale = scale * chunk_propagate[:, :, c].unsqueeze(-1) + prev_state = scale * prev_state + states_flat[:, c] + all_states.append(prev_state) + + # Propagated states at chunk boundaries + boundary_states = [initial_states_flat] + all_states[:-1] + boundary_states = torch.stack(boundary_states, dim=1) # (batch, nchunks, nheads, headdim*dstate) + boundary_states = rearrange(boundary_states, "b c h (p n) -> b c h p n", p=headdim, n=dstate) + + final_state = rearrange(all_states[-1], "b h (p n) -> b h p n", p=headdim, n=dstate) + + # === 4. Inter-chunk output (off-diagonal blocks) === + state_decay_out = torch.exp(dA_cumsum) # (batch, nheads, nchunks, block_len) + + # Mask: only apply boundary state to positions in same document as chunk start + if has_seq_idx: + # For each position in a chunk, check if it's in the same document as + # the boundary state (which was propagated from the previous chunk's last token) + # The boundary state represents the document at the END of the previous chunk. + # We need: seq_idx at each position == seq_idx at start of current doc segment in this chunk. + # Simpler correct approach: mask positions where seq_idx differs from chunk's first token + # of the same document run that includes the boundary. + # Actually, the correct mask is: state_decay_out[pos] = 0 if pos belongs to a different + # document than the boundary state. The boundary state is from the previous chunk's last token. + boundary_doc = torch.zeros(batch, nchunks, dtype=seq_idx.dtype, device=seq_idx.device) + boundary_doc[:, 0] = seq_idx_c[:, 0, 0] # initial state's document (or first token) + boundary_doc[:, 1:] = seq_idx_c[:, :-1, -1] # previous chunk's last token + # mask[b,c,t] = (seq_idx_c[b,c,t] == boundary_doc[b,c]) + off_diag_mask = (seq_idx_c == boundary_doc[:, :, None]).float() + state_decay_out = state_decay_out * rearrange(off_diag_mask, "b c l -> b 1 c l") + if is_mimo: + Y_off = torch.zeros(batch, nchunks, block_len, nheads, headdim, mimo_rank, + device=X.device, dtype=X.dtype) + for r_out in range(mimo_rank): + Y_off[..., r_out] = torch.einsum( + "bclhn,bchpn,bhcl->bclhp", + C_c[..., r_out], boundary_states.to(C_c.dtype), state_decay_out, + ) + else: + Y_off = torch.einsum( + "bclhn,bchpn,bhcl->bclhp", + C_c, boundary_states.to(C_c.dtype), state_decay_out, + ) + + # === Combine === + if is_mimo: + Y = rearrange(Y_diag + Y_off, "b c l h p r -> b (c l) h p r") + else: + Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") + + # D skip connection + if D is not None: + if is_mimo: + if D.dim() == 1: + Y = Y + X * rearrange(D, "h -> 1 1 h 1 1") + else: + Y = Y + X * rearrange(D, "h p -> 1 1 h p 1") + else: + if D.dim() == 1: + Y = Y + X * rearrange(D, "h -> 1 1 h 1") + else: + Y = Y + X * rearrange(D, "h p -> 1 1 h p") + + # z gating (SiLU) + if z is not None: + z = z.float() + Y = Y * F.silu(z) + + # Cast back to original dtype + Y = Y.to(out_dtype) + + if return_final_states: + return Y, final_state + return Y + + +def apply_rotary_emb_to_bc(B, C, theta, nheads, ngroups): + """Apply cumulative data-dependent RoPE to B and C. + + RoPE is applied AFTER expanding B, C from groups to heads so that each head + gets its own rotation (matching the reference recurrence). This is called + before group→head expansion in mamba3_chunk_scan_combined, so we expand + internally, apply per-head RoPE, then return at head level. + + Args: + B: (batch, seqlen, ngroups, d_state) or (..., d_state, mimo_rank) + C: same as B + theta: (batch, seqlen, nheads, d_state//2) -- per-step rotation angles + nheads: int + ngroups: int + Returns: + B_rot, C_rot at head level: (batch, seqlen, nheads, d_state[, mimo_rank]) + """ + if theta is None: + return B, C + + batch, seqlen = theta.shape[:2] + is_mimo = B.dim() == 5 + dstate = B.shape[-2] if is_mimo else B.shape[-1] + nheads_per_group = nheads // ngroups + half_d = dstate // 2 + + # Expand B, C from groups to heads BEFORE applying RoPE + if is_mimo: + B = repeat(B, "b l g n r -> b l (g h) n r", h=nheads_per_group) + C = repeat(C, "b l g n r -> b l (g h) n r", h=nheads_per_group) + else: + B = repeat(B, "b l g n -> b l (g h) n", h=nheads_per_group) + C = repeat(C, "b l g n -> b l (g h) n", h=nheads_per_group) + + # Cumulative sum of per-head angles + theta_cumsum = torch.cumsum(theta, dim=1) # (batch, seqlen, nheads, dstate//2) + cos_h = torch.cos(theta_cumsum) # (batch, seqlen, nheads, dstate//2) + sin_h = torch.sin(theta_cumsum) + + if is_mimo: + # B: (batch, seqlen, nheads, d_state, mimo_rank) + B1, B2 = B[..., :half_d, :], B[..., half_d:, :] + B_rot = torch.cat([ + B1 * cos_h.unsqueeze(-1) - B2 * sin_h.unsqueeze(-1), + B1 * sin_h.unsqueeze(-1) + B2 * cos_h.unsqueeze(-1), + ], dim=-2) + C1, C2 = C[..., :half_d, :], C[..., half_d:, :] + C_rot = torch.cat([ + C1 * cos_h.unsqueeze(-1) - C2 * sin_h.unsqueeze(-1), + C1 * sin_h.unsqueeze(-1) + C2 * cos_h.unsqueeze(-1), + ], dim=-2) + else: + # B: (batch, seqlen, nheads, d_state) + B1, B2 = B[..., :half_d], B[..., half_d:] + B_rot = torch.cat([B1 * cos_h - B2 * sin_h, B1 * sin_h + B2 * cos_h], dim=-1) + C1, C2 = C[..., :half_d], C[..., half_d:] + C_rot = torch.cat([C1 * cos_h - C2 * sin_h, C1 * sin_h + C2 * cos_h], dim=-1) + + return B_rot, C_rot + + +def mamba3_chunk_scan_combined( + x, dt, A, B, C, + chunk_size, + gamma=None, + beta=None, + theta=None, + D=None, + z=None, + dt_bias=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + initial_states=None, + initial_prev_Bx=None, + return_final_states=False, + ngroups=1, + seq_idx=None, +): + """ + Combined chunked SSD for Mamba-3. + + Args: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads,) -- negative + B: (batch, seqlen, ngroups, d_state) + C: (batch, seqlen, ngroups, d_state) + chunk_size: int + gamma: (batch, seqlen, nheads) -- λ * dt (trapezoidal current weight) + beta: (batch, seqlen, nheads) -- (1-λ) * dt * exp(dt*A) (trapezoidal lookback weight) + theta: (batch, seqlen, nheads, d_state//2) -- RoPE angles + D, z, dt_bias, dt_softplus, dt_limit: same as mamba_chunk_scan_combined + initial_states: (batch, nheads, headdim, d_state) + return_final_states: bool + ngroups: int + """ + batch, seqlen, nheads, headdim = x.shape[:4] + is_mimo = B.dim() == 5 + dstate = B.shape[-2] if is_mimo else B.shape[-1] + nheads_per_group = nheads // ngroups + + # Process dt + if dt_bias is not None: + dt = dt + dt_bias.view(1, 1, nheads) + if dt_softplus: + dt = F.softplus(dt) + if dt_limit != (0.0, float("inf")): + dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]) + + # Apply RoPE to B, C before chunked computation + # apply_rotary_emb_to_bc expands from groups→heads internally (per-head RoPE) + if theta is not None: + B, C = apply_rotary_emb_to_bc(B, C, theta, nheads, ngroups) + is_mimo = B.dim() == 5 # refresh after expansion + else: + # No RoPE: expand B, C from groups to heads + is_mimo = B.dim() == 5 + if is_mimo: + B = repeat(B, "b l g n r -> b l (g h) n r", h=nheads_per_group) + C = repeat(C, "b l g n r -> b l (g h) n r", h=nheads_per_group) + else: + B = repeat(B, "b l g n -> b l (g h) n", h=nheads_per_group) + C = repeat(C, "b l g n -> b l (g h) n", h=nheads_per_group) + + # Pad sequence to multiple of chunk_size + pad_len = (chunk_size - seqlen % chunk_size) % chunk_size + if pad_len > 0: + x = F.pad(x, (0, 0, 0, 0, 0, pad_len)) if not is_mimo else \ + F.pad(x, (0, 0, 0, 0, 0, 0, 0, pad_len)) + dt = F.pad(dt, (0, 0, 0, pad_len)) + B = F.pad(B, (0, 0, 0, 0, 0, pad_len)) if not is_mimo else \ + F.pad(B, (0, 0, 0, 0, 0, 0, 0, pad_len)) + C = F.pad(C, (0, 0, 0, 0, 0, pad_len)) if not is_mimo else \ + F.pad(C, (0, 0, 0, 0, 0, 0, 0, pad_len)) + if gamma is not None: + gamma = F.pad(gamma, (0, 0, 0, pad_len)) + if beta is not None: + beta = F.pad(beta, (0, 0, 0, pad_len)) + if z is not None: + z = F.pad(z, (0, 0, 0, 0, 0, pad_len)) if not is_mimo else \ + F.pad(z, (0, 0, 0, 0, 0, 0, 0, pad_len)) + if seq_idx is not None: + # Pad with -1 so padded positions are never equal to real doc indices + seq_idx = F.pad(seq_idx, (0, pad_len), value=-1) + + result = mamba3_ssd_chunked( + x, dt, A, B, C, + block_len=chunk_size, + gamma=gamma, + beta=beta, + D=D, + z=z, + initial_states=initial_states, + return_final_states=return_final_states, + initial_prev_Bx=initial_prev_Bx, + seq_idx=seq_idx, + ) + + # Un-pad + if pad_len > 0: + if return_final_states: + Y, final_state = result + Y = Y[:, :seqlen] + result = (Y, final_state) + else: + result = result[:, :seqlen] + + return result + + +# ============================================================================ +# Triton kernel for Mamba-3 single-step decode +# ============================================================================ +# LIMITATION: This kernel operates on B, C at group level without RoPE, BCNorm, +# or BC bias. It is a low-level primitive — the caller must pre-process B, C +# (apply norm, expand to heads, apply bias, apply RoPE) before calling. +# The step() method in mamba3.py handles this correctly in PyTorch. +# Supports both SISO and MIMO decode (per-rank B, C, X for MIMO). +# ============================================================================ + +if HAS_TRITON: + + @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) + @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) + @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) + @triton.heuristics({"HAS_PREV_BX": lambda args: args["prev_Bx_ptr"] is not None}) + @triton.heuristics({"HAS_BETA": lambda args: args["beta_ptr"] is not None}) + @triton.heuristics({"HAS_GAMMA": lambda args: args["gamma_ptr"] is not None}) + @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) + @triton.heuristics({"IS_MIMO": lambda args: args["mimo_rank"] > 0}) + @triton.heuristics({"MIMO_RANK": lambda args: args["mimo_rank"]}) + @triton.jit + def _mamba3_state_update_kernel( + # Pointers + state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, + B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, + prev_Bx_ptr, beta_ptr, gamma_ptr, + # Dims + batch, nheads, dim, dstate, nheads_ngroups_ratio, mimo_rank, + # Strides + stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate, + stride_x_batch, stride_x_head, stride_x_dim, + stride_dt_batch, stride_dt_head, + stride_A_head, + stride_B_batch, stride_B_group, stride_B_dstate, + stride_C_batch, stride_C_group, stride_C_dstate, + stride_D_head, + stride_z_batch, stride_z_head, stride_z_dim, + stride_out_batch, stride_out_head, stride_out_dim, + stride_prev_Bx_batch, stride_prev_Bx_head, stride_prev_Bx_dim, stride_prev_Bx_dstate, + stride_beta_batch, stride_beta_head, + stride_gamma_batch, stride_gamma_head, + # MIMO strides (only used when IS_MIMO) + stride_x_rank, stride_B_rank, stride_C_rank, stride_out_rank, + # Meta + DT_SOFTPLUS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_PREV_BX: tl.constexpr, + HAS_BETA: tl.constexpr, + HAS_GAMMA: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_MIMO: tl.constexpr, + MIMO_RANK: tl.constexpr, + ): + pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + + # Load dt and A (scalar per head) + dt = tl.load(dt_ptr + pid_b * stride_dt_batch + pid_h * stride_dt_head).to(tl.float32) + if HAS_DT_BIAS: + dt_bias_stride = tl.load(dt_bias_ptr + pid_h).to(tl.float32) + dt += dt_bias_stride + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, softplus(dt), dt) + + A = tl.load(A_ptr + pid_h * stride_A_head).to(tl.float32) + dA = tl.exp(A * dt) # decay uses original dt + + # Load gamma (input scaling) — separate from dt for trapezoidal + if HAS_GAMMA: + input_scale = tl.load(gamma_ptr + pid_b * stride_gamma_batch + pid_h * stride_gamma_head).to(tl.float32) + else: + input_scale = dt # Euler mode: gamma = dt + + # Load state + state_ptr_base = state_ptr + pid_b * stride_state_batch + pid_h * stride_state_head + state_ptrs = state_ptr_base + offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate + state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) + + # Load x, B and compute Bx (unscaled) and dBx (scaled by input_scale) + x_ptr_base = x_ptr + pid_b * stride_x_batch + pid_h * stride_x_head + B_ptr_base = B_ptr + pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group + C_ptr_base = C_ptr + pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group + + if IS_MIMO: + # MIMO: Bx = Σ_r x[m,r] * B[n,r], summed over rank + Bx_unscaled = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_DSTATE], dtype=tl.float32) + for r in range(MIMO_RANK): + x_r = tl.load(x_ptr_base + offs_m * stride_x_dim + r * stride_x_rank, + mask=offs_m < dim, other=0.0).to(tl.float32) + B_r = tl.load(B_ptr_base + offs_n * stride_B_dstate + r * stride_B_rank, + mask=offs_n < dstate, other=0.0).to(tl.float32) + Bx_unscaled += x_r[:, None] * B_r[None, :] + dBx = Bx_unscaled * input_scale + else: + # SISO: dBx = input_scale * B * x + x = tl.load(x_ptr_base + offs_m * stride_x_dim, mask=offs_m < dim, other=0.0).to(tl.float32) + B = tl.load(B_ptr_base + offs_n * stride_B_dstate, mask=offs_n < dstate, other=0.0).to(tl.float32) + dBx = B[None, :] * input_scale * x[:, None] + + # State update: h = dA * h + dBx + state = state * dA + dBx + + # Add trapezoidal lookback: + beta * prev_Bx + if HAS_PREV_BX and HAS_BETA: + beta = tl.load(beta_ptr + pid_b * stride_beta_batch + pid_h * stride_beta_head).to(tl.float32) + prev_Bx_ptrs = (prev_Bx_ptr + pid_b * stride_prev_Bx_batch + pid_h * stride_prev_Bx_head + + offs_m[:, None] * stride_prev_Bx_dim + offs_n[None, :] * stride_prev_Bx_dstate) + prev_Bx = tl.load(prev_Bx_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) + state = state + beta * prev_Bx + + # Store updated state + tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + + # Store current Bx (unscaled) for next step's trapezoidal lookback + if HAS_PREV_BX: + prev_Bx_ptrs_s = (prev_Bx_ptr + pid_b * stride_prev_Bx_batch + pid_h * stride_prev_Bx_head + + offs_m[:, None] * stride_prev_Bx_dim + offs_n[None, :] * stride_prev_Bx_dstate) + if IS_MIMO: + tl.store(prev_Bx_ptrs_s, Bx_unscaled, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + else: + raw_Bx = B[None, :] * x[:, None] + tl.store(prev_Bx_ptrs_s, raw_Bx, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + + # Output + out_ptr_base = out_ptr + pid_b * stride_out_batch + pid_h * stride_out_head + if IS_MIMO: + # Per-rank output: y[p,r] = Σ_n state[p,n] * C[n,r] + for r in range(MIMO_RANK): + C_r = tl.load(C_ptr_base + offs_n * stride_C_dstate + r * stride_C_rank, + mask=offs_n < dstate, other=0.0).to(tl.float32) + out_r = tl.sum(state * C_r[None, :], axis=1) + tl.store(out_ptr_base + offs_m * stride_out_dim + r * stride_out_rank, + out_r, mask=offs_m < dim) + else: + C = tl.load(C_ptr_base + offs_n * stride_C_dstate, mask=offs_n < dstate, other=0.0).to(tl.float32) + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + out += x * D + if HAS_Z: + z = tl.load(z_ptr + pid_b * stride_z_batch + pid_h * stride_z_head + offs_m * stride_z_dim, + mask=offs_m < dim, other=0.0).to(tl.float32) + out *= z * tl.sigmoid(z) + tl.store(out_ptr_base + offs_m * stride_out_dim, out, mask=offs_m < dim) + + +def mamba3_state_update( + state, x, dt, A, B, C, + D=None, z=None, dt_bias=None, dt_softplus=False, + prev_Bx=None, beta=None, gamma=None, +): + """ + Mamba-3 single-step decode with fused Triton kernel. Supports SISO and MIMO. + + Args: + state: (batch, nheads, dim, dstate) + x: (batch, nheads, dim) for SISO, (batch, nheads, dim, mimo_rank) for MIMO + dt: (batch, nheads) + A: (nheads,) + B: (batch, ngroups, dstate) for SISO, (batch, ngroups, dstate, mimo_rank) for MIMO + C: (batch, ngroups, dstate) for SISO, (batch, ngroups, dstate, mimo_rank) for MIMO + D: (nheads,) or None -- NOT applied for MIMO (handled outside) + z: (batch, nheads, dim) or None -- NOT applied for MIMO (handled outside) + dt_bias: (nheads,) or None + dt_softplus: bool + prev_Bx: (batch, nheads, dim, dstate) or None + beta: (batch, nheads) or None + gamma: (batch, nheads) or None + Returns: + out: (batch, nheads, dim) for SISO, (batch, nheads, dim, mimo_rank) for MIMO + """ + is_mimo = x.dim() == 4 + mr = x.shape[3] if is_mimo else 0 + + if not HAS_TRITON: + return _mamba3_state_update_ref(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, prev_Bx, beta, gamma) + + batch, nheads, dim, dstate = state.shape + ngroups = B.shape[1] + assert nheads % ngroups == 0 + + if is_mimo: + out = torch.empty((batch, nheads, dim, mr), device=x.device, dtype=x.dtype) + else: + out = torch.empty((batch, nheads, dim), device=x.device, dtype=x.dtype) + + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) + BLOCK_SIZE_M = 8 if dstate <= 32 else (4 if dstate <= 128 else 2) + + with torch.cuda.device(x.device.index): + _mamba3_state_update_kernel[grid]( + state, x, dt, dt_bias, A, B, C, D, z, out, + prev_Bx, beta, gamma, + batch, nheads, dim, dstate, nheads // ngroups, mr, + # state strides + state.stride(0), state.stride(1), state.stride(2), state.stride(3), + # x strides + x.stride(0), x.stride(1), x.stride(2), + # dt strides + dt.stride(0), dt.stride(1), + # A stride + A.stride(0), + # B strides + B.stride(0), B.stride(1), B.stride(2), + # C strides + C.stride(0), C.stride(1), C.stride(2), + # D stride + D.stride(0) if D is not None else 0, + # z strides + *((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)), + # out strides + out.stride(0), out.stride(1), out.stride(2), + # prev_Bx strides + *((prev_Bx.stride(0), prev_Bx.stride(1), prev_Bx.stride(2), prev_Bx.stride(3)) + if prev_Bx is not None else (0, 0, 0, 0)), + # beta strides + *((beta.stride(0), beta.stride(1)) if beta is not None else (0, 0)), + # gamma strides + *((gamma.stride(0), gamma.stride(1)) if gamma is not None else (0, 0)), + # MIMO strides + x.stride(3) if is_mimo else 0, + B.stride(3) if is_mimo else 0, + C.stride(3) if is_mimo else 0, + out.stride(3) if is_mimo else 0, + # Meta + dt_softplus, + BLOCK_SIZE_M, + num_warps=4, + ) + return out + + +def _mamba3_state_update_ref(state, x, dt, A, B, C, D=None, z=None, + dt_bias=None, dt_softplus=False, + prev_Bx=None, beta=None, gamma=None): + """Reference PyTorch implementation for Mamba-3 decode step. Supports SISO and MIMO.""" + batch, nheads, dim, dstate = state.shape + ngroups = B.shape[1] + is_mimo = x.dim() == 4 + + if dt_bias is not None: + dt = dt + dt_bias + if dt_softplus: + dt = F.softplus(dt) + + dA = torch.exp(dt * A) # (batch, nheads) + + # Input scaling: gamma for trapezoidal, dt for Euler + input_scale = gamma if gamma is not None else dt + + # Expand B, C from groups to heads + nheads_per_group = nheads // ngroups + if is_mimo: + B_exp = repeat(B, "b g n r -> b (g h) n r", h=nheads_per_group) + C_exp = repeat(C, "b g n r -> b (g h) n r", h=nheads_per_group) + else: + B_exp = repeat(B, "b g n -> b (g h) n", h=nheads_per_group) + C_exp = repeat(C, "b g n -> b (g h) n", h=nheads_per_group) + + # Compute Bx (unscaled): sum over rank for MIMO + if is_mimo: + Bx = torch.einsum("bhpr,bhnr->bhpn", x.float(), B_exp.float()) + else: + Bx = torch.einsum("bhp,bhn->bhpn", x.float(), B_exp.float()) + + # Scaled dBx + dBx = rearrange(input_scale, "b h -> b h 1 1") * Bx + + # State update + state.copy_(state * rearrange(dA, "b h -> b h 1 1") + dBx) + + # Trapezoidal lookback + if prev_Bx is not None and beta is not None: + state.add_(rearrange(beta, "b h -> b h 1 1") * prev_Bx) + + # Store current raw Bx for next step (unscaled) + if prev_Bx is not None: + prev_Bx.copy_(Bx) + + # Output + if is_mimo: + out = torch.einsum("bhpn,bhnr->bhpr", state.to(C_exp.dtype), C_exp) + # D and z not applied for MIMO (handled outside) + else: + out = torch.einsum("bhpn,bhn->bhp", state.to(C_exp.dtype), C_exp) + if D is not None: + out = out + x * rearrange(D, "h -> 1 h 1") + if z is not None: + out = out * F.silu(z) + + return out diff --git a/mamba_ssm/ops/triton/ssd_state_passing.py b/mamba_ssm/ops/triton/ssd_state_passing.py index d6aa53c96..359ad4d53 100644 --- a/mamba_ssm/ops/triton/ssd_state_passing.py +++ b/mamba_ssm/ops/triton/ssd_state_passing.py @@ -68,7 +68,8 @@ def _state_passing_fwd_kernel( states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) tl.store(out_ptrs, states, mask=offs_m < dim) out_ptrs += stride_out_chunk - seq_idx = 0 + if HAS_SEQ_IDX: + seq_idx = tl.load(seq_idx_ptr).to(tl.int64) for c in range(nchunks): new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) diff --git a/tests/test_mamba3_cpu.py b/tests/test_mamba3_cpu.py new file mode 100644 index 000000000..692a27458 --- /dev/null +++ b/tests/test_mamba3_cpu.py @@ -0,0 +1,974 @@ +"""CPU-only tests for Mamba-3 implementation. + +Tests numerical consistency between: +1. Step-by-step recurrence vs chunked parallel SSD +2. Mamba3Simple vs Mamba3 (reference vs full) +3. Prefill (forward) vs decode (step) consistency +4. SISO and MIMO variants + +No CUDA required — all tests run on CPU with PyTorch fallbacks. +""" + +import pytest +import sys +import os +import importlib +from unittest.mock import MagicMock +from types import ModuleType + +# ============================================================================ +# Heavy-duty mocking to allow CPU-only import of mamba3 modules +# without triton/CUDA. We intercept the entire triton ecosystem. +# ============================================================================ + +class _TritonMock(ModuleType): + """A mock that acts as a module and supports attribute access.""" + def __init__(self, name="triton"): + super().__init__(name) + self.__version__ = "3.0.0" + + def __getattr__(self, name): + if name.startswith("__") and name.endswith("__"): + raise AttributeError(name) + # Return a callable mock for decorators, functions, etc. + mock = MagicMock() + setattr(self, name, mock) + return mock + + def jit(self, fn=None, **kwargs): + return fn if fn else (lambda f: f) + + def heuristics(self, mapping): + return lambda fn: fn + + def autotune(self, **kwargs): + return lambda fn: fn + + def next_power_of_2(self, x): + return 1 << (x - 1).bit_length() if x > 0 else 1 + + def cdiv(self, a, b): + return (a + b - 1) // b + + +# Install triton mock +_tmock = _TritonMock("triton") +_tl_mock = _TritonMock("triton.language") +_tl_mock.constexpr = type # tl.constexpr used in type hints + +# Create mocks for all Triton-dependent and CUDA modules +_mods_to_mock = [ + "triton", "triton.language", + "causal_conv1d", "causal_conv1d.causal_conv1d_varlen", + "flash_attn", "flash_attn.ops", "flash_attn.ops.triton", + "flash_attn.ops.triton.layer_norm", + "selective_scan_cuda", "causal_conv1d_cuda", +] +for m in _mods_to_mock: + if m == "triton": + sys.modules[m] = _tmock + elif m == "triton.language": + sys.modules[m] = _tl_mock + elif m not in sys.modules: + sys.modules[m] = MagicMock() + +# Pre-mock all mamba_ssm.ops.triton modules that use triton at module-level +# to prevent import errors from triton autotuning/config pruning +_triton_ops_to_mock = [ + "mamba_ssm.ops.triton.layer_norm", + "mamba_ssm.ops.triton.layernorm_gated", + "mamba_ssm.ops.triton.selective_state_update", + "mamba_ssm.ops.triton.ssd_combined", + "mamba_ssm.ops.triton.softplus", + "mamba_ssm.ops.selective_scan_interface", +] +for m in _triton_ops_to_mock: + _mod_mock = MagicMock() + # Provide commonly needed attributes + _mod_mock.RMSNorm = None + _mod_mock._layer_norm_fwd = None + _mod_mock.rms_norm_fn = None + _mod_mock.layer_norm_fn = None + _mod_mock.selective_scan_fn = None + _mod_mock.mamba_inner_fn = None + _mod_mock.mamba_chunk_scan_combined = None + _mod_mock.selective_state_update = None + sys.modules[m] = _mod_mock + +# Also mock heavy deps that come through mamba_ssm.__init__ -> generation -> transformers +_tf_mock = MagicMock() +for m in [ + "transformers", "transformers.generation", "transformers.utils", + "transformers.utils.hub", +]: + sys.modules[m] = _tf_mock + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class _RMSNormGatedCPU(nn.Module): + """CPU fallback for the Triton RMSNormGated kernel.""" + def __init__(self, d, eps=1e-5, norm_before_gate=False, group_size=0, + device=None, dtype=None): + super().__init__() + self.eps = eps + self.norm_before_gate = norm_before_gate + self.weight = nn.Parameter(torch.ones(d, device=device, dtype=dtype)) + self.bias = None + + def forward(self, x, z=None): + # RMSNorm + rms = torch.sqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) + x_normed = (x.float() / rms).to(x.dtype) * self.weight + if z is not None: + x_normed = x_normed * F.silu(z) + return x_normed + + +# Inject CPU fallback for RMSNormGated into the mock +_layernorm_mock = sys.modules["mamba_ssm.ops.triton.layernorm_gated"] +_layernorm_mock.RMSNorm = _RMSNormGatedCPU + +# Also set it in the already-imported layer_norm module +_layer_norm_mock = sys.modules["mamba_ssm.ops.triton.layer_norm"] +_layer_norm_mock.RMSNorm = nn.RMSNorm +_layer_norm_mock.layer_norm_fn = None +_layer_norm_mock.rms_norm_fn = None + +# Now we can import mamba_ssm modules (Triton ops will gracefully degrade) +from mamba_ssm.modules.mamba3 import apply_rotary_emb, compute_cumulative_rotary +from mamba_ssm.modules.mamba3_simple import Mamba3Simple +from mamba_ssm.ops.triton.mamba3_ssd import mamba3_ssd_chunked + + +DEVICE = "cpu" +DTYPE = torch.float32 # Use float32 on CPU for numerical stability + + +# ============================================================================ +# Helpers +# ============================================================================ + +def make_mamba3_simple(d_model=32, d_state=16, expand=2, headdim=16, ngroups=1, + use_rope=True, use_trapezoidal=True, mimo_rank=0, **kwargs): + """Create a Mamba3Simple on CPU.""" + return Mamba3Simple( + d_model=d_model, d_state=d_state, expand=expand, headdim=headdim, + ngroups=ngroups, use_rope=use_rope, use_trapezoidal=use_trapezoidal, + use_bc_norm=kwargs.pop("use_bc_norm", True), + use_bc_bias=kwargs.pop("use_bc_bias", True), + mimo_rank=mimo_rank, + device=DEVICE, dtype=DTYPE, **kwargs, + ).eval() + + +def make_input(batch, seqlen, d_model): + return torch.randn(batch, seqlen, d_model, device=DEVICE, dtype=DTYPE) + + +# ============================================================================ +# Test: Mamba3Simple forward (smoke test) +# ============================================================================ + +class TestMamba3SimpleSmoke: + """Basic smoke tests that Mamba3Simple runs without error.""" + + def test_siso_forward(self): + model = make_mamba3_simple() + u = make_input(2, 64, 32) + y = model(u) + assert y.shape == u.shape + + def test_mimo_forward(self): + model = make_mamba3_simple(mimo_rank=2) + u = make_input(2, 64, 32) + y = model(u) + assert y.shape == u.shape + + def test_no_rope(self): + model = make_mamba3_simple(use_rope=False) + u = make_input(2, 64, 32) + y = model(u) + assert y.shape == u.shape + + def test_no_trapezoidal(self): + model = make_mamba3_simple(use_trapezoidal=False) + u = make_input(2, 64, 32) + y = model(u) + assert y.shape == u.shape + + def test_euler_fallback(self): + """No trapezoidal, no rope — pure Euler discretization.""" + model = make_mamba3_simple(use_rope=False, use_trapezoidal=False) + u = make_input(2, 64, 32) + y = model(u) + assert y.shape == u.shape + + +# ============================================================================ +# Test: Chunked SSD vs Step-by-step Recurrence +# ============================================================================ + +class TestChunkedVsRecurrence: + """Verify that the chunked parallel SSD matches step-by-step recurrence.""" + + def _run_comparison(self, use_rope=True, use_trapezoidal=True, mimo_rank=0): + torch.manual_seed(42) + batch, seqlen, nheads, headdim, dstate = 2, 64, 4, 8, 16 + chunk_size = 16 + + # Generate inputs + X = torch.randn(batch, seqlen, nheads, headdim, device=DEVICE, dtype=DTYPE) + dt = torch.rand(batch, seqlen, nheads, device=DEVICE, dtype=DTYPE) * 0.1 + 0.01 + A = -torch.rand(nheads, device=DEVICE, dtype=DTYPE) * 5 - 1 # negative + B = torch.randn(batch, seqlen, nheads, dstate, device=DEVICE, dtype=DTYPE) + C = torch.randn(batch, seqlen, nheads, dstate, device=DEVICE, dtype=DTYPE) + + theta = None + if use_rope: + theta = torch.randn(batch, seqlen, nheads, dstate // 2, device=DEVICE, dtype=DTYPE) * 0.1 + + lam = None + gamma = dt.clone() + beta = None + if use_trapezoidal: + lam = torch.sigmoid(torch.randn(batch, seqlen, nheads, device=DEVICE, dtype=DTYPE)) + gamma = lam * dt + beta = (1 - lam) * dt * torch.exp(dt * A.view(1, 1, nheads)) + + if mimo_rank > 0: + R = mimo_rank + X = torch.randn(batch, seqlen, nheads, headdim, R, device=DEVICE, dtype=DTYPE) + B = torch.randn(batch, seqlen, nheads, dstate, R, device=DEVICE, dtype=DTYPE) + C = torch.randn(batch, seqlen, nheads, dstate, R, device=DEVICE, dtype=DTYPE) + + # Apply RoPE to B, C (both paths need rotated B, C) + if theta is not None: + theta_cumsum = torch.cumsum(theta, dim=1) + cos_t, sin_t = compute_cumulative_rotary(theta_cumsum, dstate) + if mimo_rank > 0: + for r in range(mimo_rank): + B[:, :, :, :, r] = apply_rotary_emb(B[:, :, :, :, r], cos_t, sin_t) + C[:, :, :, :, r] = apply_rotary_emb(C[:, :, :, :, r], cos_t, sin_t) + else: + B = apply_rotary_emb(B, cos_t, sin_t) + C = apply_rotary_emb(C, cos_t, sin_t) + + # --- Chunked path --- + Y_chunked = mamba3_ssd_chunked( + X, dt, A, B, C, + block_len=chunk_size, + gamma=gamma, + beta=beta, + ) + + # --- Step-by-step recurrence --- + is_mimo = mimo_rank > 0 + alpha = torch.exp(dt.unsqueeze(-1) * A.view(1, 1, nheads, 1)) + h = torch.zeros(batch, nheads, headdim, dstate, device=DEVICE, dtype=torch.float32) + ys = [] + prev_Bx = None + + for t in range(seqlen): + x_t = X[:, t] + B_t = B[:, t] + C_t = C[:, t] + + if is_mimo: + Bx_t = torch.einsum("bhpr,bhnr->bhpn", x_t.float(), B_t.float()) + else: + Bx_t = torch.einsum("bhp,bhn->bhpn", x_t.float(), B_t.float()) + + alpha_t = alpha[:, t].unsqueeze(-1) + gamma_t = gamma[:, t].unsqueeze(-1).unsqueeze(-1) + h = alpha_t * h + gamma_t * Bx_t + + if beta is not None and prev_Bx is not None: + beta_t = beta[:, t].unsqueeze(-1).unsqueeze(-1) + h = h + beta_t * prev_Bx + + prev_Bx = Bx_t + + if is_mimo: + y_t = torch.einsum("bhpn,bhnr->bhpr", h.to(DTYPE), C_t) + else: + y_t = torch.einsum("bhpn,bhn->bhp", h.to(DTYPE), C_t) + ys.append(y_t) + + Y_recurrence = torch.stack(ys, dim=1) + + # Compare + torch.testing.assert_close(Y_chunked, Y_recurrence, atol=1e-4, rtol=1e-3) + + def test_siso_euler(self): + self._run_comparison(use_rope=False, use_trapezoidal=False) + + def test_siso_trapezoidal(self): + self._run_comparison(use_rope=False, use_trapezoidal=True) + + def test_siso_rope(self): + self._run_comparison(use_rope=True, use_trapezoidal=False) + + def test_siso_full(self): + self._run_comparison(use_rope=True, use_trapezoidal=True) + + def test_mimo_euler(self): + self._run_comparison(use_rope=False, use_trapezoidal=False, mimo_rank=2) + + def test_mimo_trapezoidal(self): + self._run_comparison(use_rope=False, use_trapezoidal=True, mimo_rank=2) + + def test_mimo_rope(self): + self._run_comparison(use_rope=True, use_trapezoidal=False, mimo_rank=2) + + def test_mimo_full(self): + self._run_comparison(use_rope=True, use_trapezoidal=True, mimo_rank=2) + + +# ============================================================================ +# Test: Mamba3Simple recurrence consistency +# ============================================================================ + +class TestMamba3SimpleRecurrence: + """Test that Mamba3Simple's internal recurrence produces consistent outputs.""" + + def test_siso_gradient_flows(self): + model = make_mamba3_simple() + u = make_input(2, 32, 32) + u.requires_grad_(True) + y = model(u) + loss = y.sum() + loss.backward() + assert u.grad is not None + assert not torch.isnan(u.grad).any() + + def test_mimo_gradient_flows(self): + model = make_mamba3_simple(mimo_rank=2) + u = make_input(2, 32, 32) + u.requires_grad_(True) + y = model(u) + loss = y.sum() + loss.backward() + assert u.grad is not None + assert not torch.isnan(u.grad).any() + + +# ============================================================================ +# Test: RoPE correctness +# ============================================================================ + +class TestRoPE: + """Test rotary embedding functions.""" + + def test_apply_rotary_emb_identity(self): + """cos=1, sin=0 should be identity.""" + x = torch.randn(2, 4, 8) + cos = torch.ones(2, 4, 4) + sin = torch.zeros(2, 4, 4) + out = apply_rotary_emb(x, cos, sin) + torch.testing.assert_close(out, x) + + def test_apply_rotary_emb_rotation(self): + """90-degree rotation: cos=0, sin=1 swaps halves.""" + x = torch.randn(2, 4, 8) + cos = torch.zeros(2, 4, 4) + sin = torch.ones(2, 4, 4) + out = apply_rotary_emb(x, cos, sin) + x1, x2 = x[..., :4], x[..., 4:] + expected = torch.cat([-x2, x1], dim=-1) + torch.testing.assert_close(out, expected) + + def test_cumulative_rotary(self): + theta = torch.tensor([[[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]]) # (1, 3, 2) + # Add head dim + theta = theta.unsqueeze(2) # (1, 3, 1, 2) + cumsum = torch.cumsum(theta, dim=1) + cos, sin = compute_cumulative_rotary(cumsum, 4) + assert cos.shape == (1, 3, 1, 2) + # Verify cumulative property + torch.testing.assert_close(cos[:, 0], torch.cos(theta[:, 0])) + torch.testing.assert_close(cos[:, 1], torch.cos(theta[:, 0] + theta[:, 1])) + + +# ============================================================================ +# Test: Trapezoidal discretization correctness +# ============================================================================ + +class TestTrapezoidalDiscretization: + """Test that trapezoidal discretization produces correct recurrence.""" + + def test_euler_equivalence(self): + """When lambda=1 exactly, trapezoidal should equal Euler (beta=0, gamma=dt).""" + torch.manual_seed(123) + batch, seqlen, d_model = 1, 16, 32 + model = make_mamba3_simple(d_model=d_model, use_rope=False, use_trapezoidal=True) + u = make_input(batch, seqlen, d_model) + + # Run recurrence manually with forced lambda=1 + with torch.no_grad(): + proj = model.in_proj(u) + splits = torch.split(proj, model._split_sizes, dim=-1) + z, x_raw, B_raw, C_raw, dt_raw = splits[0], splits[1], splits[2], splits[3], splits[4] + # Force lambda to 1.0 (sigmoid(large) ≈ 1) + lam_forced = torch.ones(batch, seqlen, model.nheads) + + dt = F.softplus(dt_raw + model.dt_bias) + A = -torch.exp(model.A_log.float()) + alpha = torch.exp(dt.unsqueeze(-1) * A.view(1, 1, model.nheads, 1)) + gamma = lam_forced * dt # = dt (Euler) + beta = (1 - lam_forced) * dt * torch.exp(dt * A.view(1, 1, model.nheads)) # = 0 + + # beta should be zero when lambda=1 + assert torch.allclose(beta, torch.zeros_like(beta), atol=1e-7) + # gamma should equal dt when lambda=1 + assert torch.allclose(gamma, dt, atol=1e-7) + + # Also verify the model produces finite output + with torch.no_grad(): + y = model(u) + assert torch.isfinite(y).all() + + def test_trapezoidal_vs_euler_different(self): + """Trapezoidal and Euler models should produce numerically different outputs.""" + torch.manual_seed(42) + d_model = 32 + u = make_input(1, 32, d_model) + + model_trap = make_mamba3_simple(d_model=d_model, use_rope=False, use_trapezoidal=True) + model_euler = make_mamba3_simple(d_model=d_model, use_rope=False, use_trapezoidal=False) + + # Structural check: trapezoidal has extra lambda projection + assert model_trap.in_proj.weight.shape[0] != model_euler.in_proj.weight.shape[0] + + # Copy shared weights so only trapezoidal vs Euler differs + with torch.no_grad(): + euler_dim = model_euler.in_proj.weight.shape[0] + # Copy the shared prefix of in_proj (z, x, B, C, dt — everything except lambda) + model_euler.in_proj.weight.copy_(model_trap.in_proj.weight[:euler_dim]) + if model_euler.in_proj.bias is not None: + model_euler.in_proj.bias.copy_(model_trap.in_proj.bias[:euler_dim]) + model_euler.out_proj.weight.copy_(model_trap.out_proj.weight) + model_euler.A_log.copy_(model_trap.A_log) + model_euler.D.copy_(model_trap.D) + model_euler.dt_bias.copy_(model_trap.dt_bias) + model_euler.B_norm.weight.copy_(model_trap.B_norm.weight) + model_euler.C_norm.weight.copy_(model_trap.C_norm.weight) + model_euler.B_bias.copy_(model_trap.B_bias) + model_euler.C_bias.copy_(model_trap.C_bias) + model_euler.norm.weight.copy_(model_trap.norm.weight) + + y_trap = model_trap(u) + y_euler = model_euler(u) + + # They should differ (trapezoidal uses lookback term) + assert not torch.allclose(y_trap, y_euler, atol=1e-5), \ + "Trapezoidal and Euler outputs should differ" + + +# ============================================================================ +# Test: MIMO output projection +# ============================================================================ + +class TestMIMOOutputProjection: + """Test that MIMO output projection is a learned linear, not sum.""" + + def test_mimo_out_proj_exists(self): + model = make_mamba3_simple(mimo_rank=4) + assert hasattr(model, 'mimo_out_proj') + assert isinstance(model.mimo_out_proj, nn.Linear) + assert model.mimo_out_proj.in_features == model.headdim * 4 + assert model.mimo_out_proj.out_features == model.headdim + + def test_mimo_out_proj_not_identity(self): + """Verify mimo_out_proj affects output (not bypassed).""" + torch.manual_seed(42) + model = make_mamba3_simple(mimo_rank=2) + u = make_input(1, 32, 32) + + with torch.no_grad(): + y1 = model(u).clone() + # Perturb the projection + model.mimo_out_proj.weight.data += 0.5 + y2 = model(u) + + assert not torch.allclose(y1, y2, atol=1e-6) + + def test_siso_no_mimo_proj(self): + """SISO model should NOT have mimo_out_proj.""" + model = make_mamba3_simple(mimo_rank=0) + assert not hasattr(model, 'mimo_out_proj') + + +# ============================================================================ +# Test: Chunked kernel per-rank MIMO output +# ============================================================================ + +class TestChunkedMIMOPerRank: + """Test that chunked kernel returns per-rank output for MIMO.""" + + def test_mimo_output_shape(self): + """Chunked kernel should return (B, L, H, P, R) for MIMO.""" + torch.manual_seed(42) + batch, seqlen, nheads, headdim, dstate, R = 2, 32, 4, 8, 16, 2 + X = torch.randn(batch, seqlen, nheads, headdim, R) + dt = torch.rand(batch, seqlen, nheads) * 0.1 + 0.01 + A = -torch.rand(nheads) * 5 - 1 + B = torch.randn(batch, seqlen, nheads, dstate, R) + C = torch.randn(batch, seqlen, nheads, dstate, R) + + Y = mamba3_ssd_chunked(X, dt, A, B, C, block_len=16) + assert Y.shape == (batch, seqlen, nheads, headdim, R) + + def test_siso_output_shape(self): + """Chunked kernel should return (B, L, H, P) for SISO.""" + torch.manual_seed(42) + batch, seqlen, nheads, headdim, dstate = 2, 32, 4, 8, 16 + X = torch.randn(batch, seqlen, nheads, headdim) + dt = torch.rand(batch, seqlen, nheads) * 0.1 + 0.01 + A = -torch.rand(nheads) * 5 - 1 + B = torch.randn(batch, seqlen, nheads, dstate) + C = torch.randn(batch, seqlen, nheads, dstate) + + Y = mamba3_ssd_chunked(X, dt, A, B, C, block_len=16) + assert Y.shape == (batch, seqlen, nheads, headdim) + + +# ============================================================================ +# Test: BC Bias initialization +# ============================================================================ + +class TestBCBias: + """Test BC bias is initialized to ones per paper Table 9a.""" + + def test_bc_bias_init_ones(self): + model = make_mamba3_simple() + torch.testing.assert_close(model.B_bias.data, torch.ones_like(model.B_bias.data)) + torch.testing.assert_close(model.C_bias.data, torch.ones_like(model.C_bias.data)) + + def test_bc_bias_shape(self): + model = make_mamba3_simple(d_state=16) + nheads = model.nheads + assert model.B_bias.shape == (nheads, 16) + assert model.C_bias.shape == (nheads, 16) + + +# ============================================================================ +# Test: No causal convolution +# ============================================================================ + +class TestNoConvolution: + """Verify Mamba-3 has no causal convolution (paper Section 3.4).""" + + def test_no_conv1d(self): + model = make_mamba3_simple() + for name, module in model.named_modules(): + assert not isinstance(module, nn.Conv1d), f"Found Conv1d: {name}" + + +# ============================================================================ +# Test: Final states returned correctly +# ============================================================================ + +class TestFinalStates: + """Test that final states are returned for use in generation.""" + + def test_chunked_final_states(self): + torch.manual_seed(42) + batch, seqlen, nheads, headdim, dstate = 2, 32, 4, 8, 16 + X = torch.randn(batch, seqlen, nheads, headdim) + dt = torch.rand(batch, seqlen, nheads) * 0.1 + 0.01 + A = -torch.rand(nheads) * 5 - 1 + B = torch.randn(batch, seqlen, nheads, dstate) + C = torch.randn(batch, seqlen, nheads, dstate) + + Y, final_state = mamba3_ssd_chunked( + X, dt, A, B, C, block_len=16, return_final_states=True, + ) + assert Y.shape == (batch, seqlen, nheads, headdim) + assert final_state.shape == (batch, nheads, headdim, dstate) + + def test_chunked_final_states_mimo(self): + torch.manual_seed(42) + batch, seqlen, nheads, headdim, dstate, R = 2, 32, 4, 8, 16, 2 + X = torch.randn(batch, seqlen, nheads, headdim, R) + dt = torch.rand(batch, seqlen, nheads) * 0.1 + 0.01 + A = -torch.rand(nheads) * 5 - 1 + B = torch.randn(batch, seqlen, nheads, dstate, R) + C = torch.randn(batch, seqlen, nheads, dstate, R) + + Y, final_state = mamba3_ssd_chunked( + X, dt, A, B, C, block_len=16, return_final_states=True, + ) + assert Y.shape == (batch, seqlen, nheads, headdim, R) + assert final_state.shape == (batch, nheads, headdim, dstate) + + +# ============================================================================ +# Test: BCNorm correctness for MIMO +# ============================================================================ + +class TestBCNormMIMO: + """Verify BCNorm normalizes each rank's d_state vector independently.""" + + def test_bcnorm_mimo_normalizes_per_rank(self): + """Each rank's d_state vector should be independently normalized.""" + torch.manual_seed(42) + model = make_mamba3_simple(d_state=8, mimo_rank=2) + + # Create B with known values: rank 0 has large values, rank 1 has small + B = torch.zeros(1, 1, 1, 8, 2) # (b, l, g, d_state, mimo_rank) + B[..., 0] = 10.0 # rank 0: all 10s + B[..., 1] = 0.1 # rank 1: all 0.1s + + orig = B.shape + # Correct normalization: each rank independently + B_r0 = model.B_norm(B[..., 0].reshape(-1, 8)) + B_r1 = model.B_norm(B[..., 1].reshape(-1, 8)) + + # After RMSNorm, both should have similar magnitude (normalized) + rms_r0 = B_r0.float().pow(2).mean().sqrt() + rms_r1 = B_r1.float().pow(2).mean().sqrt() + # RMSNorm should bring both to ~1.0 (weight=1) + torch.testing.assert_close(rms_r0, rms_r1, atol=0.1, rtol=0.1) + + def test_bcnorm_mimo_vs_siso_consistency(self): + """BCNorm on MIMO with rank=1 should match SISO BCNorm.""" + torch.manual_seed(42) + d_state = 16 + model = make_mamba3_simple(d_state=d_state, mimo_rank=0) + + B_siso = torch.randn(2, 4, 1, d_state) # (b, l, g, d_state) + B_mimo = B_siso.unsqueeze(-1) # (b, l, g, d_state, 1) — rank=1 + + # Apply SISO norm + orig_s = B_siso.shape + B_siso_normed = model.B_norm(B_siso.reshape(-1, d_state)).reshape(orig_s) + + # Apply MIMO norm (correct path: movedim before reshape) + orig_m = B_mimo.shape + B_mimo_normed = model.B_norm( + B_mimo.movedim(-1, -2).reshape(-1, d_state) + ).reshape(*orig_m[:-2], orig_m[-1], orig_m[-2]).movedim(-1, -2) + + torch.testing.assert_close(B_siso_normed, B_mimo_normed.squeeze(-1)) + + +# ============================================================================ +# Test: Mamba3Simple MIMO + BCNorm + RoPE full integration +# ============================================================================ + +class TestMIMOFullIntegration: + """Test MIMO with all features enabled produces valid gradients and output.""" + + def test_mimo_bcnorm_rope_gradient(self): + """Full MIMO + BCNorm + RoPE + trapezoidal should have clean gradients.""" + torch.manual_seed(42) + model = make_mamba3_simple(mimo_rank=2, use_rope=True, use_trapezoidal=True) + u = make_input(2, 32, 32) + u.requires_grad_(True) + y = model(u) + loss = y.sum() + loss.backward() + assert u.grad is not None + assert torch.isfinite(u.grad).all() + # Check all model params got grads + for name, p in model.named_parameters(): + if p.requires_grad: + assert p.grad is not None, f"No grad for {name}" + assert torch.isfinite(p.grad).all(), f"Non-finite grad for {name}" + + def test_mimo_deterministic(self): + """Same input should produce same output (no stochastic ops).""" + torch.manual_seed(42) + model = make_mamba3_simple(mimo_rank=2) + u = make_input(1, 16, 32) + with torch.no_grad(): + y1 = model(u).clone() + y2 = model(u) + torch.testing.assert_close(y1, y2) + + def test_ngroups_greater_than_one(self): + """Test with ngroups > 1 (multi-value attention head structure).""" + # ngroups=2 means 2 groups sharing B,C, each expanded to nheads/2 heads + model = make_mamba3_simple( + d_model=64, d_state=16, expand=2, headdim=16, ngroups=2, + mimo_rank=0, use_rope=True, + ) + u = make_input(2, 32, 64) + y = model(u) + assert y.shape == u.shape + + def test_ngroups_mimo(self): + """Test MIMO with ngroups > 1.""" + model = make_mamba3_simple( + d_model=64, d_state=16, expand=2, headdim=16, ngroups=2, + mimo_rank=2, use_rope=True, + ) + u = make_input(2, 32, 64) + y = model(u) + assert y.shape == u.shape + + +# ============================================================================ +# Test: Chunked vs Recurrence with BCNorm + bias (end-to-end Mamba3Simple) +# ============================================================================ + +class TestEndToEndConsistency: + """Test that Mamba3Simple produces finite, reasonable outputs.""" + + def test_output_scale(self): + """Output should not explode or vanish for random input.""" + torch.manual_seed(42) + model = make_mamba3_simple() + u = torch.randn(2, 64, 32) + with torch.no_grad(): + y = model(u) + # Output should be roughly same scale as input (order of magnitude) + assert y.abs().mean() > 1e-4, "Output too small — possible vanishing" + assert y.abs().mean() < 100, "Output too large — possible explosion" + + def test_different_seqlens(self): + """Model should handle various sequence lengths (multiples of chunk_size not required).""" + model = make_mamba3_simple(chunk_size=16) + for seqlen in [16, 32, 48, 64]: + u = make_input(1, seqlen, 32) + y = model(u) + assert y.shape == (1, seqlen, 32) + + +# ============================================================================ +# Test: seq_idx (packed multi-document training) +# ============================================================================ + +class TestSeqIdx: + """Test that seq_idx correctly prevents cross-document information leakage.""" + + def test_seq_idx_simple_isolation(self): + """Two documents packed together should produce same output as running them separately.""" + torch.manual_seed(42) + model = make_mamba3_simple( + d_model=32, d_state=16, expand=2, headdim=16, + use_rope=False, use_trapezoidal=False, # simplify for clean comparison + use_bc_bias=False, use_bc_norm=False, + chunk_size=8, + ) + seqlen_a, seqlen_b = 8, 8 + + # Create two separate inputs + u_a = torch.randn(1, seqlen_a, 32) + u_b = torch.randn(1, seqlen_b, 32) + + # Run separately + with torch.no_grad(): + y_a_sep = model(u_a) + y_b_sep = model(u_b) + + # Pack together with seq_idx + u_packed = torch.cat([u_a, u_b], dim=1) # (1, 16, 32) + seq_idx = torch.cat([ + torch.zeros(1, seqlen_a, dtype=torch.long), + torch.ones(1, seqlen_b, dtype=torch.long), + ], dim=1) + + with torch.no_grad(): + y_packed = model(u_packed, seq_idx=seq_idx) + + # First document should match + torch.testing.assert_close(y_packed[:, :seqlen_a], y_a_sep, atol=1e-5, rtol=1e-5) + # Second document should match (starts fresh) + torch.testing.assert_close(y_packed[:, seqlen_a:], y_b_sep, atol=1e-5, rtol=1e-5) + + def test_seq_idx_with_trapezoidal(self): + """seq_idx should work correctly with trapezoidal discretization.""" + torch.manual_seed(42) + model = make_mamba3_simple( + d_model=32, d_state=16, expand=2, headdim=16, + use_rope=False, use_trapezoidal=True, + use_bc_bias=True, use_bc_norm=True, + chunk_size=8, + ) + seqlen_a, seqlen_b = 8, 8 + u_a = torch.randn(1, seqlen_a, 32) + u_b = torch.randn(1, seqlen_b, 32) + + with torch.no_grad(): + y_a_sep = model(u_a) + y_b_sep = model(u_b) + + u_packed = torch.cat([u_a, u_b], dim=1) + seq_idx = torch.cat([ + torch.zeros(1, seqlen_a, dtype=torch.long), + torch.ones(1, seqlen_b, dtype=torch.long), + ], dim=1) + + with torch.no_grad(): + y_packed = model(u_packed, seq_idx=seq_idx) + + torch.testing.assert_close(y_packed[:, :seqlen_a], y_a_sep, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(y_packed[:, seqlen_a:], y_b_sep, atol=1e-5, rtol=1e-5) + + def test_seq_idx_no_leakage_gradient(self): + """Gradient should not flow across document boundaries.""" + torch.manual_seed(42) + model = make_mamba3_simple( + d_model=32, d_state=16, expand=2, headdim=16, + use_rope=False, use_trapezoidal=False, + use_bc_bias=False, use_bc_norm=False, + chunk_size=8, + ) + u_packed = torch.randn(1, 16, 32, requires_grad=True) + seq_idx = torch.cat([ + torch.zeros(1, 8, dtype=torch.long), + torch.ones(1, 8, dtype=torch.long), + ], dim=1) + + y = model(u_packed, seq_idx=seq_idx) + # Backprop from second document only + loss = y[:, 8:].sum() + loss.backward() + + # Gradient on first document's input should be zero (no leakage) + grad_doc1 = u_packed.grad[:, :8] + assert grad_doc1.abs().max() < 1e-6, \ + f"Gradient leaked across docs: max={grad_doc1.abs().max()}" + + def test_seq_idx_mimo(self): + """seq_idx should work with MIMO.""" + torch.manual_seed(42) + model = make_mamba3_simple( + d_model=32, d_state=16, expand=2, headdim=16, + use_rope=False, use_trapezoidal=True, + mimo_rank=2, chunk_size=8, + ) + u_a = torch.randn(1, 8, 32) + u_b = torch.randn(1, 8, 32) + + with torch.no_grad(): + y_a_sep = model(u_a) + y_b_sep = model(u_b) + + u_packed = torch.cat([u_a, u_b], dim=1) + seq_idx = torch.cat([ + torch.zeros(1, 8, dtype=torch.long), + torch.ones(1, 8, dtype=torch.long), + ], dim=1) + + with torch.no_grad(): + y_packed = model(u_packed, seq_idx=seq_idx) + + torch.testing.assert_close(y_packed[:, :8], y_a_sep, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(y_packed[:, 8:], y_b_sep, atol=1e-5, rtol=1e-5) + + def test_seq_idx_three_docs(self): + """Test with three documents packed together.""" + torch.manual_seed(42) + model = make_mamba3_simple( + d_model=32, d_state=16, expand=2, headdim=16, + use_rope=False, use_trapezoidal=True, + use_bc_bias=True, chunk_size=8, + ) + lens = [8, 8, 8] + us = [torch.randn(1, l, 32) for l in lens] + + with torch.no_grad(): + ys_sep = [model(u) for u in us] + + u_packed = torch.cat(us, dim=1) + seq_idx = torch.cat([ + torch.full((1, l), i, dtype=torch.long) for i, l in enumerate(lens) + ], dim=1) + + with torch.no_grad(): + y_packed = model(u_packed, seq_idx=seq_idx) + + offset = 0 + for i, l in enumerate(lens): + torch.testing.assert_close( + y_packed[:, offset:offset + l], ys_sep[i], + atol=1e-5, rtol=1e-5, + msg=f"Doc {i} mismatch", + ) + offset += l + + def test_seq_idx_uneven_docs_cross_chunk(self): + """Documents that don't align with chunk boundaries.""" + torch.manual_seed(42) + model = make_mamba3_simple( + d_model=32, d_state=16, expand=2, headdim=16, + use_rope=False, use_trapezoidal=True, + chunk_size=8, + ) + # Doc1: 5 tokens, Doc2: 11 tokens — boundary falls mid-chunk + u_a = torch.randn(1, 5, 32) + u_b = torch.randn(1, 11, 32) + + with torch.no_grad(): + y_a_sep = model(u_a) + y_b_sep = model(u_b) + + u_packed = torch.cat([u_a, u_b], dim=1) # (1, 16, 32) + seq_idx = torch.cat([ + torch.zeros(1, 5, dtype=torch.long), + torch.ones(1, 11, dtype=torch.long), + ], dim=1) + + with torch.no_grad(): + y_packed = model(u_packed, seq_idx=seq_idx) + + torch.testing.assert_close(y_packed[:, :5], y_a_sep, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(y_packed[:, 5:], y_b_sep, atol=1e-5, rtol=1e-5) + + +# ============================================================================ +# Test: use_mem_eff_path (gradient checkpointing) +# ============================================================================ + +class TestMemEffPath: + """Test determinism and gradient correctness of Mamba3Simple.""" + + def test_deterministic_forward_backward(self): + """Two identical Mamba3Simple models should produce identical outputs and gradients.""" + torch.manual_seed(42) + import copy + model1 = make_mamba3_simple(d_model=32, d_state=16, expand=2, headdim=16) + model1.train() + model2 = copy.deepcopy(model1) + + u = torch.randn(1, 16, 32) + + u1 = u.clone().requires_grad_(True) + y1 = model1(u1) + y1.sum().backward() + + u2 = u.clone().requires_grad_(True) + y2 = model2(u2) + y2.sum().backward() + + torch.testing.assert_close(y1, y2, atol=1e-6, rtol=1e-6) + torch.testing.assert_close(u1.grad, u2.grad, atol=1e-6, rtol=1e-6) + for (n1, p1), (n2, p2) in zip(model1.named_parameters(), model2.named_parameters()): + if p1.grad is not None: + torch.testing.assert_close(p1.grad, p2.grad, atol=1e-6, rtol=1e-6, + msg=f"Grad mismatch for {n1}") + + def test_recurrence_matches_chunked_gradient(self): + """Recurrence and chunked paths should produce consistent gradients.""" + torch.manual_seed(42) + d_model, d_state, headdim = 32, 16, 16 + model = make_mamba3_simple(d_model=d_model, d_state=d_state, headdim=headdim, + use_rope=True, use_trapezoidal=True) + model.train() + u = torch.randn(1, 16, d_model, requires_grad=True) + y = model(u) + y.sum().backward() + # All parameter gradients should be finite + for name, p in model.named_parameters(): + assert p.grad is not None, f"No gradient for {name}" + assert torch.isfinite(p.grad).all(), f"Non-finite gradient for {name}" + assert torch.isfinite(u.grad).all() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_mamba3_gpu.py b/tests/test_mamba3_gpu.py new file mode 100644 index 000000000..2b9826479 --- /dev/null +++ b/tests/test_mamba3_gpu.py @@ -0,0 +1,717 @@ +"""GPU tests for Mamba-3 implementation. + +Tests on CUDA: +1. Chunked SSD (Triton) vs step-by-step recurrence +2. Prefill (forward) → decode (step) consistency +3. Mamba3 full module: forward + step +4. Gradient flow with mixed precision (bf16) +5. MIMO + RoPE + trapezoidal end-to-end +6. MambaLMHeadModel integration +""" + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +DEVICE = "cuda" + + +# ============================================================================ +# 1. Chunked SSD vs Recurrence (GPU, real Triton kernels) +# ============================================================================ + +class TestChunkedVsRecurrenceGPU: + """Verify chunked parallel SSD matches step-by-step recurrence on GPU.""" + + def _reference_recurrence(self, X, dt, A, B, C, gamma, beta=None): + """Step-by-step reference.""" + is_mimo = X.dim() == 5 + batch, seqlen, nheads, headdim = X.shape[:4] + dstate = B.shape[-2] if is_mimo else B.shape[-1] + + alpha = torch.exp(dt.unsqueeze(-1) * A.view(1, 1, nheads, 1)) + h = torch.zeros(batch, nheads, headdim, dstate, device=X.device, dtype=torch.float32) + ys = [] + prev_Bx = None + + for t in range(seqlen): + x_t = X[:, t] + B_t = B[:, t] + C_t = C[:, t] + + if is_mimo: + Bx_t = torch.einsum("bhpr,bhnr->bhpn", x_t.float(), B_t.float()) + else: + Bx_t = torch.einsum("bhp,bhn->bhpn", x_t.float(), B_t.float()) + + alpha_t = alpha[:, t].unsqueeze(-1) + gamma_t = gamma[:, t].unsqueeze(-1).unsqueeze(-1) + h = alpha_t * h + gamma_t * Bx_t + + if beta is not None and prev_Bx is not None: + beta_t = beta[:, t].unsqueeze(-1).unsqueeze(-1) + h = h + beta_t * prev_Bx + + prev_Bx = Bx_t + + if is_mimo: + y_t = torch.einsum("bhpn,bhnr->bhpr", h.to(X.dtype), C_t) + else: + y_t = torch.einsum("bhpn,bhn->bhp", h.to(X.dtype), C_t) + ys.append(y_t) + + return torch.stack(ys, dim=1) + + def _run_comparison(self, use_rope=True, use_trapezoidal=True, mimo_rank=0, dtype=torch.float32): + from mamba_ssm.modules.mamba3 import apply_rotary_emb, compute_cumulative_rotary + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_ssd_chunked + + torch.manual_seed(42) + batch, seqlen, nheads, headdim, dstate = 2, 128, 8, 16, 32 + chunk_size = 32 + + X = torch.randn(batch, seqlen, nheads, headdim, device=DEVICE, dtype=dtype) + dt = torch.rand(batch, seqlen, nheads, device=DEVICE, dtype=dtype) * 0.1 + 0.01 + A = -torch.rand(nheads, device=DEVICE, dtype=dtype) * 5 - 1 + B = torch.randn(batch, seqlen, nheads, dstate, device=DEVICE, dtype=dtype) + C = torch.randn(batch, seqlen, nheads, dstate, device=DEVICE, dtype=dtype) + + theta = None + if use_rope: + theta = torch.randn(batch, seqlen, nheads, dstate // 2, device=DEVICE, dtype=dtype) * 0.1 + + lam = None + gamma = dt.clone() + beta = None + if use_trapezoidal: + lam = torch.sigmoid(torch.randn(batch, seqlen, nheads, device=DEVICE, dtype=dtype)) + gamma = lam * dt + beta = (1 - lam) * dt * torch.exp(dt * A.view(1, 1, nheads)) + + if mimo_rank > 0: + R = mimo_rank + X = torch.randn(batch, seqlen, nheads, headdim, R, device=DEVICE, dtype=dtype) + B = torch.randn(batch, seqlen, nheads, dstate, R, device=DEVICE, dtype=dtype) + C = torch.randn(batch, seqlen, nheads, dstate, R, device=DEVICE, dtype=dtype) + + # Apply RoPE to B, C for both paths + if theta is not None: + theta_cumsum = torch.cumsum(theta, dim=1) + cos_t, sin_t = compute_cumulative_rotary(theta_cumsum, dstate) + if mimo_rank > 0: + B_parts = [apply_rotary_emb(B[..., r], cos_t, sin_t) for r in range(mimo_rank)] + C_parts = [apply_rotary_emb(C[..., r], cos_t, sin_t) for r in range(mimo_rank)] + B = torch.stack(B_parts, dim=-1) + C = torch.stack(C_parts, dim=-1) + else: + B = apply_rotary_emb(B, cos_t, sin_t) + C = apply_rotary_emb(C, cos_t, sin_t) + + # Chunked + Y_chunked = mamba3_ssd_chunked(X, dt, A, B, C, block_len=chunk_size, gamma=gamma, beta=beta) + + # Reference + Y_ref = self._reference_recurrence(X, dt, A, B, C, gamma, beta) + + atol = 1e-3 if dtype == torch.float32 else 5e-2 + rtol = 1e-3 if dtype == torch.float32 else 5e-2 + torch.testing.assert_close(Y_chunked.float(), Y_ref.float(), atol=atol, rtol=rtol) + + def test_siso_euler_fp32(self): + self._run_comparison(use_rope=False, use_trapezoidal=False) + + def test_siso_full_fp32(self): + self._run_comparison(use_rope=True, use_trapezoidal=True) + + def test_mimo_full_fp32(self): + self._run_comparison(use_rope=True, use_trapezoidal=True, mimo_rank=2) + + def test_siso_full_bf16(self): + self._run_comparison(use_rope=True, use_trapezoidal=True, dtype=torch.bfloat16) + + def test_mimo_full_bf16(self): + self._run_comparison(use_rope=True, use_trapezoidal=True, mimo_rank=2, dtype=torch.bfloat16) + + +# ============================================================================ +# 2. Mamba3 full module forward +# ============================================================================ + +class TestMamba3ModuleGPU: + """Test the full Mamba3 module on GPU.""" + + def _make_model(self, mimo_rank=0, use_rope=True, use_trapezoidal=True, dtype=torch.float32): + from mamba_ssm.modules.mamba3 import Mamba3 + return Mamba3( + d_model=128, d_state=32, expand=2, headdim=32, + ngroups=1, use_rope=use_rope, use_trapezoidal=use_trapezoidal, + use_bc_norm=True, use_bc_bias=True, mimo_rank=mimo_rank, + chunk_size=64, layer_idx=0, device=DEVICE, dtype=dtype, + ).eval() + + def test_siso_forward(self): + model = self._make_model() + u = torch.randn(2, 128, 128, device=DEVICE) + with torch.no_grad(): + y = model(u) + assert y.shape == u.shape + assert torch.isfinite(y).all() + + def test_mimo_forward(self): + model = self._make_model(mimo_rank=4) + u = torch.randn(2, 128, 128, device=DEVICE) + with torch.no_grad(): + y = model(u) + assert y.shape == u.shape + assert torch.isfinite(y).all() + + def test_siso_forward_bf16(self): + model = self._make_model(dtype=torch.bfloat16) + u = torch.randn(2, 128, 128, device=DEVICE, dtype=torch.bfloat16) + with torch.no_grad(): + y = model(u) + assert y.shape == u.shape + assert torch.isfinite(y).all() + + def test_mimo_forward_bf16(self): + model = self._make_model(mimo_rank=4, dtype=torch.bfloat16) + u = torch.randn(2, 128, 128, device=DEVICE, dtype=torch.bfloat16) + with torch.no_grad(): + y = model(u) + assert y.shape == u.shape + assert torch.isfinite(y).all() + + def test_siso_gradient_bf16(self): + model = self._make_model(dtype=torch.bfloat16).train() + u = torch.randn(2, 64, 128, device=DEVICE, dtype=torch.bfloat16, requires_grad=True) + y = model(u) + y.sum().backward() + assert u.grad is not None + assert torch.isfinite(u.grad).all() + + def test_mimo_gradient_bf16(self): + model = self._make_model(mimo_rank=2, dtype=torch.bfloat16).train() + u = torch.randn(2, 64, 128, device=DEVICE, dtype=torch.bfloat16, requires_grad=True) + y = model(u) + y.sum().backward() + assert u.grad is not None + assert torch.isfinite(u.grad).all() + + +# ============================================================================ +# 3. Prefill → Decode consistency +# ============================================================================ + +class TestPrefillDecodeConsistency: + """Test that single-step decode matches last position of prefill.""" + + def _make_model(self, mimo_rank=0, dtype=torch.float32): + from mamba_ssm.modules.mamba3 import Mamba3 + return Mamba3( + d_model=64, d_state=16, expand=2, headdim=16, + ngroups=1, use_rope=True, use_trapezoidal=True, + use_bc_norm=True, use_bc_bias=True, mimo_rank=mimo_rank, + chunk_size=32, layer_idx=0, device=DEVICE, dtype=dtype, + ).eval() + + def test_siso_prefill_then_decode(self): + """After prefill, decode step should produce reasonable output.""" + from mamba_ssm.utils.generation import InferenceParams + + model = self._make_model() + batch, seqlen = 2, 64 + + # Allocate inference cache + inference_params = InferenceParams(max_seqlen=seqlen + 10, max_batch_size=batch) + + # Prefill + u_prefill = torch.randn(batch, seqlen, 64, device=DEVICE) + with torch.no_grad(): + y_prefill = model(u_prefill, inference_params=inference_params) + inference_params.seqlen_offset = seqlen + + # Decode one token + u_decode = torch.randn(batch, 1, 64, device=DEVICE) + with torch.no_grad(): + y_decode = model(u_decode, inference_params=inference_params) + + assert y_decode.shape == (batch, 1, 64) + assert torch.isfinite(y_decode).all() + + def test_siso_multi_step_decode(self): + """Multiple decode steps should all produce finite output.""" + from mamba_ssm.utils.generation import InferenceParams + + model = self._make_model() + batch, seqlen = 1, 32 + + inference_params = InferenceParams(max_seqlen=seqlen + 20, max_batch_size=batch) + + # Prefill + u = torch.randn(batch, seqlen, 64, device=DEVICE) + with torch.no_grad(): + model(u, inference_params=inference_params) + inference_params.seqlen_offset = seqlen + + # Decode 10 tokens + for step in range(10): + u_step = torch.randn(batch, 1, 64, device=DEVICE) + with torch.no_grad(): + y = model(u_step, inference_params=inference_params) + assert torch.isfinite(y).all(), f"Non-finite at decode step {step}" + inference_params.seqlen_offset += 1 + + def test_siso_prefill_decode_numerical(self): + """Verify decode step matches what prefill would produce for same input.""" + from mamba_ssm.utils.generation import InferenceParams + torch.manual_seed(42) + + model = self._make_model() + batch = 1 + seqlen = 32 + + # Full sequence: prefill all at once + u_full = torch.randn(batch, seqlen + 1, 64, device=DEVICE) + with torch.no_grad(): + y_full = model(u_full) + + # Split: prefill first seqlen tokens, then decode last token + inference_params = InferenceParams(max_seqlen=seqlen + 10, max_batch_size=batch) + with torch.no_grad(): + y_prefill = model(u_full[:, :seqlen], inference_params=inference_params) + inference_params.seqlen_offset = seqlen + with torch.no_grad(): + y_decode = model(u_full[:, seqlen:seqlen+1], inference_params=inference_params) + + # The decode output should match the last position of full prefill + # Allow some tolerance due to chunked vs recurrent numerical diffs + torch.testing.assert_close( + y_decode[:, 0].float(), y_full[:, seqlen].float(), + atol=5e-3, rtol=5e-3 + ) + + +# ============================================================================ +# 4. MambaLMHeadModel with Mamba3 +# ============================================================================ + +class TestMambaLMHeadModelMamba3: + """Test full model integration.""" + + def test_mamba3_lm_model_forward(self): + from mamba_ssm.models.config_mamba import MambaConfig + from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + + config = MambaConfig( + d_model=128, n_layer=2, vocab_size=256, + ssm_cfg={"layer": "Mamba3", "d_state": 16, "headdim": 32, + "use_rope": True, "use_trapezoidal": True}, + rms_norm=True, fused_add_norm=False, + ) + model = MambaLMHeadModel(config, device=DEVICE, dtype=torch.bfloat16) + input_ids = torch.randint(0, 256, (2, 64), device=DEVICE) + output = model(input_ids) + assert output.logits.shape == (2, 64, 256) + assert torch.isfinite(output.logits).all() + + def test_mamba3_mimo_lm_model(self): + from mamba_ssm.models.config_mamba import MambaConfig + from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + + config = MambaConfig( + d_model=128, n_layer=2, vocab_size=256, + ssm_cfg={"layer": "Mamba3", "d_state": 16, "headdim": 32, + "mimo_rank": 2, "use_rope": True, "use_trapezoidal": True}, + rms_norm=True, fused_add_norm=False, + ) + model = MambaLMHeadModel(config, device=DEVICE, dtype=torch.bfloat16) + input_ids = torch.randint(0, 256, (2, 64), device=DEVICE) + output = model(input_ids) + assert output.logits.shape == (2, 64, 256) + assert torch.isfinite(output.logits).all() + + def test_mamba3_lm_model_gradient(self): + from mamba_ssm.models.config_mamba import MambaConfig + from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + + config = MambaConfig( + d_model=128, n_layer=2, vocab_size=256, + ssm_cfg={"layer": "Mamba3", "d_state": 16, "headdim": 32}, + rms_norm=True, fused_add_norm=False, + ) + model = MambaLMHeadModel(config, device=DEVICE, dtype=torch.bfloat16) + input_ids = torch.randint(0, 256, (2, 32), device=DEVICE) + output = model(input_ids) + loss = output.logits.float().sum() + loss.backward() + # Check at least some params got gradients + graded = sum(1 for p in model.parameters() if p.grad is not None) + assert graded > 0 + + +# ============================================================================ +# 5. Mamba3Simple on GPU +# ============================================================================ + +class TestMamba3SimpleGPU: + """Test Mamba3Simple on GPU (uses reference recurrence, no Triton).""" + + def _make_model(self, mimo_rank=0, dtype=torch.float32): + from mamba_ssm.modules.mamba3_simple import Mamba3Simple + return Mamba3Simple( + d_model=64, d_state=16, expand=2, headdim=16, + ngroups=1, use_rope=True, use_trapezoidal=True, + use_bc_norm=True, use_bc_bias=True, mimo_rank=mimo_rank, + chunk_size=32, device=DEVICE, dtype=dtype, + ).eval() + + def test_siso_forward_and_grad(self): + model = self._make_model().train() + u = torch.randn(2, 64, 64, device=DEVICE, requires_grad=True) + y = model(u) + assert y.shape == u.shape + y.sum().backward() + assert torch.isfinite(u.grad).all() + + def test_mimo_forward_and_grad(self): + model = self._make_model(mimo_rank=2).train() + u = torch.randn(2, 64, 64, device=DEVICE, requires_grad=True) + y = model(u) + assert y.shape == u.shape + y.sum().backward() + assert torch.isfinite(u.grad).all() + + +# ============================================================================ +# 6. seq_idx — packed multi-document training (GPU) +# ============================================================================ + +class TestSeqIdxGPU: + """Test seq_idx prevents cross-document leakage on GPU with real Triton kernels.""" + + def _make_model(self, **kwargs): + from mamba_ssm.modules.mamba3_simple import Mamba3Simple + defaults = dict( + d_model=64, d_state=16, expand=2, headdim=16, + ngroups=1, use_rope=True, use_trapezoidal=True, + use_bc_norm=True, use_bc_bias=True, mimo_rank=0, + chunk_size=32, device=DEVICE, dtype=torch.float32, + ) + defaults.update(kwargs) + return Mamba3Simple(**defaults).eval() + + def test_siso_two_docs_isolation(self): + """Packed docs should produce same output as separate runs.""" + torch.manual_seed(42) + model = self._make_model(use_rope=False) + u_a = torch.randn(1, 32, 64, device=DEVICE) + u_b = torch.randn(1, 32, 64, device=DEVICE) + + with torch.no_grad(): + y_a = model(u_a) + y_b = model(u_b) + + u_packed = torch.cat([u_a, u_b], dim=1) + seq_idx = torch.cat([ + torch.zeros(1, 32, dtype=torch.long, device=DEVICE), + torch.ones(1, 32, dtype=torch.long, device=DEVICE), + ], dim=1) + + with torch.no_grad(): + y_packed = model(u_packed, seq_idx=seq_idx) + + torch.testing.assert_close(y_packed[:, :32], y_a, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(y_packed[:, 32:], y_b, atol=1e-4, rtol=1e-4) + + def test_siso_trapezoidal_docs(self): + """seq_idx + trapezoidal discretization.""" + torch.manual_seed(42) + model = self._make_model(use_rope=False, use_trapezoidal=True) + u_a = torch.randn(1, 32, 64, device=DEVICE) + u_b = torch.randn(1, 32, 64, device=DEVICE) + + with torch.no_grad(): + y_a = model(u_a) + y_b = model(u_b) + + u_packed = torch.cat([u_a, u_b], dim=1) + seq_idx = torch.cat([ + torch.zeros(1, 32, dtype=torch.long, device=DEVICE), + torch.ones(1, 32, dtype=torch.long, device=DEVICE), + ], dim=1) + + with torch.no_grad(): + y_packed = model(u_packed, seq_idx=seq_idx) + + torch.testing.assert_close(y_packed[:, :32], y_a, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(y_packed[:, 32:], y_b, atol=1e-4, rtol=1e-4) + + def test_gradient_isolation(self): + """No gradient should flow across document boundaries.""" + torch.manual_seed(42) + model = self._make_model(use_rope=False).train() + u = torch.randn(1, 64, 64, device=DEVICE, requires_grad=True) + seq_idx = torch.cat([ + torch.zeros(1, 32, dtype=torch.long, device=DEVICE), + torch.ones(1, 32, dtype=torch.long, device=DEVICE), + ], dim=1) + + y = model(u, seq_idx=seq_idx) + y[:, 32:].sum().backward() + + assert u.grad[:, :32].abs().max() < 1e-5, "Gradient leaked across documents" + + def test_bf16_seq_idx(self): + """seq_idx should work correctly with bf16.""" + torch.manual_seed(42) + model = self._make_model(dtype=torch.bfloat16, use_rope=False) + u_a = torch.randn(1, 32, 64, device=DEVICE, dtype=torch.bfloat16) + u_b = torch.randn(1, 32, 64, device=DEVICE, dtype=torch.bfloat16) + + with torch.no_grad(): + y_a = model(u_a) + y_b = model(u_b) + + u_packed = torch.cat([u_a, u_b], dim=1) + seq_idx = torch.cat([ + torch.zeros(1, 32, dtype=torch.long, device=DEVICE), + torch.ones(1, 32, dtype=torch.long, device=DEVICE), + ], dim=1) + + with torch.no_grad(): + y_packed = model(u_packed, seq_idx=seq_idx) + + torch.testing.assert_close(y_packed[:, :32], y_a, atol=0.05, rtol=0.05) + torch.testing.assert_close(y_packed[:, 32:], y_b, atol=0.05, rtol=0.05) + + +# ============================================================================ +# 7. Triton decode kernel integration +# ============================================================================ + +class TestTritonDecodeKernel: + """Test that Triton decode kernel in step() matches PyTorch reference.""" + + def _make_model(self, **kwargs): + from mamba_ssm.modules.mamba3 import Mamba3 + defaults = dict( + d_model=64, d_state=16, expand=2, headdim=16, + ngroups=1, use_rope=True, use_trapezoidal=True, + use_bc_norm=True, use_bc_bias=True, mimo_rank=0, + use_mem_eff_path=False, + chunk_size=32, layer_idx=0, device=DEVICE, dtype=torch.float32, + ) + defaults.update(kwargs) + return Mamba3(**defaults).eval() + + def test_triton_decode_siso(self): + """SISO decode: Triton kernel should match prefill output at each position.""" + torch.manual_seed(42) + model = self._make_model() + batch, seqlen = 2, 16 + u = torch.randn(batch, seqlen, 64, device=DEVICE) + + # Prefill + from mamba_ssm.utils.generation import InferenceParams + inference_params = InferenceParams(max_seqlen=seqlen, max_batch_size=batch) + with torch.no_grad(): + y_prefill = model(u, inference_params=inference_params) + + # Now decode one more token + u_next = torch.randn(batch, 1, 64, device=DEVICE) + inference_params.seqlen_offset = seqlen + with torch.no_grad(): + y_decode = model(u_next, inference_params=inference_params) + + assert y_decode.shape == (batch, 1, 64) + assert torch.isfinite(y_decode).all() + + def test_triton_vs_pytorch_decode_consistency(self): + """SISO: Triton decode and PyTorch fallback should give same results.""" + torch.manual_seed(42) + model = self._make_model(use_rope=False) + batch, seqlen = 1, 32 + + import copy + from mamba_ssm.utils.generation import InferenceParams + import mamba_ssm.modules.mamba3 as mamba3_mod + + # Run with Triton kernel + torch.manual_seed(99) + u = torch.randn(batch, seqlen, 64, device=DEVICE) + ip1 = InferenceParams(max_seqlen=seqlen + 4, max_batch_size=batch) + with torch.no_grad(): + _ = model(u, inference_params=ip1) + decode_triton = [] + decode_inputs = [] + for t in range(4): + u_t = torch.randn(batch, 1, 64, device=DEVICE) + decode_inputs.append(u_t.clone()) + ip1.seqlen_offset = seqlen + t + with torch.no_grad(): + decode_triton.append(model(u_t, inference_params=ip1)) + + # Run with PyTorch fallback (patch out mamba3_state_update) + model2 = copy.deepcopy(model) + orig_fn = mamba3_mod.mamba3_state_update + torch.manual_seed(99) + u2 = torch.randn(batch, seqlen, 64, device=DEVICE) + ip2 = InferenceParams(max_seqlen=seqlen + 4, max_batch_size=batch) + with torch.no_grad(): + _ = model2(u2, inference_params=ip2) + decode_pytorch = [] + mamba3_mod.mamba3_state_update = None + try: + for t in range(4): + ip2.seqlen_offset = seqlen + t + with torch.no_grad(): + decode_pytorch.append(model2(decode_inputs[t], inference_params=ip2)) + finally: + mamba3_mod.mamba3_state_update = orig_fn + + # Compare Triton vs PyTorch outputs + for i in range(4): + torch.testing.assert_close( + decode_triton[i], decode_pytorch[i], atol=1e-4, rtol=1e-4, + msg=f"SISO decode step {i}: Triton vs PyTorch mismatch", + ) + + def test_triton_decode_mimo(self): + """MIMO decode: Triton kernel should produce finite outputs and correct shapes.""" + torch.manual_seed(42) + model = self._make_model(mimo_rank=2) + batch, seqlen = 2, 16 + u = torch.randn(batch, seqlen, 64, device=DEVICE) + + from mamba_ssm.utils.generation import InferenceParams + inference_params = InferenceParams(max_seqlen=seqlen + 4, max_batch_size=batch) + with torch.no_grad(): + y_prefill = model(u, inference_params=inference_params) + + assert y_prefill.shape == (batch, seqlen, 64) + assert torch.isfinite(y_prefill).all() + + # Decode 4 tokens + for t in range(4): + u_t = torch.randn(batch, 1, 64, device=DEVICE) + inference_params.seqlen_offset = seqlen + t + with torch.no_grad(): + y_t = model(u_t, inference_params=inference_params) + assert y_t.shape == (batch, 1, 64) + assert torch.isfinite(y_t).all(), f"MIMO decode step {t} has non-finite values" + + def test_triton_mimo_vs_pytorch_consistency(self): + """MIMO: Triton decode and PyTorch fallback should give same results.""" + torch.manual_seed(42) + model = self._make_model(mimo_rank=2, use_rope=False) + batch, seqlen = 1, 32 + + import copy + from mamba_ssm.utils.generation import InferenceParams + import mamba_ssm.modules.mamba3 as mamba3_mod + + # Run with Triton kernel + torch.manual_seed(99) + u = torch.randn(batch, seqlen, 64, device=DEVICE) + ip1 = InferenceParams(max_seqlen=seqlen + 2, max_batch_size=batch) + with torch.no_grad(): + _ = model(u, inference_params=ip1) + decode_triton = [] + decode_inputs = [] + for t in range(2): + u_t = torch.randn(batch, 1, 64, device=DEVICE) + decode_inputs.append(u_t.clone()) + ip1.seqlen_offset = seqlen + t + with torch.no_grad(): + decode_triton.append(model(u_t, inference_params=ip1)) + + # Run with PyTorch fallback (patch out mamba3_state_update) + model2 = copy.deepcopy(model) + orig_fn = mamba3_mod.mamba3_state_update + torch.manual_seed(99) + u2 = torch.randn(batch, seqlen, 64, device=DEVICE) + ip2 = InferenceParams(max_seqlen=seqlen + 2, max_batch_size=batch) + with torch.no_grad(): + _ = model2(u2, inference_params=ip2) + decode_pytorch = [] + mamba3_mod.mamba3_state_update = None # force PyTorch fallback + try: + for t in range(2): + ip2.seqlen_offset = seqlen + t + with torch.no_grad(): + decode_pytorch.append(model2(decode_inputs[t], inference_params=ip2)) + finally: + mamba3_mod.mamba3_state_update = orig_fn + + # Both paths should produce matching results + for i in range(2): + torch.testing.assert_close( + decode_triton[i], decode_pytorch[i], atol=1e-4, rtol=1e-4, + msg=f"MIMO decode step {i}: Triton vs PyTorch mismatch", + ) + + +# ============================================================================ +# 8. use_mem_eff_path (gradient checkpointing) +# ============================================================================ + +class TestMemEffPathGPU: + """Test gradient checkpointing produces same results as normal forward.""" + + def _make_model(self, use_mem_eff_path=True, **kwargs): + from mamba_ssm.modules.mamba3 import Mamba3 + defaults = dict( + d_model=64, d_state=16, expand=2, headdim=16, + ngroups=1, use_rope=True, use_trapezoidal=True, + use_bc_norm=True, use_bc_bias=True, mimo_rank=0, + use_mem_eff_path=use_mem_eff_path, + chunk_size=32, device=DEVICE, dtype=torch.float32, + ) + defaults.update(kwargs) + return Mamba3(**defaults) + + def test_checkpointed_matches_normal(self): + """Checkpointed and normal forward should produce same outputs and gradients.""" + torch.manual_seed(42) + model_ckpt = self._make_model(use_mem_eff_path=True).train() + + import copy + model_plain = copy.deepcopy(model_ckpt) + model_plain.use_mem_eff_path = False + + u = torch.randn(2, 32, 64, device=DEVICE) + + # Checkpointed + u1 = u.clone().requires_grad_(True) + y1 = model_ckpt(u1) + y1.sum().backward() + + # Normal + u2 = u.clone().requires_grad_(True) + y2 = model_plain(u2) + y2.sum().backward() + + torch.testing.assert_close(y1, y2, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(u1.grad, u2.grad, atol=1e-5, rtol=1e-5) + # Check parameter gradients match + for (n1, p1), (n2, p2) in zip( + model_ckpt.named_parameters(), model_plain.named_parameters() + ): + if p1.grad is not None: + torch.testing.assert_close(p1.grad, p2.grad, atol=1e-5, rtol=1e-5, + msg=f"Grad mismatch for {n1}") + + def test_checkpointed_bf16(self): + """Gradient checkpointing should work with bf16.""" + torch.manual_seed(42) + model = self._make_model(use_mem_eff_path=True, dtype=torch.bfloat16).train() + u = torch.randn(2, 32, 64, device=DEVICE, dtype=torch.bfloat16, requires_grad=True) + y = model(u) + y.sum().backward() + assert torch.isfinite(u.grad).all() + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/test_mamba3_triton.py b/tests/test_mamba3_triton.py new file mode 100644 index 000000000..ace9060c3 --- /dev/null +++ b/tests/test_mamba3_triton.py @@ -0,0 +1,1022 @@ +"""Tests for Mamba-3 Triton training kernels. + +Verifies that the chunked SSD forward and backward paths produce correct results +by comparing against a step-by-step recurrence reference. Also tests the Triton +decode kernel (mamba3_state_update) against its PyTorch reference implementation. + +Requirements: NVIDIA GPU with Triton support. +""" + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +# Skip all tests if no CUDA GPU +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA GPU required", +) + +DEVICE = "cuda" + + +# ===== Fixtures ===== + +@pytest.fixture(params=[torch.float32, torch.bfloat16]) +def dtype(request): + return request.param + + +@pytest.fixture(params=[64, 128, 256]) +def seqlen(request): + return request.param + + +@pytest.fixture(params=[1, 4]) +def nheads(request): + return request.param + + +@pytest.fixture(params=[1, 2]) +def ngroups(request): + return request.param + + +# ===== Helper functions ===== + +def make_inputs(batch=2, seqlen=128, nheads=4, headdim=16, ngroups=1, + d_state=16, chunk_size=64, dtype=torch.float32, + has_trapezoidal=False, has_rope=False, has_D=True, has_z=True, + has_seq_idx=False, has_initial_states=False, + has_initial_prev_Bx=False, device=DEVICE): + """Generate random test inputs matching mamba3_chunk_scan_combined signature.""" + factory = dict(device=device, dtype=dtype) + + x = torch.randn(batch, seqlen, nheads, headdim, **factory, requires_grad=True) + dt = torch.randn(batch, seqlen, nheads, **factory, requires_grad=True) + A = (-torch.rand(nheads, device=device, dtype=torch.float32)).detach().requires_grad_(True) + B = torch.randn(batch, seqlen, ngroups, d_state, **factory, requires_grad=True) + C = torch.randn(batch, seqlen, ngroups, d_state, **factory, requires_grad=True) + + D = torch.randn(nheads, device=device, dtype=torch.float32, requires_grad=True) if has_D else None + z = torch.randn(batch, seqlen, nheads, headdim, **factory, requires_grad=True) if has_z else None + dt_bias = torch.randn(nheads, device=device, dtype=torch.float32, requires_grad=True) + + gamma = torch.randn(batch, seqlen, nheads, **factory, requires_grad=True) if has_trapezoidal else None + beta = torch.randn(batch, seqlen, nheads, **factory, requires_grad=True) if has_trapezoidal else None + theta = torch.randn(batch, seqlen, nheads, d_state // 2, **factory, requires_grad=True) if has_rope else None + + initial_states = (torch.randn(batch, nheads, headdim, d_state, **factory, requires_grad=True) + if has_initial_states else None) + initial_prev_Bx = (torch.randn(batch, nheads, headdim, d_state, **factory, requires_grad=True) + if has_initial_prev_Bx else None) + + seq_idx = None + if has_seq_idx: + seq_idx = torch.zeros(batch, seqlen, device=device, dtype=torch.long) + for b in range(batch): + mid = seqlen // 2 + seq_idx[b, mid:] = 1 + + return dict( + x=x, dt=dt, A=A, B=B, C=C, chunk_size=chunk_size, + D=D, z=z, dt_bias=dt_bias, + initial_states=initial_states, seq_idx=seq_idx, + dt_softplus=True, dt_limit=(0.0, float("inf")), + return_final_states=True, + gamma=gamma, beta=beta, theta=theta, + initial_prev_Bx=initial_prev_Bx, + ngroups=ngroups, + ) + + +def clone_inputs(inputs): + """Deep-clone inputs so we can run two independent forward passes.""" + cloned = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + c = v.detach().clone() + if v.requires_grad: + c.requires_grad_(True) + cloned[k] = c + else: + cloned[k] = v + return cloned + + +def assert_close(a, b, rtol=None, atol=None, dtype=torch.float32): + """Assert tensors are close with appropriate tolerances per dtype.""" + if a is None and b is None: + return + if rtol is None: + rtol = 1e-3 if dtype == torch.bfloat16 else 1e-5 + if atol is None: + atol = 5e-2 if dtype == torch.bfloat16 else 1e-4 + torch.testing.assert_close(a.float(), b.float(), rtol=rtol, atol=atol) + + +def _reference_recurrence(X, dt, A, B, C, gamma, beta=None, D=None, z=None): + """Step-by-step reference recurrence for SISO. + + Args: + X: (batch, seqlen, nheads, headdim) -- float32 + dt: (batch, seqlen, nheads) -- processed dt (after softplus/bias/clamp) + A: (nheads,) -- negative + B: (batch, seqlen, nheads, dstate) -- at head level + C: (batch, seqlen, nheads, dstate) -- at head level + gamma: (batch, seqlen, nheads) + beta: (batch, seqlen, nheads) or None + D: (nheads,) or None + z: (batch, seqlen, nheads, headdim) or None + Returns: + Y: (batch, seqlen, nheads, headdim) + final_state: (batch, nheads, headdim, dstate) + """ + batch, seqlen, nheads, headdim = X.shape + dstate = B.shape[-1] + + alpha = torch.exp(dt.unsqueeze(-1) * A.float().view(1, 1, nheads, 1)) + h = torch.zeros(batch, nheads, headdim, dstate, device=X.device, dtype=torch.float32) + ys = [] + prev_Bx = None + + for t in range(seqlen): + x_t = X[:, t].float() + B_t = B[:, t].float() + C_t = C[:, t].float() + + Bx_t = torch.einsum("bhp,bhn->bhpn", x_t, B_t) + + alpha_t = alpha[:, t].unsqueeze(-1) + gamma_t = gamma[:, t].float().unsqueeze(-1).unsqueeze(-1) + h = alpha_t * h + gamma_t * Bx_t + + if beta is not None and prev_Bx is not None: + beta_t = beta[:, t].float().unsqueeze(-1).unsqueeze(-1) + h = h + beta_t * prev_Bx + + prev_Bx = Bx_t + + y_t = torch.einsum("bhpn,bhn->bhp", h, C_t) + + if D is not None: + y_t = y_t + X[:, t].float() * D.float().view(1, nheads, 1) + if z is not None: + y_t = y_t * F.silu(z[:, t].float()) + + ys.append(y_t) + + Y = torch.stack(ys, dim=1) + return Y, h + + +# ===== Group 1: Forward Correctness ===== + +class TestChunkedForwardCorrectness: + """Verify mamba3_chunk_scan_combined forward matches step-by-step recurrence.""" + + def _run_forward_check(self, dtype=torch.float32, has_trapezoidal=False, + has_rope=False, has_D=True, has_z=True, + has_seq_idx=False, has_initial_states=False, + has_initial_prev_Bx=False, ngroups=1, + seqlen=128, chunk_size=64, nheads=4): + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_chunk_scan_combined + + torch.manual_seed(42) + inputs = make_inputs( + batch=2, seqlen=seqlen, nheads=nheads, headdim=16, + ngroups=ngroups, d_state=16, chunk_size=chunk_size, + dtype=dtype, has_trapezoidal=has_trapezoidal, has_rope=has_rope, + has_D=has_D, has_z=has_z, has_seq_idx=has_seq_idx, + has_initial_states=has_initial_states, + has_initial_prev_Bx=has_initial_prev_Bx, + ) + + with torch.no_grad(): + out, final_states = mamba3_chunk_scan_combined(**inputs) + + assert out.shape == inputs["x"].shape + assert final_states.shape == (2, nheads, 16, 16) + assert torch.isfinite(out).all(), "Non-finite values in output" + assert torch.isfinite(final_states).all(), "Non-finite values in final states" + + return out, final_states + + def test_euler_mode_fp32(self): + """Basic Euler mode (no trapezoidal) in fp32.""" + self._run_forward_check(dtype=torch.float32, has_trapezoidal=False, + has_rope=False, has_D=False, has_z=False) + + def test_euler_mode_bf16(self): + """Euler mode in bf16.""" + self._run_forward_check(dtype=torch.bfloat16, has_trapezoidal=False, + has_rope=False, has_D=False, has_z=False) + + def test_trapezoidal_mode_fp32(self): + """Trapezoidal discretization in fp32.""" + self._run_forward_check(dtype=torch.float32, has_trapezoidal=True) + + def test_trapezoidal_mode_bf16(self): + """Trapezoidal in bf16.""" + self._run_forward_check(dtype=torch.bfloat16, has_trapezoidal=True) + + def test_with_rope(self): + """Forward with RoPE on B, C.""" + self._run_forward_check(has_rope=True) + + def test_trapezoidal_with_rope(self): + """Trapezoidal + RoPE combined.""" + self._run_forward_check(has_trapezoidal=True, has_rope=True) + + def test_with_D_skip(self): + """Forward with D skip connection.""" + self._run_forward_check(has_D=True, has_z=False) + + def test_with_z_gating(self): + """Forward with z gating (SiLU).""" + self._run_forward_check(has_D=False, has_z=True) + + def test_with_initial_states(self): + """Forward with non-zero initial states.""" + self._run_forward_check(has_initial_states=True) + + def test_with_initial_prev_Bx(self): + """Forward with initial_prev_Bx (trapezoidal lookback init).""" + self._run_forward_check(has_trapezoidal=True, has_initial_prev_Bx=True) + + def test_with_seq_idx(self): + """Forward with document boundaries.""" + self._run_forward_check(has_seq_idx=True) + + def test_trapezoidal_seq_idx(self): + """Trapezoidal + seq_idx: shifted tensors masked at boundaries.""" + self._run_forward_check(has_trapezoidal=True, has_seq_idx=True) + + def test_ngroups_gt_1(self): + """Forward with ngroups > 1 (groups < nheads).""" + self._run_forward_check(ngroups=2, nheads=4) + + def test_various_seqlens(self, seqlen): + """Test different sequence lengths (multiples of chunk_size).""" + self._run_forward_check(seqlen=seqlen, chunk_size=64) + + def test_various_chunk_sizes(self): + """Test chunk_size=32, 64, 128.""" + for cs in [32, 64, 128]: + self._run_forward_check(seqlen=128, chunk_size=cs) + + def test_final_states_match_recurrence(self): + """Verify final states match between chunked SSD and step-by-step recurrence.""" + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_chunk_scan_combined + + torch.manual_seed(42) + batch, seqlen, nheads, headdim, dstate = 2, 64, 4, 16, 16 + chunk_size = 32 + + inputs = make_inputs( + batch=batch, seqlen=seqlen, nheads=nheads, headdim=headdim, + ngroups=1, d_state=dstate, chunk_size=chunk_size, + dtype=torch.float32, has_trapezoidal=False, has_rope=False, + has_D=False, has_z=False, + ) + + with torch.no_grad(): + out_chunked, final_states_chunked = mamba3_chunk_scan_combined(**inputs) + + # Build reference: process dt the same way as mamba3_chunk_scan_combined + dt_proc = inputs["dt"] + inputs["dt_bias"].view(1, 1, nheads) + dt_proc = F.softplus(dt_proc) + + # Expand B, C from groups to heads (ngroups=1 here, so just identity) + B_exp = inputs["B"][:, :, :1].expand(-1, -1, nheads, -1) # ngroups=1 + C_exp = inputs["C"][:, :, :1].expand(-1, -1, nheads, -1) + + with torch.no_grad(): + _, final_ref = _reference_recurrence( + inputs["x"], dt_proc, inputs["A"], B_exp, C_exp, + gamma=dt_proc, # Euler: gamma = dt + ) + + assert_close(final_states_chunked, final_ref, dtype=torch.float32, + atol=5e-3, rtol=5e-3) + + def test_full_config(self): + """All features enabled: trapezoidal + RoPE + D + z + seq_idx + initial_states.""" + self._run_forward_check( + has_trapezoidal=True, has_rope=True, has_D=True, has_z=True, + has_seq_idx=True, has_initial_states=True, + ) + + +# ===== Group 2: Backward Correctness (Gradient comparison) ===== + +class TestChunkedBackwardCorrectness: + """Verify gradients from the chunked SSD path are correct.""" + + def _run_grad_check(self, param_name, dtype=torch.float32, + has_trapezoidal=False, has_rope=False, + has_D=True, has_z=True): + """Run forward+backward and verify gradient of a specific parameter is finite and non-zero.""" + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_chunk_scan_combined + + torch.manual_seed(42) + inputs = make_inputs( + batch=2, seqlen=64, nheads=4, headdim=16, + ngroups=1, d_state=16, chunk_size=32, + dtype=dtype, has_trapezoidal=has_trapezoidal, has_rope=has_rope, + has_D=has_D, has_z=has_z, + ) + + out, _ = mamba3_chunk_scan_combined(**inputs) + loss = out.float().sum() + loss.backward() + + param = inputs[param_name] + assert param is not None, f"Parameter {param_name} is None" + assert param.grad is not None, f"No gradient for {param_name}" + assert torch.isfinite(param.grad).all(), f"Non-finite gradient for {param_name}" + assert param.grad.abs().max() > 0, f"Zero gradient for {param_name}" + + return param.grad + + def test_grad_x(self): + """Gradient w.r.t. x.""" + self._run_grad_check("x") + + def test_grad_dt(self): + """Gradient w.r.t. dt.""" + self._run_grad_check("dt") + + def test_grad_A(self): + """Gradient w.r.t. A.""" + self._run_grad_check("A") + + def test_grad_B(self): + """Gradient w.r.t. B.""" + self._run_grad_check("B") + + def test_grad_C(self): + """Gradient w.r.t. C.""" + self._run_grad_check("C") + + def test_grad_D(self): + """Gradient w.r.t. D.""" + self._run_grad_check("D", has_D=True, has_z=False) + + def test_grad_z(self): + """Gradient w.r.t. z.""" + self._run_grad_check("z", has_D=False, has_z=True) + + def test_grad_gamma(self): + """Gradient w.r.t. gamma (trapezoidal).""" + self._run_grad_check("gamma", has_trapezoidal=True) + + def test_grad_beta(self): + """Gradient w.r.t. beta (trapezoidal).""" + self._run_grad_check("beta", has_trapezoidal=True) + + def test_all_grads_euler(self): + """All gradients correct in Euler mode.""" + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_chunk_scan_combined + + torch.manual_seed(42) + inputs = make_inputs( + batch=2, seqlen=64, nheads=4, headdim=16, + ngroups=1, d_state=16, chunk_size=32, + dtype=torch.float32, has_trapezoidal=False, has_rope=False, + has_D=True, has_z=True, + ) + + out, _ = mamba3_chunk_scan_combined(**inputs) + loss = out.float().sum() + loss.backward() + + for name in ["x", "dt", "A", "B", "C", "D", "z", "dt_bias"]: + param = inputs[name] + if param is not None and param.requires_grad: + assert param.grad is not None, f"No gradient for {name}" + assert torch.isfinite(param.grad).all(), f"Non-finite gradient for {name}" + + def test_all_grads_trapezoidal(self): + """All gradients correct in trapezoidal mode.""" + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_chunk_scan_combined + + torch.manual_seed(42) + inputs = make_inputs( + batch=2, seqlen=64, nheads=4, headdim=16, + ngroups=1, d_state=16, chunk_size=32, + dtype=torch.float32, has_trapezoidal=True, has_rope=False, + has_D=True, has_z=True, + ) + + out, _ = mamba3_chunk_scan_combined(**inputs) + loss = out.float().sum() + loss.backward() + + for name in ["x", "dt", "A", "B", "C", "D", "z", "dt_bias", "gamma", "beta"]: + param = inputs[name] + if param is not None and param.requires_grad: + assert param.grad is not None, f"No gradient for {name}" + assert torch.isfinite(param.grad).all(), f"Non-finite gradient for {name}" + + def test_all_grads_full_config(self): + """All gradients correct with all features enabled.""" + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_chunk_scan_combined + + torch.manual_seed(42) + inputs = make_inputs( + batch=2, seqlen=64, nheads=4, headdim=16, + ngroups=1, d_state=16, chunk_size=32, + dtype=torch.float32, has_trapezoidal=True, has_rope=True, + has_D=True, has_z=True, has_initial_states=True, + ) + + out, _ = mamba3_chunk_scan_combined(**inputs) + loss = out.float().sum() + loss.backward() + + for name in ["x", "dt", "A", "B", "C", "D", "z", "dt_bias", + "gamma", "beta", "theta", "initial_states"]: + param = inputs[name] + if param is not None and param.requires_grad: + assert param.grad is not None, f"No gradient for {name}" + assert torch.isfinite(param.grad).all(), f"Non-finite gradient for {name}" + + def test_grad_bf16(self): + """Gradient correctness in bf16 (relaxed tolerance).""" + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_chunk_scan_combined + + torch.manual_seed(42) + inputs = make_inputs( + batch=2, seqlen=64, nheads=4, headdim=16, + ngroups=1, d_state=16, chunk_size=32, + dtype=torch.bfloat16, has_trapezoidal=True, has_rope=True, + has_D=True, has_z=True, + ) + + out, _ = mamba3_chunk_scan_combined(**inputs) + loss = out.float().sum() + loss.backward() + + for name in ["x", "dt", "B", "C", "gamma", "beta"]: + param = inputs[name] + if param is not None and param.requires_grad: + assert param.grad is not None, f"No gradient for {name} in bf16" + assert torch.isfinite(param.grad).all(), f"Non-finite gradient for {name} in bf16" + + +# ===== Group 3: Triton Decode Kernel Tests ===== + +class TestTritonDecodeKernel: + """Test mamba3_state_update Triton kernel against PyTorch reference.""" + + def _make_decode_inputs(self, batch=2, nheads=4, dim=16, dstate=16, + ngroups=1, dtype=torch.float32, is_mimo=False, + mimo_rank=2, has_D=False, has_z=False, + has_trapezoidal=False): + """Create inputs for a single decode step.""" + factory = dict(device=DEVICE, dtype=dtype) + state = torch.randn(batch, nheads, dim, dstate, **factory) + + if is_mimo: + x = torch.randn(batch, nheads, dim, mimo_rank, **factory) + B = torch.randn(batch, ngroups, dstate, mimo_rank, **factory) + C = torch.randn(batch, ngroups, dstate, mimo_rank, **factory) + else: + x = torch.randn(batch, nheads, dim, **factory) + B = torch.randn(batch, ngroups, dstate, **factory) + C = torch.randn(batch, ngroups, dstate, **factory) + + dt = torch.randn(batch, nheads, **factory) + A = -torch.rand(nheads, device=DEVICE, dtype=torch.float32) + dt_bias = torch.randn(nheads, device=DEVICE, dtype=torch.float32) + + D = torch.randn(nheads, device=DEVICE, dtype=torch.float32) if has_D else None + z = torch.randn(batch, nheads, dim, **factory) if (has_z and not is_mimo) else None + + prev_Bx = torch.randn(batch, nheads, dim, dstate, **factory) if has_trapezoidal else None + beta_val = torch.randn(batch, nheads, **factory) if has_trapezoidal else None + gamma_val = torch.randn(batch, nheads, **factory) if has_trapezoidal else None + + return dict( + state=state, x=x, dt=dt, A=A, B=B, C=C, + D=D, z=z, dt_bias=dt_bias, dt_softplus=True, + prev_Bx=prev_Bx, beta=beta_val, gamma=gamma_val, + ) + + def _run_triton_vs_ref(self, **kwargs): + """Run both Triton and PyTorch reference decode, compare outputs.""" + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_state_update, _mamba3_state_update_ref + + torch.manual_seed(42) + inputs = self._make_decode_inputs(**kwargs) + dtype = inputs["x"].dtype + + # Clone state for reference (state is modified in-place) + import copy + inputs_ref = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + inputs_ref[k] = v.clone() + else: + inputs_ref[k] = v + state_ref = inputs_ref["state"] + prev_Bx_ref = inputs_ref["prev_Bx"].clone() if inputs_ref["prev_Bx"] is not None else None + inputs_ref["state"] = state_ref + inputs_ref["prev_Bx"] = prev_Bx_ref + + # Triton path + out_triton = mamba3_state_update(**inputs) + + # Reference path + out_ref = _mamba3_state_update_ref(**inputs_ref) + + tol_rtol = 5e-3 if dtype == torch.bfloat16 else 1e-4 + tol_atol = 1e-1 if dtype == torch.bfloat16 else 1e-3 + torch.testing.assert_close(out_triton.float(), out_ref.float(), + rtol=tol_rtol, atol=tol_atol) + + # Also compare updated states + torch.testing.assert_close(inputs["state"].float(), state_ref.float(), + rtol=tol_rtol, atol=tol_atol) + + def test_siso_euler_fp32(self): + """SISO Euler decode in fp32.""" + self._run_triton_vs_ref(dtype=torch.float32) + + def test_siso_euler_bf16(self): + """SISO Euler decode in bf16.""" + self._run_triton_vs_ref(dtype=torch.bfloat16) + + def test_siso_with_D(self): + """SISO with D skip connection.""" + self._run_triton_vs_ref(has_D=True) + + def test_siso_with_z(self): + """SISO with z gating.""" + self._run_triton_vs_ref(has_z=True) + + def test_siso_with_D_and_z(self): + """SISO with both D and z.""" + self._run_triton_vs_ref(has_D=True, has_z=True) + + def test_siso_trapezoidal(self): + """SISO trapezoidal decode (gamma, beta, prev_Bx).""" + self._run_triton_vs_ref(has_trapezoidal=True) + + def test_siso_trapezoidal_D_z(self): + """SISO trapezoidal with D and z.""" + self._run_triton_vs_ref(has_trapezoidal=True, has_D=True, has_z=True) + + def test_mimo_euler(self): + """MIMO Euler decode.""" + self._run_triton_vs_ref(is_mimo=True, mimo_rank=2) + + def test_mimo_trapezoidal(self): + """MIMO trapezoidal decode.""" + self._run_triton_vs_ref(is_mimo=True, mimo_rank=2, has_trapezoidal=True) + + def test_mimo_rank4(self): + """MIMO with rank 4.""" + self._run_triton_vs_ref(is_mimo=True, mimo_rank=4) + + def test_ngroups_gt_1(self): + """Decode with ngroups > 1.""" + self._run_triton_vs_ref(nheads=4, ngroups=2) + + def test_prev_Bx_updated_correctly(self): + """Verify prev_Bx buffer is updated with current unscaled Bx after decode step.""" + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_state_update, _mamba3_state_update_ref + + torch.manual_seed(42) + inputs = self._make_decode_inputs(has_trapezoidal=True) + import copy + inputs_ref = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + inputs_ref[k] = v.clone() + else: + inputs_ref[k] = v + + _ = mamba3_state_update(**inputs) + _ = _mamba3_state_update_ref(**inputs_ref) + + # prev_Bx should be updated in both + torch.testing.assert_close( + inputs["prev_Bx"].float(), inputs_ref["prev_Bx"].float(), + rtol=1e-4, atol=1e-3, + ) + + +# ===== Group 4: Chunked SSD Component Tests ===== + +class TestChunkedSSDComponents: + """Test chunked SSD building blocks: state computation, output assembly.""" + + def test_euler_state_accumulation(self): + """Chunk state accumulation matches recurrence in Euler mode.""" + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_ssd_chunked + + torch.manual_seed(42) + batch, seqlen, nheads, headdim, dstate = 2, 64, 4, 16, 16 + chunk_size = 32 + + X = torch.randn(batch, seqlen, nheads, headdim, device=DEVICE) + dt = torch.rand(batch, seqlen, nheads, device=DEVICE) * 0.1 + 0.01 + A = -torch.rand(nheads, device=DEVICE) * 5 - 1 + B = torch.randn(batch, seqlen, nheads, dstate, device=DEVICE) + C = torch.randn(batch, seqlen, nheads, dstate, device=DEVICE) + + Y_chunked, final_state = mamba3_ssd_chunked( + X, dt, A, B, C, block_len=chunk_size, return_final_states=True, + ) + + # Recurrence + Y_ref, h_ref = _reference_recurrence(X, dt, A, B, C, gamma=dt) + + torch.testing.assert_close(Y_chunked.float(), Y_ref.float(), atol=1e-3, rtol=1e-3) + torch.testing.assert_close(final_state.float(), h_ref.float(), atol=5e-3, rtol=5e-3) + + def test_trapezoidal_state_accumulation(self): + """Chunk state matches recurrence with trapezoidal discretization.""" + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_ssd_chunked + + torch.manual_seed(42) + batch, seqlen, nheads, headdim, dstate = 2, 64, 4, 16, 16 + chunk_size = 32 + + X = torch.randn(batch, seqlen, nheads, headdim, device=DEVICE) + dt = torch.rand(batch, seqlen, nheads, device=DEVICE) * 0.1 + 0.01 + A = -torch.rand(nheads, device=DEVICE) * 5 - 1 + B = torch.randn(batch, seqlen, nheads, dstate, device=DEVICE) + C = torch.randn(batch, seqlen, nheads, dstate, device=DEVICE) + + lam = torch.sigmoid(torch.randn(batch, seqlen, nheads, device=DEVICE)) + gamma = lam * dt + beta = (1 - lam) * dt * torch.exp(dt * A.view(1, 1, nheads)) + + Y_chunked = mamba3_ssd_chunked( + X, dt, A, B, C, block_len=chunk_size, gamma=gamma, beta=beta, + ) + Y_ref, _ = _reference_recurrence(X, dt, A, B, C, gamma=gamma, beta=beta) + + torch.testing.assert_close(Y_chunked.float(), Y_ref.float(), atol=1e-3, rtol=1e-3) + + def test_seq_idx_state_reset(self): + """Chunk state correctly resets at document boundaries.""" + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_ssd_chunked + + torch.manual_seed(42) + batch, seqlen, nheads, headdim, dstate = 1, 64, 2, 8, 8 + chunk_size = 32 + + X = torch.randn(batch, seqlen, nheads, headdim, device=DEVICE) + dt = torch.rand(batch, seqlen, nheads, device=DEVICE) * 0.1 + 0.01 + A = -torch.rand(nheads, device=DEVICE) * 5 - 1 + B = torch.randn(batch, seqlen, nheads, dstate, device=DEVICE) + C = torch.randn(batch, seqlen, nheads, dstate, device=DEVICE) + + # Two documents + seq_idx = torch.zeros(batch, seqlen, device=DEVICE, dtype=torch.long) + seq_idx[:, seqlen // 2:] = 1 + + Y_packed = mamba3_ssd_chunked( + X, dt, A, B, C, block_len=chunk_size, seq_idx=seq_idx, + ) + + # Run second half separately (should match packed result) + half = seqlen // 2 + Y_doc2_sep = mamba3_ssd_chunked( + X[:, half:], dt[:, half:], A, B[:, half:], C[:, half:], + block_len=chunk_size, + ) + + torch.testing.assert_close( + Y_packed[:, half:].float(), Y_doc2_sep.float(), + atol=1e-4, rtol=1e-4, + ) + + def test_D_skip_connection(self): + """D skip connection adds x * D to output.""" + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_ssd_chunked + + torch.manual_seed(42) + batch, seqlen, nheads, headdim, dstate = 2, 64, 4, 16, 16 + chunk_size = 32 + + X = torch.randn(batch, seqlen, nheads, headdim, device=DEVICE) + dt = torch.rand(batch, seqlen, nheads, device=DEVICE) * 0.1 + 0.01 + A = -torch.rand(nheads, device=DEVICE) * 5 - 1 + B = torch.randn(batch, seqlen, nheads, dstate, device=DEVICE) + C = torch.randn(batch, seqlen, nheads, dstate, device=DEVICE) + D = torch.randn(nheads, device=DEVICE) + + Y_no_D = mamba3_ssd_chunked(X, dt, A, B, C, block_len=chunk_size) + Y_with_D = mamba3_ssd_chunked(X, dt, A, B, C, block_len=chunk_size, D=D) + + # Difference should be X * D + diff = (Y_with_D - Y_no_D).float() + expected_diff = (X.float() * D.float().view(1, 1, nheads, 1)) + + torch.testing.assert_close(diff, expected_diff, atol=1e-4, rtol=1e-4) + + +# ===== Group 5: Integration Tests ===== + +class TestModuleIntegration: + """Test Triton paths through the full Mamba3 module.""" + + def test_module_forward_finite(self): + """Mamba3 module forward produces finite output on GPU.""" + from mamba_ssm.modules.mamba3 import Mamba3 + + torch.manual_seed(42) + model = Mamba3( + d_model=128, d_state=16, expand=2, headdim=32, + ngroups=1, use_rope=True, use_trapezoidal=True, + use_bc_norm=True, use_bc_bias=True, mimo_rank=0, + chunk_size=64, layer_idx=0, device=DEVICE, dtype=torch.float32, + ).eval() + + u = torch.randn(2, 128, 128, device=DEVICE) + with torch.no_grad(): + y = model(u) + assert y.shape == u.shape + assert torch.isfinite(y).all() + + def test_module_backward_finite(self): + """Mamba3 module backward produces finite gradients.""" + from mamba_ssm.modules.mamba3 import Mamba3 + + torch.manual_seed(42) + model = Mamba3( + d_model=64, d_state=16, expand=2, headdim=16, + ngroups=1, use_rope=True, use_trapezoidal=True, + use_bc_norm=True, use_bc_bias=True, mimo_rank=0, + chunk_size=32, layer_idx=0, device=DEVICE, dtype=torch.bfloat16, + ).train() + + u = torch.randn(2, 64, 64, device=DEVICE, dtype=torch.bfloat16, requires_grad=True) + y = model(u) + y.sum().backward() + + assert u.grad is not None + assert torch.isfinite(u.grad).all() + graded = sum(1 for p in model.parameters() if p.grad is not None) + assert graded > 0 + + def test_lm_model_forward(self): + """Full LM model with Mamba-3 layers produces valid logits.""" + from mamba_ssm.models.config_mamba import MambaConfig + from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + + config = MambaConfig( + d_model=128, n_layer=2, vocab_size=256, + ssm_cfg={"layer": "Mamba3", "d_state": 16, "headdim": 32, + "use_rope": True, "use_trapezoidal": True}, + rms_norm=True, fused_add_norm=False, + ) + model = MambaLMHeadModel(config, device=DEVICE, dtype=torch.bfloat16) + input_ids = torch.randint(0, 256, (2, 64), device=DEVICE) + output = model(input_ids) + assert output.logits.shape == (2, 64, 256) + assert torch.isfinite(output.logits).all() + + def test_lm_model_gradient(self): + """Full LM model gradient flows through Mamba-3 layers.""" + from mamba_ssm.models.config_mamba import MambaConfig + from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + + config = MambaConfig( + d_model=128, n_layer=2, vocab_size=256, + ssm_cfg={"layer": "Mamba3", "d_state": 16, "headdim": 32}, + rms_norm=True, fused_add_norm=False, + ) + model = MambaLMHeadModel(config, device=DEVICE, dtype=torch.bfloat16) + input_ids = torch.randint(0, 256, (2, 32), device=DEVICE) + output = model(input_ids) + loss = output.logits.float().sum() + loss.backward() + graded = sum(1 for p in model.parameters() if p.grad is not None) + assert graded > 0 + + +# ===== Group 6: Fallback Tests ===== + +class TestTritonFallback: + """Test graceful fallback to PyTorch when Triton is unavailable.""" + + def test_cpu_uses_pytorch_path(self): + """CPU inputs use PyTorch reference path for decode.""" + from mamba_ssm.ops.triton.mamba3_ssd import _mamba3_state_update_ref + + batch, nheads, dim, dstate = 2, 4, 16, 16 + state = torch.randn(batch, nheads, dim, dstate) + x = torch.randn(batch, nheads, dim) + dt = torch.randn(batch, nheads) + A = -torch.rand(nheads) + B = torch.randn(batch, 1, dstate) + C = torch.randn(batch, 1, dstate) + + # Reference should work on CPU without error + out = _mamba3_state_update_ref(state, x, dt, A, B, C) + assert out.shape == (batch, nheads, dim) + assert torch.isfinite(out).all() + + def test_mimo_decode_works(self): + """MIMO decode works through the Triton kernel.""" + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_state_update + + batch, nheads, dim, dstate, mr = 2, 4, 16, 16, 2 + state = torch.randn(batch, nheads, dim, dstate, device=DEVICE) + x = torch.randn(batch, nheads, dim, mr, device=DEVICE) + dt = torch.randn(batch, nheads, device=DEVICE) + A = -torch.rand(nheads, device=DEVICE) + B = torch.randn(batch, 1, dstate, mr, device=DEVICE) + C = torch.randn(batch, 1, dstate, mr, device=DEVICE) + + out = mamba3_state_update(state, x, dt, A, B, C) + assert out.shape == (batch, nheads, dim, mr) + assert torch.isfinite(out).all() + + def test_chunk_scan_combined_pads_seqlen(self): + """mamba3_chunk_scan_combined handles non-divisible sequence lengths via padding.""" + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_chunk_scan_combined + + torch.manual_seed(42) + # seqlen=100 is not divisible by chunk_size=64 + inputs = make_inputs( + batch=1, seqlen=100, nheads=4, headdim=16, + ngroups=1, d_state=16, chunk_size=64, + dtype=torch.float32, has_trapezoidal=False, has_rope=False, + has_D=False, has_z=False, + ) + + with torch.no_grad(): + out, final_states = mamba3_chunk_scan_combined(**inputs) + + assert out.shape == (1, 100, 4, 16), f"Expected (1,100,4,16), got {out.shape}" + assert torch.isfinite(out).all() + + +# ===== Group 7: Consistency between two independent runs ===== + +class TestDeterminism: + """Verify that repeated runs produce identical results.""" + + def test_chunked_deterministic(self): + """Two runs with same seed produce identical output.""" + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_chunk_scan_combined + + results = [] + for _ in range(2): + torch.manual_seed(42) + inputs = make_inputs( + batch=2, seqlen=64, nheads=4, headdim=16, + ngroups=1, d_state=16, chunk_size=32, + dtype=torch.float32, has_trapezoidal=True, has_rope=True, + has_D=True, has_z=True, + ) + with torch.no_grad(): + out, states = mamba3_chunk_scan_combined(**inputs) + results.append((out.clone(), states.clone())) + + torch.testing.assert_close(results[0][0], results[1][0]) + torch.testing.assert_close(results[0][1], results[1][1]) + + def test_decode_deterministic(self): + """Triton decode kernel is deterministic.""" + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_state_update + + results = [] + for _ in range(2): + torch.manual_seed(42) + batch, nheads, dim, dstate = 2, 4, 16, 16 + state = torch.randn(batch, nheads, dim, dstate, device=DEVICE) + x = torch.randn(batch, nheads, dim, device=DEVICE) + dt = torch.randn(batch, nheads, device=DEVICE) + A = -torch.rand(nheads, device=DEVICE) + B = torch.randn(batch, 1, dstate, device=DEVICE) + C = torch.randn(batch, 1, dstate, device=DEVICE) + + out = mamba3_state_update(state, x, dt, A, B, C) + results.append(out.clone()) + + torch.testing.assert_close(results[0], results[1]) + + +# ===== Group 8: Triton Combined Forward/Backward (mamba3_chunk_scan_combined_triton) ===== + +class TestTritonCombined: + """Test mamba3_chunk_scan_combined_triton against the PyTorch reference.""" + + def _run_triton_vs_ref(self, dtype=torch.float32, has_trapezoidal=False, + has_rope=False, has_D=False, has_z=False, + has_seq_idx=False, has_initial_states=False, + has_initial_prev_Bx=False, ngroups=1, + seqlen=128, chunk_size=64, nheads=4): + """Run both Triton and reference forward, compare outputs. + + The Triton pipeline accumulates numerical differences from multiple kernel + stages (chunk_cumsum, bmm, chunk_state, state_passing, chunk_scan) which + each use reduced-precision dot products. Tolerances are therefore relaxed + compared to single-kernel tests. + """ + from mamba_ssm.ops.triton.mamba3_ssd import mamba3_chunk_scan_combined + from mamba_ssm.ops.triton.mamba3_combined import mamba3_chunk_scan_combined_triton + + torch.manual_seed(42) + inputs = make_inputs( + batch=2, seqlen=seqlen, nheads=nheads, headdim=16, + ngroups=ngroups, d_state=16, chunk_size=chunk_size, + dtype=dtype, has_trapezoidal=has_trapezoidal, has_rope=has_rope, + has_D=has_D, has_z=has_z, has_seq_idx=has_seq_idx, + has_initial_states=has_initial_states, + has_initial_prev_Bx=has_initial_prev_Bx, + ) + inputs_ref = clone_inputs(inputs) + + with torch.no_grad(): + out_triton, fs_triton = mamba3_chunk_scan_combined_triton(**inputs) + out_ref, fs_ref = mamba3_chunk_scan_combined(**inputs_ref) + + # Triton kernels use reduced-precision dot products (tf32/bf16 accumulators) + # across multiple pipeline stages, so numerical differences are expected. + # fp32: rtol=5e-3, atol=0.1 allows ~0.5% relative and 0.1 absolute error. + # bf16: rtol=1e-2, atol=0.2 allows larger differences from bf16 accumulation. + tol_rtol = 1e-2 if dtype == torch.bfloat16 else 5e-3 + tol_atol = 2e-1 if dtype == torch.bfloat16 else 1e-1 + + torch.testing.assert_close(out_triton.float(), out_ref.float(), + rtol=tol_rtol, atol=tol_atol) + if fs_triton is not None and fs_ref is not None: + torch.testing.assert_close(fs_triton.float(), fs_ref.float(), + rtol=tol_rtol, atol=tol_atol) + + def test_euler_fp32(self): + """Euler mode in fp32.""" + self._run_triton_vs_ref(dtype=torch.float32, has_trapezoidal=False) + + def test_euler_bf16(self): + """Euler mode in bf16.""" + self._run_triton_vs_ref(dtype=torch.bfloat16, has_trapezoidal=False) + + def test_trapezoidal_fp32(self): + """Trapezoidal mode in fp32.""" + self._run_triton_vs_ref(dtype=torch.float32, has_trapezoidal=True) + + def test_trapezoidal_bf16(self): + """Trapezoidal mode in bf16.""" + self._run_triton_vs_ref(dtype=torch.bfloat16, has_trapezoidal=True) + + def test_with_rope(self): + """Forward with RoPE.""" + self._run_triton_vs_ref(has_rope=True) + + def test_trapezoidal_rope(self): + """Trapezoidal + RoPE.""" + self._run_triton_vs_ref(has_trapezoidal=True, has_rope=True) + + def test_with_D(self): + """Forward with D skip.""" + self._run_triton_vs_ref(has_D=True) + + def test_with_initial_states(self): + """Forward with initial states.""" + self._run_triton_vs_ref(has_initial_states=True) + + def test_with_seq_idx(self): + """Forward with document boundaries.""" + self._run_triton_vs_ref(has_seq_idx=True) + + def test_ngroups_gt_1(self): + """Forward with ngroups > 1.""" + self._run_triton_vs_ref(ngroups=2, nheads=4) + + def test_triton_gradient(self): + """Verify gradients through the Triton combined function.""" + from mamba_ssm.ops.triton.mamba3_combined import mamba3_chunk_scan_combined_triton + + torch.manual_seed(42) + inputs = make_inputs( + batch=2, seqlen=64, nheads=4, headdim=16, + ngroups=1, d_state=16, chunk_size=32, + dtype=torch.float32, has_trapezoidal=True, has_rope=True, + has_D=True, has_z=False, + ) + + out, _ = mamba3_chunk_scan_combined_triton(**inputs) + loss = out.float().sum() + loss.backward() + + for name in ["x", "dt", "A", "B", "C", "D", "dt_bias", "gamma", "beta"]: + param = inputs[name] + if param is not None and param.requires_grad: + assert param.grad is not None, f"No gradient for {name}" + assert torch.isfinite(param.grad).all(), f"Non-finite gradient for {name}" + assert param.grad.abs().max() > 0, f"Zero gradient for {name}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/test_mamba3_triton_bwd.py b/tests/test_mamba3_triton_bwd.py new file mode 100644 index 000000000..ba5370c1f --- /dev/null +++ b/tests/test_mamba3_triton_bwd.py @@ -0,0 +1,837 @@ +"""GPU tests for Mamba-3 Triton backward pass correctness. + +Compares gradients from the Triton backward pipeline (mamba3_chunk_scan_combined_triton +from mamba3_combined.py) against the PyTorch reference backward (mamba3_chunk_scan_combined +from mamba3_ssd.py) using manual gradient comparison. + +Test strategy: +1. Create identical inputs for both paths (with requires_grad=True). +2. Run forward + backward on both. +3. Compare all gradients (dx, ddt, dA, dB, dC, dgamma, dbeta, dD, dz, dtheta, etc.) + with appropriate tolerances per dtype. + +All tests require an NVIDIA CUDA GPU. +""" + +import pytest +import torch +import torch.nn.functional as F + +DEVICE = "cuda" + +# Skip entire module if CUDA is not available. +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA required" +) + +# Tolerances per dtype. +FP32_ATOL = 5e-3 +FP32_RTOL = 5e-2 +BF16_ATOL = 5e-3 +BF16_RTOL = 5e-2 + + +def _tols(dtype): + if dtype == torch.bfloat16: + return BF16_ATOL, BF16_RTOL + return FP32_ATOL, FP32_RTOL + + +# --------------------------------------------------------------------------- +# Input generation +# --------------------------------------------------------------------------- + +def _make_inputs( + batch=2, + seqlen=128, + nheads=4, + headdim=32, + dstate=16, + ngroups=None, + dtype=torch.float32, + with_gamma_beta=True, + with_theta=False, + with_z=False, + with_D=False, + D_2d=False, + with_dt_bias=False, + dt_softplus=False, + with_seq_idx=False, + with_initial_states=False, + return_final_states=False, + with_initial_prev_Bx=False, + seed=42, +): + """Create test inputs. Returns a dict of tensors (no requires_grad yet).""" + if ngroups is None: + ngroups = nheads + + torch.manual_seed(seed) + device = DEVICE + + x = torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=dtype) + dt = torch.rand(batch, seqlen, nheads, device=device, dtype=dtype) * 0.1 + 0.01 + A = -torch.rand(nheads, device=device, dtype=dtype) * 3 - 0.5 + B = torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=dtype) * 0.1 + C = torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=dtype) * 0.1 + + gamma = beta = None + if with_gamma_beta: + gamma = torch.rand(batch, seqlen, nheads, device=device, dtype=dtype) * 0.1 + 0.01 + beta = torch.rand(batch, seqlen, nheads, device=device, dtype=dtype) * 0.1 + 0.01 + + theta = None + if with_theta: + theta = torch.randn(batch, seqlen, nheads, dstate // 2, device=device, dtype=dtype) * 0.05 + + z = None + if with_z: + z = torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=dtype) + + D = None + if with_D: + if D_2d: + D = torch.randn(nheads, headdim, device=device, dtype=dtype) * 0.1 + else: + D = torch.randn(nheads, device=device, dtype=dtype) * 0.1 + + dt_bias = None + if with_dt_bias: + dt_bias = torch.rand(nheads, device=device, dtype=dtype) * 0.005 + + seq_idx = None + if with_seq_idx: + seq_idx = torch.zeros(batch, seqlen, dtype=torch.long, device=device) + seq_idx[:, seqlen // 2:] = 1 + + initial_states = None + if with_initial_states: + initial_states = torch.randn(batch, nheads, headdim, dstate, device=device, dtype=dtype) * 0.1 + + initial_prev_Bx = None + if with_initial_prev_Bx and with_gamma_beta: + initial_prev_Bx = torch.randn(batch, nheads, headdim, dstate, device=device, dtype=dtype) * 0.1 + + return dict( + x=x, dt=dt, A=A, B=B, C=C, + gamma=gamma, beta=beta, theta=theta, + z=z, D=D, dt_bias=dt_bias, + seq_idx=seq_idx, + initial_states=initial_states, + initial_prev_Bx=initial_prev_Bx, + ngroups=ngroups, + dt_softplus=dt_softplus, + return_final_states=return_final_states, + ) + + +# Names of all tensor inputs that are potentially differentiable. +_DIFF_NAMES = [ + "x", "dt", "A", "B", "C", + "gamma", "beta", "theta", + "z", "D", "dt_bias", + "initial_states", "initial_prev_Bx", +] + + +def _clone_with_grad(inputs): + """Clone all tensors; set requires_grad on differentiable ones.""" + cloned = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor) and k in _DIFF_NAMES: + c = v.detach().clone().requires_grad_(True) + cloned[k] = c + elif isinstance(v, torch.Tensor): + cloned[k] = v.detach().clone() # seq_idx etc. + else: + cloned[k] = v + return cloned + + +# --------------------------------------------------------------------------- +# Core comparison helper +# --------------------------------------------------------------------------- + +def _compare_gradients(chunk_size=64, atol_override=None, rtol_override=None, **input_kwargs): + """Run Triton and PyTorch reference backward, compare all gradients. + + ``input_kwargs`` are forwarded to ``_make_inputs``. + """ + from mamba_ssm.ops.triton.mamba3_combined import mamba3_chunk_scan_combined_triton + from mamba_ssm.ops.triton.mamba3_ssd import ( + mamba3_chunk_scan_combined as mamba3_ref, + ) + + raw = _make_inputs(**input_kwargs) + dtype = input_kwargs.get("dtype", torch.float32) + atol, rtol = _tols(dtype) + if atol_override is not None: + atol = atol_override + if rtol_override is not None: + rtol = rtol_override + return_final_states = raw.get("return_final_states", False) + + # --- Triton path --- + tri = _clone_with_grad(raw) + tri_result = mamba3_chunk_scan_combined_triton( + tri["x"], tri["dt"], tri["A"], tri["B"], tri["C"], chunk_size, + gamma=tri["gamma"], beta=tri["beta"], theta=tri["theta"], + D=tri["D"], z=tri["z"], dt_bias=tri["dt_bias"], + dt_softplus=tri.get("dt_softplus", False), + initial_states=tri["initial_states"], + initial_prev_Bx=tri["initial_prev_Bx"], + return_final_states=return_final_states, + ngroups=tri["ngroups"], + seq_idx=tri.get("seq_idx"), + ) + if return_final_states: + tri_out, tri_final = tri_result + tri_loss = tri_out.float().sum() + tri_final.float().sum() + else: + tri_out = tri_result + tri_loss = tri_out.float().sum() + tri_loss.backward() + + # --- PyTorch reference path --- + ref = _clone_with_grad(raw) + ref_result = mamba3_ref( + ref["x"], ref["dt"], ref["A"], ref["B"], ref["C"], chunk_size, + gamma=ref["gamma"], beta=ref["beta"], theta=ref["theta"], + D=ref["D"], z=ref["z"], dt_bias=ref["dt_bias"], + dt_softplus=ref.get("dt_softplus", False), + initial_states=ref["initial_states"], + initial_prev_Bx=ref["initial_prev_Bx"], + return_final_states=return_final_states, + ngroups=ref["ngroups"], + seq_idx=ref.get("seq_idx"), + ) + if return_final_states: + ref_out, ref_final = ref_result + ref_loss = ref_out.float().sum() + ref_final.float().sum() + else: + ref_out = ref_result + ref_loss = ref_out.float().sum() + ref_loss.backward() + + # --- Forward comparison --- + torch.testing.assert_close( + tri_out.float(), ref_out.float(), atol=atol, rtol=rtol, + msg="Forward output mismatch", + ) + if return_final_states: + torch.testing.assert_close( + tri_final.float(), ref_final.float(), atol=atol, rtol=rtol, + msg="Final state mismatch", + ) + + # --- Gradient comparison --- + for name in _DIFF_NAMES: + t_tensor = tri.get(name) + r_tensor = ref.get(name) + if t_tensor is None or r_tensor is None: + continue + t_grad = t_tensor.grad + r_grad = r_tensor.grad + if t_grad is None and r_grad is None: + continue + assert t_grad is not None, f"Triton grad for '{name}' is None but ref grad exists" + assert r_grad is not None, f"Ref grad for '{name}' is None but Triton grad exists" + assert torch.isfinite(t_grad).all(), f"Non-finite Triton grad for '{name}'" + assert torch.isfinite(r_grad).all(), f"Non-finite ref grad for '{name}'" + torch.testing.assert_close( + t_grad.float(), r_grad.float(), atol=atol, rtol=rtol, + msg=f"Gradient mismatch for '{name}'", + ) + + +# ============================================================================ +# 1. Basic SISO (ngroups=nheads, rank=1) +# ============================================================================ + +class TestBasicSISO: + """Basic SISO backward: trapezoidal (Mamba-3) and Euler (Mamba-2 fallback).""" + + def test_trapezoidal_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + ) + + def test_trapezoidal_bf16(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.bfloat16, with_gamma_beta=True, + ) + + def test_euler_fp32(self): + """Mamba-2 fallback: no gamma/beta.""" + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=False, + ) + + def test_euler_bf16(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.bfloat16, with_gamma_beta=False, + ) + + +# ============================================================================ +# 2. With RoPE (theta parameter) +# ============================================================================ + +class TestWithRoPE: + """Backward with RoPE rotary embeddings on B and C.""" + + def test_rope_trapezoidal_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, with_theta=True, + ) + + def test_rope_trapezoidal_bf16(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.bfloat16, with_gamma_beta=True, with_theta=True, + ) + + def test_rope_euler_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=False, with_theta=True, + ) + + def test_rope_ngroups_lt_nheads_fp32(self): + """RoPE with ngroups < nheads (group-to-head expansion).""" + _compare_gradients( + batch=2, seqlen=128, nheads=8, headdim=32, dstate=16, + ngroups=2, dtype=torch.float32, + with_gamma_beta=True, with_theta=True, + ) + + +# ============================================================================ +# 3. With z gating (SiLU) +# ============================================================================ + +class TestWithZGating: + + def test_z_trapezoidal_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, with_z=True, + ) + + def test_z_trapezoidal_bf16(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.bfloat16, with_gamma_beta=True, with_z=True, + ) + + def test_z_euler_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=False, with_z=True, + ) + + +# ============================================================================ +# 4. With D skip connection +# ============================================================================ + +class TestWithDSkip: + + def test_D_1d_trapezoidal_fp32(self): + """D as (nheads,) vector.""" + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, with_D=True, D_2d=False, + ) + + def test_D_2d_trapezoidal_fp32(self): + """D as (nheads, headdim) matrix.""" + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, with_D=True, D_2d=True, + ) + + def test_D_1d_bf16(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.bfloat16, with_gamma_beta=True, with_D=True, D_2d=False, + ) + + def test_D_with_z_fp32(self): + """D + z gating combined.""" + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + with_D=True, with_z=True, + ) + + +# ============================================================================ +# 5. With seq_idx (packed multi-document sequences) +# ============================================================================ + +class TestWithSeqIdx: + + def test_seq_idx_trapezoidal_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, with_seq_idx=True, + ) + + def test_seq_idx_trapezoidal_bf16(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.bfloat16, with_gamma_beta=True, with_seq_idx=True, + ) + + def test_seq_idx_euler_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=False, with_seq_idx=True, + ) + + def test_seq_idx_with_rope_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, + with_gamma_beta=True, with_theta=True, with_seq_idx=True, + ) + + +# ============================================================================ +# 6. With initial_states / return_final_states +# ============================================================================ + +class TestWithInitialStates: + + def test_initial_states_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + with_initial_states=True, return_final_states=True, + ) + + def test_initial_states_bf16(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.bfloat16, with_gamma_beta=True, + with_initial_states=True, return_final_states=True, + ) + + def test_return_final_states_only_fp32(self): + """return_final_states without providing initial_states.""" + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + with_initial_states=False, return_final_states=True, + ) + + def test_initial_prev_Bx_fp32(self): + """initial_prev_Bx for trapezoidal lookback at t=0.""" + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + with_initial_prev_Bx=True, + ) + + def test_initial_states_and_prev_Bx_fp32(self): + """Both initial_states and initial_prev_Bx.""" + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + with_initial_states=True, with_initial_prev_Bx=True, + return_final_states=True, + ) + + +# ============================================================================ +# 7. Trapezoidal mode (gamma+beta) -- core Mamba-3 +# ============================================================================ + +class TestTrapezoidal: + """Dedicated trapezoidal-mode tests at various configurations.""" + + def test_trapezoidal_small_chunk_fp32(self): + _compare_gradients( + chunk_size=32, + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + ) + + def test_trapezoidal_large_dstate_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=64, + dtype=torch.float32, with_gamma_beta=True, + ) + + def test_trapezoidal_with_dt_bias_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, with_dt_bias=True, + ) + + def test_trapezoidal_with_dt_softplus_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + with_dt_bias=True, dt_softplus=True, + ) + + +# ============================================================================ +# 8. Mamba-2 fallback (no gamma/beta) +# ============================================================================ + +class TestMamba2Fallback: + """Euler discretization (Mamba-2 compatible) backward tests.""" + + def test_euler_with_D_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=False, with_D=True, + ) + + def test_euler_with_z_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=False, with_z=True, + ) + + def test_euler_with_initial_states_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=False, + with_initial_states=True, return_final_states=True, + ) + + def test_euler_with_rope_bf16(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.bfloat16, with_gamma_beta=False, with_theta=True, + ) + + +# ============================================================================ +# 9. Different shapes +# ============================================================================ + +class TestDifferentShapes: + + def test_seqlen_256_nheads_4_fp32(self): + _compare_gradients( + batch=2, seqlen=256, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + ) + + def test_nheads_8_headdim_64_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=8, headdim=64, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + ) + + def test_dstate_64_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=64, + dtype=torch.float32, with_gamma_beta=True, + ) + + def test_ngroups_1_nheads_8_fp32(self): + """Large head-to-group ratio.""" + _compare_gradients( + batch=2, seqlen=128, nheads=8, headdim=32, dstate=16, + ngroups=1, dtype=torch.float32, with_gamma_beta=True, + ) + + def test_ngroups_2_nheads_8_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=8, headdim=32, dstate=16, + ngroups=2, dtype=torch.float32, with_gamma_beta=True, + ) + + def test_small_shapes_fp32(self): + """Minimal configuration.""" + _compare_gradients( + chunk_size=32, + batch=1, seqlen=64, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + ) + + def test_large_shapes_bf16(self): + """Larger configuration in bf16 — wider tolerance for bf16 accumulation noise.""" + _compare_gradients( + batch=2, seqlen=256, nheads=8, headdim=64, dstate=64, + dtype=torch.bfloat16, with_gamma_beta=True, + atol_override=0.05, rtol_override=0.1, + ) + + def test_headdim_64_dstate_64_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=64, dstate=64, + dtype=torch.float32, with_gamma_beta=True, + ) + + +# ============================================================================ +# 10. dt processing (bias + softplus) +# ============================================================================ + +class TestDtProcessing: + + def test_dt_bias_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, with_dt_bias=True, + ) + + def test_dt_softplus_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + dt_softplus=True, + ) + + def test_dt_bias_and_softplus_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + with_dt_bias=True, dt_softplus=True, + ) + + def test_dt_bias_bf16(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.bfloat16, with_gamma_beta=True, + with_dt_bias=True, + ) + + +# ============================================================================ +# 11. Combined features +# ============================================================================ + +class TestCombinedFeatures: + """Tests combining multiple optional features at once.""" + + def test_all_features_fp32(self): + """gamma/beta + RoPE + z + D + dt_bias + softplus + initial_states + final_states.""" + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, + with_gamma_beta=True, with_theta=True, + with_z=True, with_D=True, with_dt_bias=True, + dt_softplus=True, + with_initial_states=True, return_final_states=True, + ) + + def test_all_features_bf16(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.bfloat16, + with_gamma_beta=True, with_theta=True, + with_z=True, with_D=True, with_dt_bias=True, + dt_softplus=True, + with_initial_states=True, return_final_states=True, + ) + + def test_trapezoidal_rope_seq_idx_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, + with_gamma_beta=True, with_theta=True, with_seq_idx=True, + ) + + def test_trapezoidal_D_z_initial_prev_Bx_fp32(self): + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, + with_gamma_beta=True, with_D=True, with_z=True, + with_initial_states=True, with_initial_prev_Bx=True, + return_final_states=True, + ) + + def test_euler_all_extras_fp32(self): + """Euler mode with all other optional features.""" + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, + with_gamma_beta=False, with_theta=True, + with_z=True, with_D=True, with_dt_bias=True, + dt_softplus=True, + with_initial_states=True, return_final_states=True, + ) + + def test_ngroups_with_all_features_fp32(self): + """ngroups < nheads with all features.""" + _compare_gradients( + batch=2, seqlen=128, nheads=8, headdim=32, dstate=16, + ngroups=2, dtype=torch.float32, + with_gamma_beta=True, with_theta=True, + with_z=True, with_D=True, + with_initial_states=True, return_final_states=True, + ) + + def test_seq_idx_initial_states_rope_trapezoidal_bf16(self): + """Everything together in bf16 — wider tolerance for bf16 accumulation noise.""" + _compare_gradients( + batch=2, seqlen=128, nheads=4, headdim=32, dstate=16, + ngroups=2, dtype=torch.bfloat16, + with_gamma_beta=True, with_theta=True, + with_seq_idx=True, with_D=True, + with_initial_states=True, return_final_states=True, + atol_override=0.02, rtol_override=0.1, + ) + + +# ============================================================================ +# 12. Individual gradient sanity checks +# ============================================================================ + +class TestIndividualGradients: + """Verify specific gradient components in isolation (useful for debugging).""" + + def _get_grad(self, name, **kwargs): + """Run both paths and return (triton_grad, ref_grad) for a named parameter.""" + from mamba_ssm.ops.triton.mamba3_combined import mamba3_chunk_scan_combined_triton + from mamba_ssm.ops.triton.mamba3_ssd import ( + mamba3_chunk_scan_combined as mamba3_ref, + ) + + raw = _make_inputs(**kwargs) + chunk_size = 64 + + results = {} + for label, fn in [("tri", mamba3_chunk_scan_combined_triton), ("ref", mamba3_ref)]: + inp = _clone_with_grad(raw) + out = fn( + inp["x"], inp["dt"], inp["A"], inp["B"], inp["C"], chunk_size, + gamma=inp["gamma"], beta=inp["beta"], theta=inp["theta"], + D=inp["D"], z=inp["z"], dt_bias=inp["dt_bias"], + dt_softplus=inp.get("dt_softplus", False), + initial_states=inp["initial_states"], + initial_prev_Bx=inp["initial_prev_Bx"], + return_final_states=inp.get("return_final_states", False), + ngroups=inp["ngroups"], + seq_idx=inp.get("seq_idx"), + ) + if isinstance(out, tuple): + loss = out[0].float().sum() + out[1].float().sum() + else: + loss = out.float().sum() + loss.backward() + results[label] = inp[name].grad if inp[name] is not None else None + + return results["tri"], results["ref"] + + def test_dx_trapezoidal(self): + tri_g, ref_g = self._get_grad( + "x", batch=2, seqlen=64, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + ) + assert tri_g is not None and ref_g is not None + torch.testing.assert_close(tri_g.float(), ref_g.float(), atol=FP32_ATOL, rtol=FP32_RTOL) + + def test_ddt_euler(self): + tri_g, ref_g = self._get_grad( + "dt", batch=2, seqlen=64, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=False, + ) + assert tri_g is not None and ref_g is not None + torch.testing.assert_close(tri_g.float(), ref_g.float(), atol=FP32_ATOL, rtol=FP32_RTOL) + + def test_dA(self): + tri_g, ref_g = self._get_grad( + "A", batch=2, seqlen=64, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + ) + assert tri_g is not None and ref_g is not None + torch.testing.assert_close(tri_g.float(), ref_g.float(), atol=FP32_ATOL, rtol=FP32_RTOL) + + def test_dB_trapezoidal(self): + tri_g, ref_g = self._get_grad( + "B", batch=2, seqlen=64, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + ) + assert tri_g is not None and ref_g is not None + torch.testing.assert_close(tri_g.float(), ref_g.float(), atol=FP32_ATOL, rtol=FP32_RTOL) + + def test_dC_trapezoidal(self): + tri_g, ref_g = self._get_grad( + "C", batch=2, seqlen=64, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + ) + assert tri_g is not None and ref_g is not None + torch.testing.assert_close(tri_g.float(), ref_g.float(), atol=FP32_ATOL, rtol=FP32_RTOL) + + def test_dgamma(self): + tri_g, ref_g = self._get_grad( + "gamma", batch=2, seqlen=64, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + ) + assert tri_g is not None and ref_g is not None + torch.testing.assert_close(tri_g.float(), ref_g.float(), atol=FP32_ATOL, rtol=FP32_RTOL) + + def test_dbeta(self): + tri_g, ref_g = self._get_grad( + "beta", batch=2, seqlen=64, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + ) + assert tri_g is not None and ref_g is not None + torch.testing.assert_close(tri_g.float(), ref_g.float(), atol=FP32_ATOL, rtol=FP32_RTOL) + + def test_dD(self): + tri_g, ref_g = self._get_grad( + "D", batch=2, seqlen=64, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, with_D=True, + ) + assert tri_g is not None and ref_g is not None + torch.testing.assert_close(tri_g.float(), ref_g.float(), atol=FP32_ATOL, rtol=FP32_RTOL) + + def test_dtheta(self): + tri_g, ref_g = self._get_grad( + "theta", batch=2, seqlen=64, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, with_theta=True, + ) + assert tri_g is not None and ref_g is not None + torch.testing.assert_close(tri_g.float(), ref_g.float(), atol=FP32_ATOL, rtol=FP32_RTOL) + + def test_dz(self): + tri_g, ref_g = self._get_grad( + "z", batch=2, seqlen=64, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, with_z=True, + ) + assert tri_g is not None and ref_g is not None + torch.testing.assert_close(tri_g.float(), ref_g.float(), atol=FP32_ATOL, rtol=FP32_RTOL) + + def test_d_initial_states(self): + tri_g, ref_g = self._get_grad( + "initial_states", batch=2, seqlen=64, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, + with_initial_states=True, return_final_states=True, + ) + assert tri_g is not None and ref_g is not None + torch.testing.assert_close(tri_g.float(), ref_g.float(), atol=FP32_ATOL, rtol=FP32_RTOL) + + def test_d_initial_prev_Bx(self): + tri_g, ref_g = self._get_grad( + "initial_prev_Bx", batch=2, seqlen=64, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, with_initial_prev_Bx=True, + ) + assert tri_g is not None and ref_g is not None + torch.testing.assert_close(tri_g.float(), ref_g.float(), atol=FP32_ATOL, rtol=FP32_RTOL) + + def test_d_dt_bias(self): + tri_g, ref_g = self._get_grad( + "dt_bias", batch=2, seqlen=64, nheads=4, headdim=32, dstate=16, + dtype=torch.float32, with_gamma_beta=True, with_dt_bias=True, + ) + assert tri_g is not None and ref_g is not None + torch.testing.assert_close(tri_g.float(), ref_g.float(), atol=FP32_ATOL, rtol=FP32_RTOL) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"])