Skip to content

Commit c19d43e

Browse files
committed
Add fused GatedDeltaNet decode Triton kernel
Fuse Q/K/V split, L2 normalization, head repeat, gating computation, and delta-rule recurrent state update into a single Triton kernel for decode (T=1). Replaces ~6 small AOTI-generated kernels with one, reducing GatedDeltaNet kernel time by ~62% and improving end-to-end decode throughput by ~2% (106 -> 108.5 tok/s on A100).
1 parent 9042f36 commit c19d43e

5 files changed

Lines changed: 643 additions & 46 deletions

File tree

backends/cuda/triton/kernels/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,12 @@
2929
__all__.append("tq4_sdpa")
3030
except ImportError:
3131
pass
32+
33+
try:
34+
from executorch.backends.cuda.triton.kernels.fused_deltanet_decode import ( # noqa: F401
35+
fused_deltanet_decode,
36+
)
37+
38+
__all__.append("fused_deltanet_decode")
39+
except ImportError:
40+
pass
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Fully-fused SRAM-resident GatedDeltaNet recurrent kernel for decode (T=1).
9+
10+
Fuses post-projection (Q/K/V split from conv1d output, L2 normalization,
11+
head repeat, gating computation) AND the recurrent state update into a
12+
single Triton kernel per layer.
13+
14+
This eliminates intermediate HBM reads/writes for q, k, v, g, beta tensors
15+
and removes multiple small kernel launches (normalize, repeat_interleave,
16+
sigmoid, softplus, exp) that the previous partial-fusion approach required.
17+
18+
For each (batch, v_head):
19+
k_head = v_head // V_PER_K # shared K head
20+
q, k = L2_normalize(qkv_conv[Q/K]) # split + normalize
21+
v = qkv_conv[V] # split
22+
decay = exp(-exp(A_log) * softplus(alpha + dt_bias))
23+
beta = sigmoid(beta_raw)
24+
state = state * decay # decay
25+
Sk = state @ k # [V]
26+
delta = beta * (v - Sk) # [V]
27+
state = state + outer(k, delta) # rank-1 update
28+
output = state @ (q * scale) # [V]
29+
30+
The kernel tiles over the V dimension in blocks of BLOCK_V.
31+
For each V-tile, it streams through K in blocks of BLOCK_K.
32+
33+
Registered as torch.ops.triton.fused_deltanet_decode for AOTI compilation.
34+
"""
35+
36+
import torch
37+
import triton
38+
import triton.language as tl
39+
from torch.library import triton_op, wrap_triton
40+
41+
42+
@triton.autotune(
43+
configs=[
44+
triton.Config({"BLOCK_K": 32, "BLOCK_V": 32}),
45+
triton.Config({"BLOCK_K": 64, "BLOCK_V": 64}),
46+
triton.Config({"BLOCK_K": 128, "BLOCK_V": 128}),
47+
triton.Config({"BLOCK_K": 128, "BLOCK_V": 64}),
48+
triton.Config({"BLOCK_K": 64, "BLOCK_V": 128}),
49+
],
50+
key=["K", "V_DIM"],
51+
)
52+
@triton.jit
53+
def _fused_deltanet_decode_kernel(
54+
# Tensor pointers
55+
QKV_ptr, # [B, conv_dim] post-conv1d+silu output
56+
Alpha_ptr, # [B, H] raw gating input (a)
57+
BetaRaw_ptr, # [B, H] raw write strength (b, pre-sigmoid)
58+
NegAExp_ptr, # [H] -exp(A_log), precomputed
59+
DtBias_ptr, # [H] dt_bias parameter
60+
S_in_ptr, # [B, H, K, V] recurrent state input (read-only)
61+
S_out_ptr, # [B, H, K, V] recurrent state output (write-only)
62+
O_ptr, # [B, H, V] output
63+
# Dimension constants
64+
K: tl.constexpr, # head_k_dim (128)
65+
V_DIM: tl.constexpr, # head_v_dim (128)
66+
KEY_DIM: tl.constexpr, # num_k_heads * K (2048)
67+
V_PER_K: tl.constexpr, # num_v_heads // num_k_heads (2)
68+
SCALE: tl.constexpr, # K^(-0.5)
69+
L2_EPS: tl.constexpr, # 1e-6
70+
# Strides
71+
stride_qkv_b, # qkv stride for batch dim
72+
stride_ab, # alpha stride for batch dim
73+
stride_bb, # beta_raw stride for batch dim
74+
stride_s_b, # state stride: batch
75+
stride_s_h, # state stride: head
76+
stride_s_k, # state stride: K dim
77+
stride_s_v, # state stride: V dim
78+
stride_ob, # output stride: batch
79+
stride_oh, # output stride: head
80+
stride_ov, # output stride: V dim
81+
# Block sizes (autotuned)
82+
BLOCK_K: tl.constexpr,
83+
BLOCK_V: tl.constexpr,
84+
):
85+
"""One program per (batch, v_head, v_block)."""
86+
pid_bh = tl.program_id(0) # batch * num_v_heads index
87+
pid_v = tl.program_id(1) # V-tile index
88+
89+
# Decompose pid_bh into batch and v_head
90+
H: tl.constexpr = KEY_DIM // K * V_PER_K # num_v_heads
91+
bid = pid_bh // H
92+
h = pid_bh % H
93+
k_head = h // V_PER_K # corresponding K head
94+
95+
# V-tile range
96+
v_start = pid_v * BLOCK_V
97+
v_offs = v_start + tl.arange(0, BLOCK_V)
98+
v_mask = v_offs < V_DIM
99+
100+
# ====== Phase 1: Load V slice from qkv_conv ======
101+
# Layout: qkv_conv = [Q(KEY_DIM) | K(KEY_DIM) | V(H * V_DIM)]
102+
qkv_base = QKV_ptr + bid * stride_qkv_b
103+
v_base = qkv_base + 2 * KEY_DIM + h * V_DIM
104+
v_vals = tl.load(v_base + v_offs, mask=v_mask, other=0.0).to(tl.float32)
105+
106+
# ====== Phase 2: Compute gating and beta ======
107+
alpha_h = tl.load(Alpha_ptr + bid * stride_ab + h).to(tl.float32)
108+
neg_a_exp_h = tl.load(NegAExp_ptr + h).to(tl.float32)
109+
dt_bias_h = tl.load(DtBias_ptr + h).to(tl.float32)
110+
111+
# softplus with numerical stability
112+
sp_input = alpha_h + dt_bias_h
113+
sp = tl.where(sp_input > 20.0, sp_input, tl.log(1.0 + tl.exp(sp_input)))
114+
gate = neg_a_exp_h * sp # always negative
115+
decay = tl.exp(gate)
116+
117+
beta_raw_h = tl.load(BetaRaw_ptr + bid * stride_bb + h).to(tl.float32)
118+
beta = tl.sigmoid(beta_raw_h)
119+
120+
# ====== Phase 3: Compute K and Q L2 norms (full-vector reduction) ======
121+
# Each v_block program needs the full K-vector norms, so we compute them here.
122+
# This is redundant across v_blocks for the same (batch, head) but avoids
123+
# a separate kernel launch or shared memory coordination.
124+
q_base = qkv_base + k_head * K
125+
k_base = qkv_base + KEY_DIM + k_head * K
126+
127+
q_sq_sum = tl.zeros([], dtype=tl.float32)
128+
k_sq_sum = tl.zeros([], dtype=tl.float32)
129+
for kk in range(0, K, BLOCK_K):
130+
kk_offs = kk + tl.arange(0, BLOCK_K)
131+
kk_mask = kk_offs < K
132+
q_chunk = tl.load(q_base + kk_offs, mask=kk_mask, other=0.0).to(tl.float32)
133+
k_chunk = tl.load(k_base + kk_offs, mask=kk_mask, other=0.0).to(tl.float32)
134+
q_sq_sum += tl.sum(q_chunk * q_chunk)
135+
k_sq_sum += tl.sum(k_chunk * k_chunk)
136+
137+
q_norm = tl.maximum(tl.sqrt(q_sq_sum), L2_EPS)
138+
k_norm = tl.maximum(tl.sqrt(k_sq_sum), L2_EPS)
139+
140+
# ====== Phase 4: Recurrent state update ======
141+
s_in_base = S_in_ptr + bid * stride_s_b + h * stride_s_h
142+
s_out_base = S_out_ptr + bid * stride_s_b + h * stride_s_h
143+
144+
# --- Pass 1: Decay state, compute Sk = (decay*S)^T @ k_normalized ---
145+
sk_acc = tl.zeros([BLOCK_V], dtype=tl.float32)
146+
for kk in range(0, K, BLOCK_K):
147+
kk_offs = kk + tl.arange(0, BLOCK_K)
148+
kk_mask = kk_offs < K
149+
150+
# Load normalized k slice
151+
k_vals = (
152+
tl.load(k_base + kk_offs, mask=kk_mask, other=0.0).to(tl.float32) / k_norm
153+
)
154+
155+
# Load state tile [BLOCK_K, BLOCK_V]
156+
tile_offs = kk_offs[:, None] * stride_s_k + v_offs[None, :] * stride_s_v
157+
tile_mask = kk_mask[:, None] & v_mask[None, :]
158+
s_tile = tl.load(s_in_base + tile_offs, mask=tile_mask, other=0.0).to(
159+
tl.float32
160+
)
161+
162+
# Decay
163+
s_tile = s_tile * decay
164+
165+
# Sk[v] += sum_k(state[k,v] * k_normalized[k])
166+
sk_acc += tl.sum(s_tile * k_vals[:, None], axis=0)
167+
168+
# delta = beta * (v - Sk)
169+
delta_v = beta * (v_vals - sk_acc)
170+
171+
# --- Pass 2: Re-read input, decay + rank-1 update, write output state, compute output ---
172+
out_acc = tl.zeros([BLOCK_V], dtype=tl.float32)
173+
for kk in range(0, K, BLOCK_K):
174+
kk_offs = kk + tl.arange(0, BLOCK_K)
175+
kk_mask = kk_offs < K
176+
177+
# Load normalized k and q slices
178+
k_vals = (
179+
tl.load(k_base + kk_offs, mask=kk_mask, other=0.0).to(tl.float32) / k_norm
180+
)
181+
q_vals = (
182+
tl.load(q_base + kk_offs, mask=kk_mask, other=0.0).to(tl.float32)
183+
/ q_norm
184+
* SCALE
185+
)
186+
187+
# Re-read input state and decay
188+
tile_offs = kk_offs[:, None] * stride_s_k + v_offs[None, :] * stride_s_v
189+
tile_mask = kk_mask[:, None] & v_mask[None, :]
190+
s_tile = tl.load(s_in_base + tile_offs, mask=tile_mask, other=0.0).to(
191+
tl.float32
192+
)
193+
s_tile = s_tile * decay
194+
195+
# Rank-1 update: S += k ⊗ delta
196+
s_tile = s_tile + k_vals[:, None] * delta_v[None, :]
197+
198+
# Store updated state
199+
tl.store(
200+
s_out_base + tile_offs,
201+
s_tile.to(S_out_ptr.dtype.element_ty),
202+
mask=tile_mask,
203+
)
204+
205+
# Output: out[v] += sum_k(S_new[k,v] * q_scaled[k])
206+
out_acc += tl.sum(s_tile * q_vals[:, None], axis=0)
207+
208+
# Store output
209+
o_offs = O_ptr + bid * stride_ob + h * stride_oh + v_offs * stride_ov
210+
tl.store(o_offs, out_acc.to(O_ptr.dtype.element_ty), mask=v_mask)
211+
212+
213+
@triton_op("triton::fused_deltanet_decode", mutates_args={})
214+
def fused_deltanet_decode(
215+
qkv: torch.Tensor,
216+
alpha: torch.Tensor,
217+
beta_raw: torch.Tensor,
218+
A_log: torch.Tensor,
219+
dt_bias: torch.Tensor,
220+
state: torch.Tensor,
221+
) -> tuple[torch.Tensor, torch.Tensor]:
222+
"""
223+
Fully-fused GatedDeltaNet decode (T=1) recurrent step.
224+
225+
Fuses Q/K/V split, L2 normalization, head repeat, gating, and delta rule
226+
recurrence into a single kernel.
227+
228+
Args:
229+
qkv: [B, conv_dim] post-conv1d+silu output (Q|K|V concatenated)
230+
alpha: [B, num_v_heads] raw gating input (pre-softplus)
231+
beta_raw: [B, num_v_heads] raw write strength (pre-sigmoid)
232+
A_log: [num_v_heads] log(A) parameter (negated exp computed inside)
233+
dt_bias: [num_v_heads] gating bias parameter
234+
state: [B, num_v_heads, K, V] recurrent state (read-only, not mutated)
235+
236+
Returns:
237+
tuple of (output, new_state):
238+
output: [B, num_v_heads, V] decode output (same dtype as state)
239+
new_state: [B, num_v_heads, K, V] updated state (same dtype as state)
240+
"""
241+
B = qkv.shape[0]
242+
H, K, V_DIM = state.shape[1], state.shape[2], state.shape[3]
243+
244+
# Derive layout constants from tensor shapes
245+
# conv_dim = 2 * KEY_DIM + H * V_DIM, KEY_DIM = num_k_heads * K
246+
value_dim = H * V_DIM
247+
KEY_DIM = (qkv.shape[1] - value_dim) // 2
248+
num_k_heads = KEY_DIM // K
249+
V_PER_K = H // num_k_heads
250+
251+
output = torch.empty(B, H, V_DIM, dtype=state.dtype, device=qkv.device)
252+
253+
# Compute neg_A_exp from A_log parameter
254+
neg_A_exp = -torch.exp(A_log.float())
255+
256+
# Separate input/output state buffers for autotuning safety
257+
# (autotuner may re-run the kernel; reading from a buffer we also write
258+
# would produce wrong results on the second run)
259+
state_in = state.float().contiguous()
260+
state_out = torch.empty_like(state_in)
261+
262+
def grid(meta):
263+
return (B * H, triton.cdiv(V_DIM, meta["BLOCK_V"]))
264+
265+
wrap_triton(_fused_deltanet_decode_kernel)[grid](
266+
qkv,
267+
alpha,
268+
beta_raw,
269+
neg_A_exp,
270+
dt_bias,
271+
state_in,
272+
state_out,
273+
output,
274+
# Dimensions
275+
K=K,
276+
V_DIM=V_DIM,
277+
KEY_DIM=KEY_DIM,
278+
V_PER_K=V_PER_K,
279+
SCALE=K**-0.5,
280+
L2_EPS=1e-6,
281+
# Strides
282+
stride_qkv_b=qkv.stride(0),
283+
stride_ab=alpha.stride(0),
284+
stride_bb=beta_raw.stride(0),
285+
stride_s_b=state_in.stride(0),
286+
stride_s_h=state_in.stride(1),
287+
stride_s_k=state_in.stride(2),
288+
stride_s_v=state_in.stride(3),
289+
stride_ob=output.stride(0),
290+
stride_oh=output.stride(1),
291+
stride_ov=output.stride(2),
292+
)
293+
294+
return output, state_out.to(state.dtype)
295+
296+
297+
@fused_deltanet_decode.register_fake
298+
def _fused_deltanet_decode_fake(
299+
qkv: torch.Tensor,
300+
alpha: torch.Tensor,
301+
beta_raw: torch.Tensor,
302+
A_log: torch.Tensor,
303+
dt_bias: torch.Tensor,
304+
state: torch.Tensor,
305+
) -> tuple[torch.Tensor, torch.Tensor]:
306+
B = qkv.shape[0]
307+
H, K_DIM, V_DIM = state.shape[1], state.shape[2], state.shape[3]
308+
output = torch.empty(B, H, V_DIM, dtype=state.dtype, device=qkv.device)
309+
new_state = torch.empty(B, H, K_DIM, V_DIM, dtype=state.dtype, device=qkv.device)
310+
return output, new_state

0 commit comments

Comments
 (0)