|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +""" |
| 8 | +Fully-fused SRAM-resident GatedDeltaNet recurrent kernel for decode (T=1). |
| 9 | +
|
| 10 | +Fuses post-projection (Q/K/V split from conv1d output, L2 normalization, |
| 11 | +head repeat, gating computation) AND the recurrent state update into a |
| 12 | +single Triton kernel per layer. |
| 13 | +
|
| 14 | +This eliminates intermediate HBM reads/writes for q, k, v, g, beta tensors |
| 15 | +and removes multiple small kernel launches (normalize, repeat_interleave, |
| 16 | +sigmoid, softplus, exp) that the previous partial-fusion approach required. |
| 17 | +
|
| 18 | +For each (batch, v_head): |
| 19 | + k_head = v_head // V_PER_K # shared K head |
| 20 | + q, k = L2_normalize(qkv_conv[Q/K]) # split + normalize |
| 21 | + v = qkv_conv[V] # split |
| 22 | + decay = exp(-exp(A_log) * softplus(alpha + dt_bias)) |
| 23 | + beta = sigmoid(beta_raw) |
| 24 | + state = state * decay # decay |
| 25 | + Sk = state @ k # [V] |
| 26 | + delta = beta * (v - Sk) # [V] |
| 27 | + state = state + outer(k, delta) # rank-1 update |
| 28 | + output = state @ (q * scale) # [V] |
| 29 | +
|
| 30 | +The kernel tiles over the V dimension in blocks of BLOCK_V. |
| 31 | +For each V-tile, it streams through K in blocks of BLOCK_K. |
| 32 | +
|
| 33 | +Registered as torch.ops.triton.fused_deltanet_decode for AOTI compilation. |
| 34 | +""" |
| 35 | + |
| 36 | +import torch |
| 37 | +import triton |
| 38 | +import triton.language as tl |
| 39 | +from torch.library import triton_op, wrap_triton |
| 40 | + |
| 41 | + |
| 42 | +@triton.autotune( |
| 43 | + configs=[ |
| 44 | + triton.Config({"BLOCK_K": 32, "BLOCK_V": 32}), |
| 45 | + triton.Config({"BLOCK_K": 64, "BLOCK_V": 64}), |
| 46 | + triton.Config({"BLOCK_K": 128, "BLOCK_V": 128}), |
| 47 | + triton.Config({"BLOCK_K": 128, "BLOCK_V": 64}), |
| 48 | + triton.Config({"BLOCK_K": 64, "BLOCK_V": 128}), |
| 49 | + ], |
| 50 | + key=["K", "V_DIM"], |
| 51 | +) |
| 52 | +@triton.jit |
| 53 | +def _fused_deltanet_decode_kernel( |
| 54 | + # Tensor pointers |
| 55 | + QKV_ptr, # [B, conv_dim] post-conv1d+silu output |
| 56 | + Alpha_ptr, # [B, H] raw gating input (a) |
| 57 | + BetaRaw_ptr, # [B, H] raw write strength (b, pre-sigmoid) |
| 58 | + NegAExp_ptr, # [H] -exp(A_log), precomputed |
| 59 | + DtBias_ptr, # [H] dt_bias parameter |
| 60 | + S_in_ptr, # [B, H, K, V] recurrent state input (read-only) |
| 61 | + S_out_ptr, # [B, H, K, V] recurrent state output (write-only) |
| 62 | + O_ptr, # [B, H, V] output |
| 63 | + # Dimension constants |
| 64 | + K: tl.constexpr, # head_k_dim (128) |
| 65 | + V_DIM: tl.constexpr, # head_v_dim (128) |
| 66 | + KEY_DIM: tl.constexpr, # num_k_heads * K (2048) |
| 67 | + V_PER_K: tl.constexpr, # num_v_heads // num_k_heads (2) |
| 68 | + SCALE: tl.constexpr, # K^(-0.5) |
| 69 | + L2_EPS: tl.constexpr, # 1e-6 |
| 70 | + # Strides |
| 71 | + stride_qkv_b, # qkv stride for batch dim |
| 72 | + stride_ab, # alpha stride for batch dim |
| 73 | + stride_bb, # beta_raw stride for batch dim |
| 74 | + stride_s_b, # state stride: batch |
| 75 | + stride_s_h, # state stride: head |
| 76 | + stride_s_k, # state stride: K dim |
| 77 | + stride_s_v, # state stride: V dim |
| 78 | + stride_ob, # output stride: batch |
| 79 | + stride_oh, # output stride: head |
| 80 | + stride_ov, # output stride: V dim |
| 81 | + # Block sizes (autotuned) |
| 82 | + BLOCK_K: tl.constexpr, |
| 83 | + BLOCK_V: tl.constexpr, |
| 84 | +): |
| 85 | + """One program per (batch, v_head, v_block).""" |
| 86 | + pid_bh = tl.program_id(0) # batch * num_v_heads index |
| 87 | + pid_v = tl.program_id(1) # V-tile index |
| 88 | + |
| 89 | + # Decompose pid_bh into batch and v_head |
| 90 | + H: tl.constexpr = KEY_DIM // K * V_PER_K # num_v_heads |
| 91 | + bid = pid_bh // H |
| 92 | + h = pid_bh % H |
| 93 | + k_head = h // V_PER_K # corresponding K head |
| 94 | + |
| 95 | + # V-tile range |
| 96 | + v_start = pid_v * BLOCK_V |
| 97 | + v_offs = v_start + tl.arange(0, BLOCK_V) |
| 98 | + v_mask = v_offs < V_DIM |
| 99 | + |
| 100 | + # ====== Phase 1: Load V slice from qkv_conv ====== |
| 101 | + # Layout: qkv_conv = [Q(KEY_DIM) | K(KEY_DIM) | V(H * V_DIM)] |
| 102 | + qkv_base = QKV_ptr + bid * stride_qkv_b |
| 103 | + v_base = qkv_base + 2 * KEY_DIM + h * V_DIM |
| 104 | + v_vals = tl.load(v_base + v_offs, mask=v_mask, other=0.0).to(tl.float32) |
| 105 | + |
| 106 | + # ====== Phase 2: Compute gating and beta ====== |
| 107 | + alpha_h = tl.load(Alpha_ptr + bid * stride_ab + h).to(tl.float32) |
| 108 | + neg_a_exp_h = tl.load(NegAExp_ptr + h).to(tl.float32) |
| 109 | + dt_bias_h = tl.load(DtBias_ptr + h).to(tl.float32) |
| 110 | + |
| 111 | + # softplus with numerical stability |
| 112 | + sp_input = alpha_h + dt_bias_h |
| 113 | + sp = tl.where(sp_input > 20.0, sp_input, tl.log(1.0 + tl.exp(sp_input))) |
| 114 | + gate = neg_a_exp_h * sp # always negative |
| 115 | + decay = tl.exp(gate) |
| 116 | + |
| 117 | + beta_raw_h = tl.load(BetaRaw_ptr + bid * stride_bb + h).to(tl.float32) |
| 118 | + beta = tl.sigmoid(beta_raw_h) |
| 119 | + |
| 120 | + # ====== Phase 3: Compute K and Q L2 norms (full-vector reduction) ====== |
| 121 | + # Each v_block program needs the full K-vector norms, so we compute them here. |
| 122 | + # This is redundant across v_blocks for the same (batch, head) but avoids |
| 123 | + # a separate kernel launch or shared memory coordination. |
| 124 | + q_base = qkv_base + k_head * K |
| 125 | + k_base = qkv_base + KEY_DIM + k_head * K |
| 126 | + |
| 127 | + q_sq_sum = tl.zeros([], dtype=tl.float32) |
| 128 | + k_sq_sum = tl.zeros([], dtype=tl.float32) |
| 129 | + for kk in range(0, K, BLOCK_K): |
| 130 | + kk_offs = kk + tl.arange(0, BLOCK_K) |
| 131 | + kk_mask = kk_offs < K |
| 132 | + q_chunk = tl.load(q_base + kk_offs, mask=kk_mask, other=0.0).to(tl.float32) |
| 133 | + k_chunk = tl.load(k_base + kk_offs, mask=kk_mask, other=0.0).to(tl.float32) |
| 134 | + q_sq_sum += tl.sum(q_chunk * q_chunk) |
| 135 | + k_sq_sum += tl.sum(k_chunk * k_chunk) |
| 136 | + |
| 137 | + q_norm = tl.maximum(tl.sqrt(q_sq_sum), L2_EPS) |
| 138 | + k_norm = tl.maximum(tl.sqrt(k_sq_sum), L2_EPS) |
| 139 | + |
| 140 | + # ====== Phase 4: Recurrent state update ====== |
| 141 | + s_in_base = S_in_ptr + bid * stride_s_b + h * stride_s_h |
| 142 | + s_out_base = S_out_ptr + bid * stride_s_b + h * stride_s_h |
| 143 | + |
| 144 | + # --- Pass 1: Decay state, compute Sk = (decay*S)^T @ k_normalized --- |
| 145 | + sk_acc = tl.zeros([BLOCK_V], dtype=tl.float32) |
| 146 | + for kk in range(0, K, BLOCK_K): |
| 147 | + kk_offs = kk + tl.arange(0, BLOCK_K) |
| 148 | + kk_mask = kk_offs < K |
| 149 | + |
| 150 | + # Load normalized k slice |
| 151 | + k_vals = ( |
| 152 | + tl.load(k_base + kk_offs, mask=kk_mask, other=0.0).to(tl.float32) / k_norm |
| 153 | + ) |
| 154 | + |
| 155 | + # Load state tile [BLOCK_K, BLOCK_V] |
| 156 | + tile_offs = kk_offs[:, None] * stride_s_k + v_offs[None, :] * stride_s_v |
| 157 | + tile_mask = kk_mask[:, None] & v_mask[None, :] |
| 158 | + s_tile = tl.load(s_in_base + tile_offs, mask=tile_mask, other=0.0).to( |
| 159 | + tl.float32 |
| 160 | + ) |
| 161 | + |
| 162 | + # Decay |
| 163 | + s_tile = s_tile * decay |
| 164 | + |
| 165 | + # Sk[v] += sum_k(state[k,v] * k_normalized[k]) |
| 166 | + sk_acc += tl.sum(s_tile * k_vals[:, None], axis=0) |
| 167 | + |
| 168 | + # delta = beta * (v - Sk) |
| 169 | + delta_v = beta * (v_vals - sk_acc) |
| 170 | + |
| 171 | + # --- Pass 2: Re-read input, decay + rank-1 update, write output state, compute output --- |
| 172 | + out_acc = tl.zeros([BLOCK_V], dtype=tl.float32) |
| 173 | + for kk in range(0, K, BLOCK_K): |
| 174 | + kk_offs = kk + tl.arange(0, BLOCK_K) |
| 175 | + kk_mask = kk_offs < K |
| 176 | + |
| 177 | + # Load normalized k and q slices |
| 178 | + k_vals = ( |
| 179 | + tl.load(k_base + kk_offs, mask=kk_mask, other=0.0).to(tl.float32) / k_norm |
| 180 | + ) |
| 181 | + q_vals = ( |
| 182 | + tl.load(q_base + kk_offs, mask=kk_mask, other=0.0).to(tl.float32) |
| 183 | + / q_norm |
| 184 | + * SCALE |
| 185 | + ) |
| 186 | + |
| 187 | + # Re-read input state and decay |
| 188 | + tile_offs = kk_offs[:, None] * stride_s_k + v_offs[None, :] * stride_s_v |
| 189 | + tile_mask = kk_mask[:, None] & v_mask[None, :] |
| 190 | + s_tile = tl.load(s_in_base + tile_offs, mask=tile_mask, other=0.0).to( |
| 191 | + tl.float32 |
| 192 | + ) |
| 193 | + s_tile = s_tile * decay |
| 194 | + |
| 195 | + # Rank-1 update: S += k ⊗ delta |
| 196 | + s_tile = s_tile + k_vals[:, None] * delta_v[None, :] |
| 197 | + |
| 198 | + # Store updated state |
| 199 | + tl.store( |
| 200 | + s_out_base + tile_offs, |
| 201 | + s_tile.to(S_out_ptr.dtype.element_ty), |
| 202 | + mask=tile_mask, |
| 203 | + ) |
| 204 | + |
| 205 | + # Output: out[v] += sum_k(S_new[k,v] * q_scaled[k]) |
| 206 | + out_acc += tl.sum(s_tile * q_vals[:, None], axis=0) |
| 207 | + |
| 208 | + # Store output |
| 209 | + o_offs = O_ptr + bid * stride_ob + h * stride_oh + v_offs * stride_ov |
| 210 | + tl.store(o_offs, out_acc.to(O_ptr.dtype.element_ty), mask=v_mask) |
| 211 | + |
| 212 | + |
| 213 | +@triton_op("triton::fused_deltanet_decode", mutates_args={}) |
| 214 | +def fused_deltanet_decode( |
| 215 | + qkv: torch.Tensor, |
| 216 | + alpha: torch.Tensor, |
| 217 | + beta_raw: torch.Tensor, |
| 218 | + A_log: torch.Tensor, |
| 219 | + dt_bias: torch.Tensor, |
| 220 | + state: torch.Tensor, |
| 221 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 222 | + """ |
| 223 | + Fully-fused GatedDeltaNet decode (T=1) recurrent step. |
| 224 | +
|
| 225 | + Fuses Q/K/V split, L2 normalization, head repeat, gating, and delta rule |
| 226 | + recurrence into a single kernel. |
| 227 | +
|
| 228 | + Args: |
| 229 | + qkv: [B, conv_dim] post-conv1d+silu output (Q|K|V concatenated) |
| 230 | + alpha: [B, num_v_heads] raw gating input (pre-softplus) |
| 231 | + beta_raw: [B, num_v_heads] raw write strength (pre-sigmoid) |
| 232 | + A_log: [num_v_heads] log(A) parameter (negated exp computed inside) |
| 233 | + dt_bias: [num_v_heads] gating bias parameter |
| 234 | + state: [B, num_v_heads, K, V] recurrent state (read-only, not mutated) |
| 235 | +
|
| 236 | + Returns: |
| 237 | + tuple of (output, new_state): |
| 238 | + output: [B, num_v_heads, V] decode output (same dtype as state) |
| 239 | + new_state: [B, num_v_heads, K, V] updated state (same dtype as state) |
| 240 | + """ |
| 241 | + B = qkv.shape[0] |
| 242 | + H, K, V_DIM = state.shape[1], state.shape[2], state.shape[3] |
| 243 | + |
| 244 | + # Derive layout constants from tensor shapes |
| 245 | + # conv_dim = 2 * KEY_DIM + H * V_DIM, KEY_DIM = num_k_heads * K |
| 246 | + value_dim = H * V_DIM |
| 247 | + KEY_DIM = (qkv.shape[1] - value_dim) // 2 |
| 248 | + num_k_heads = KEY_DIM // K |
| 249 | + V_PER_K = H // num_k_heads |
| 250 | + |
| 251 | + output = torch.empty(B, H, V_DIM, dtype=state.dtype, device=qkv.device) |
| 252 | + |
| 253 | + # Compute neg_A_exp from A_log parameter |
| 254 | + neg_A_exp = -torch.exp(A_log.float()) |
| 255 | + |
| 256 | + # Separate input/output state buffers for autotuning safety |
| 257 | + # (autotuner may re-run the kernel; reading from a buffer we also write |
| 258 | + # would produce wrong results on the second run) |
| 259 | + state_in = state.float().contiguous() |
| 260 | + state_out = torch.empty_like(state_in) |
| 261 | + |
| 262 | + def grid(meta): |
| 263 | + return (B * H, triton.cdiv(V_DIM, meta["BLOCK_V"])) |
| 264 | + |
| 265 | + wrap_triton(_fused_deltanet_decode_kernel)[grid]( |
| 266 | + qkv, |
| 267 | + alpha, |
| 268 | + beta_raw, |
| 269 | + neg_A_exp, |
| 270 | + dt_bias, |
| 271 | + state_in, |
| 272 | + state_out, |
| 273 | + output, |
| 274 | + # Dimensions |
| 275 | + K=K, |
| 276 | + V_DIM=V_DIM, |
| 277 | + KEY_DIM=KEY_DIM, |
| 278 | + V_PER_K=V_PER_K, |
| 279 | + SCALE=K**-0.5, |
| 280 | + L2_EPS=1e-6, |
| 281 | + # Strides |
| 282 | + stride_qkv_b=qkv.stride(0), |
| 283 | + stride_ab=alpha.stride(0), |
| 284 | + stride_bb=beta_raw.stride(0), |
| 285 | + stride_s_b=state_in.stride(0), |
| 286 | + stride_s_h=state_in.stride(1), |
| 287 | + stride_s_k=state_in.stride(2), |
| 288 | + stride_s_v=state_in.stride(3), |
| 289 | + stride_ob=output.stride(0), |
| 290 | + stride_oh=output.stride(1), |
| 291 | + stride_ov=output.stride(2), |
| 292 | + ) |
| 293 | + |
| 294 | + return output, state_out.to(state.dtype) |
| 295 | + |
| 296 | + |
| 297 | +@fused_deltanet_decode.register_fake |
| 298 | +def _fused_deltanet_decode_fake( |
| 299 | + qkv: torch.Tensor, |
| 300 | + alpha: torch.Tensor, |
| 301 | + beta_raw: torch.Tensor, |
| 302 | + A_log: torch.Tensor, |
| 303 | + dt_bias: torch.Tensor, |
| 304 | + state: torch.Tensor, |
| 305 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 306 | + B = qkv.shape[0] |
| 307 | + H, K_DIM, V_DIM = state.shape[1], state.shape[2], state.shape[3] |
| 308 | + output = torch.empty(B, H, V_DIM, dtype=state.dtype, device=qkv.device) |
| 309 | + new_state = torch.empty(B, H, K_DIM, V_DIM, dtype=state.dtype, device=qkv.device) |
| 310 | + return output, new_state |
0 commit comments