Skip to content

Commit 6968475

Browse files
Gasoonjiagasoonjia
andauthored
Add fused GatedDeltaNet decode Triton kernel (#18865)
Fuse Q/K/V split, L2 normalization, head repeat, gating computation, and delta-rule recurrent state update into a single Triton kernel for decode (T=1). Replaces ~6 small AOTI-generated kernels with one, reducing GatedDeltaNet kernel time by ~62%. Config | performance (compare with last optimization) -- | -- p=128 d=128 | 156.7 (+6.8) p=128 d=512 | 160.8 (+7.1) p=256 d=128 | 156.1 (+6.5) p=256 d=512 | 160.8 (+7.3) p=512 d=128 | 156.0 (+6.9) p=512 d=512 | 160.9 (+8.0) p=1024 d=128 | 156.3 (+7.7) p=1024 d=512 | 160.0 (+6.6) p=2048 d=128 | 154.8 (+6.4) p=2048 d=512 | 160.6 (+7.5) Average | 158.3 (+7.1) --------- Co-authored-by: gasoonjia <gasoonjia@fb.com>
1 parent 1014985 commit 6968475

5 files changed

Lines changed: 753 additions & 46 deletions

File tree

Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Correctness test: fully-fused Triton GatedDeltaNet decode kernel vs PyTorch reference.
9+
10+
Verifies that torch.ops.triton.fused_deltanet_decode produces the same output
11+
and state as the original GatedDeltaNet T=1 recurrence with manual Q/K/V split,
12+
L2 norm, head repeat, and gating.
13+
"""
14+
15+
import unittest
16+
17+
import torch
18+
import torch.nn.functional as F
19+
20+
21+
def _skip_if_no_cuda():
22+
if not torch.cuda.is_available():
23+
raise unittest.SkipTest("CUDA not available")
24+
if not torch.cuda.is_bf16_supported():
25+
raise unittest.SkipTest("BF16 not supported on this GPU")
26+
27+
28+
def _import_fused_deltanet_decode():
29+
from executorch.backends.cuda.triton.kernels.fused_deltanet_decode import (
30+
fused_deltanet_decode, # noqa: F401 — registers torch.ops.triton.*
31+
)
32+
33+
return fused_deltanet_decode
34+
35+
36+
def _max_abs_error(a, b):
37+
return (a.float() - b.float()).abs().max().item()
38+
39+
40+
# bf16 kernel vs fp32 reference tolerance.
41+
MAX_ABS_TOL = 0.05
42+
MULTISTEP_TOL = 0.1
43+
44+
45+
def _reference_deltanet_decode(
46+
qkv_conv,
47+
alpha,
48+
beta_raw,
49+
neg_A_exp,
50+
dt_bias,
51+
state,
52+
num_k_heads,
53+
num_v_heads,
54+
head_k_dim,
55+
head_v_dim,
56+
):
57+
"""Reference PyTorch implementation matching model.py's original T=1 path.
58+
59+
Does Q/K/V split, L2 norm, head repeat, gating, then recurrent update.
60+
"""
61+
B = qkv_conv.shape[0]
62+
key_dim = num_k_heads * head_k_dim
63+
64+
q = qkv_conv[:, :key_dim].reshape(B, num_k_heads, head_k_dim)
65+
k = qkv_conv[:, key_dim : 2 * key_dim].reshape(B, num_k_heads, head_k_dim)
66+
v = qkv_conv[:, 2 * key_dim :].reshape(B, num_v_heads, head_v_dim)
67+
68+
q = F.normalize(q.float(), p=2, dim=-1)
69+
k = F.normalize(k.float(), p=2, dim=-1)
70+
v = v.float()
71+
72+
head_repeat = num_v_heads // num_k_heads
73+
if head_repeat > 1:
74+
q = q.repeat_interleave(head_repeat, dim=1)
75+
k = k.repeat_interleave(head_repeat, dim=1)
76+
77+
beta = torch.sigmoid(beta_raw.float())
78+
g = neg_A_exp.float() * F.softplus(alpha.float() + dt_bias.float())
79+
80+
scale = head_k_dim**-0.5
81+
state_f32 = state.float()
82+
83+
decay = torch.exp(g).unsqueeze(-1).unsqueeze(-1)
84+
state_f32 = state_f32 * decay
85+
86+
Sk = torch.einsum("bhkv,bhk->bhv", state_f32, k)
87+
delta = beta.unsqueeze(-1) * (v - Sk)
88+
state_f32 = state_f32 + torch.einsum("bhk,bhv->bhkv", k, delta)
89+
90+
output = torch.einsum("bhkv,bhk->bhv", state_f32, q) * scale
91+
92+
new_state = state_f32.to(state.dtype)
93+
return output, new_state
94+
95+
96+
# Qwen3.5 MoE dimensions (used across tests)
97+
NUM_K_HEADS = 16
98+
NUM_V_HEADS = 32
99+
HEAD_K_DIM = 128
100+
HEAD_V_DIM = 128
101+
KEY_DIM = NUM_K_HEADS * HEAD_K_DIM # 2048
102+
VALUE_DIM = NUM_V_HEADS * HEAD_V_DIM # 4096
103+
CONV_DIM = 2 * KEY_DIM + VALUE_DIM # 8192
104+
105+
106+
class TestFusedDeltanetDecode(unittest.TestCase):
107+
"""Test fused GatedDeltaNet decode kernel correctness against PyTorch reference."""
108+
109+
@classmethod
110+
def setUpClass(cls):
111+
_skip_if_no_cuda()
112+
cls.fused_fn = _import_fused_deltanet_decode()
113+
torch.manual_seed(42)
114+
115+
cls.A_log = torch.log(torch.empty(NUM_V_HEADS, device="cuda").uniform_(0.5, 8))
116+
cls.neg_A_exp = -torch.exp(cls.A_log).float()
117+
cls.dt_bias = torch.ones(NUM_V_HEADS, device="cuda", dtype=torch.float32)
118+
119+
def _run_fused(self, qkv, alpha, beta_raw, state):
120+
"""Run fused kernel and return (output, new_state)."""
121+
output, new_state = torch.ops.triton.fused_deltanet_decode(
122+
qkv,
123+
alpha,
124+
beta_raw,
125+
self.A_log,
126+
self.dt_bias,
127+
state,
128+
)
129+
return output, new_state
130+
131+
def _run_reference(self, qkv, alpha, beta_raw, state):
132+
"""Run reference and return (output, new_state)."""
133+
return _reference_deltanet_decode(
134+
qkv,
135+
alpha,
136+
beta_raw,
137+
self.neg_A_exp,
138+
self.dt_bias,
139+
state,
140+
NUM_K_HEADS,
141+
NUM_V_HEADS,
142+
HEAD_K_DIM,
143+
HEAD_V_DIM,
144+
)
145+
146+
# ------------------------------------------------------------------
147+
# Correctness
148+
# ------------------------------------------------------------------
149+
150+
def test_basic(self):
151+
"""Single batch, Qwen3.5 MoE dimensions."""
152+
B = 1
153+
torch.manual_seed(42)
154+
qkv = torch.randn(B, CONV_DIM, device="cuda", dtype=torch.bfloat16) * 0.1
155+
alpha = torch.randn(B, NUM_V_HEADS, device="cuda", dtype=torch.float32)
156+
beta_raw = torch.randn(B, NUM_V_HEADS, device="cuda", dtype=torch.float32)
157+
state = (
158+
torch.randn(
159+
B,
160+
NUM_V_HEADS,
161+
HEAD_K_DIM,
162+
HEAD_V_DIM,
163+
device="cuda",
164+
dtype=torch.bfloat16,
165+
)
166+
* 0.1
167+
)
168+
169+
ref_out, ref_state = self._run_reference(
170+
qkv.clone(),
171+
alpha.clone(),
172+
beta_raw.clone(),
173+
state.clone(),
174+
)
175+
fused_out, fused_state = self._run_fused(
176+
qkv.clone(),
177+
alpha.clone(),
178+
beta_raw.clone(),
179+
state.clone(),
180+
)
181+
182+
self.assertLess(
183+
_max_abs_error(fused_out, ref_out), MAX_ABS_TOL, "output mismatch"
184+
)
185+
self.assertLess(
186+
_max_abs_error(fused_state, ref_state), MAX_ABS_TOL, "state mismatch"
187+
)
188+
189+
def test_batch(self):
190+
"""Batch size > 1."""
191+
for B in [2, 4]:
192+
with self.subTest(B=B):
193+
torch.manual_seed(42)
194+
qkv = (
195+
torch.randn(B, CONV_DIM, device="cuda", dtype=torch.bfloat16) * 0.1
196+
)
197+
alpha = torch.randn(B, NUM_V_HEADS, device="cuda", dtype=torch.float32)
198+
beta_raw = torch.randn(
199+
B, NUM_V_HEADS, device="cuda", dtype=torch.float32
200+
)
201+
state = (
202+
torch.randn(
203+
B,
204+
NUM_V_HEADS,
205+
HEAD_K_DIM,
206+
HEAD_V_DIM,
207+
device="cuda",
208+
dtype=torch.bfloat16,
209+
)
210+
* 0.1
211+
)
212+
213+
ref_out, ref_state = self._run_reference(
214+
qkv.clone(),
215+
alpha.clone(),
216+
beta_raw.clone(),
217+
state.clone(),
218+
)
219+
fused_out, fused_state = self._run_fused(
220+
qkv.clone(),
221+
alpha.clone(),
222+
beta_raw.clone(),
223+
state.clone(),
224+
)
225+
226+
self.assertLess(
227+
_max_abs_error(fused_out, ref_out),
228+
MAX_ABS_TOL,
229+
f"B={B} output mismatch",
230+
)
231+
self.assertLess(
232+
_max_abs_error(fused_state, ref_state),
233+
MAX_ABS_TOL,
234+
f"B={B} state mismatch",
235+
)
236+
237+
def test_multistep(self):
238+
"""10-step sequential decode checks accumulation drift."""
239+
torch.manual_seed(42)
240+
state_ref = (
241+
torch.randn(
242+
1,
243+
NUM_V_HEADS,
244+
HEAD_K_DIM,
245+
HEAD_V_DIM,
246+
device="cuda",
247+
dtype=torch.bfloat16,
248+
)
249+
* 0.01
250+
)
251+
state_fused = state_ref.clone()
252+
253+
for _ in range(10):
254+
qkv = torch.randn(1, CONV_DIM, device="cuda", dtype=torch.bfloat16) * 0.1
255+
alpha = torch.randn(1, NUM_V_HEADS, device="cuda", dtype=torch.float32)
256+
beta_raw = torch.randn(1, NUM_V_HEADS, device="cuda", dtype=torch.float32)
257+
258+
ref_out, state_ref = self._run_reference(
259+
qkv.clone(),
260+
alpha.clone(),
261+
beta_raw.clone(),
262+
state_ref,
263+
)
264+
fused_out, state_fused = self._run_fused(
265+
qkv.clone(),
266+
alpha.clone(),
267+
beta_raw.clone(),
268+
state_fused,
269+
)
270+
271+
self.assertLess(
272+
_max_abs_error(fused_out, ref_out),
273+
MULTISTEP_TOL,
274+
"multi-step output drift",
275+
)
276+
self.assertLess(
277+
_max_abs_error(state_fused, state_ref),
278+
MULTISTEP_TOL,
279+
"multi-step state drift",
280+
)
281+
282+
def test_state_not_mutated(self):
283+
"""Kernel must not mutate the input state tensor."""
284+
B = 1
285+
torch.manual_seed(42)
286+
qkv = torch.randn(B, CONV_DIM, device="cuda", dtype=torch.bfloat16) * 0.1
287+
alpha = torch.randn(B, NUM_V_HEADS, device="cuda", dtype=torch.float32)
288+
beta_raw = torch.randn(B, NUM_V_HEADS, device="cuda", dtype=torch.float32)
289+
state = (
290+
torch.randn(
291+
B,
292+
NUM_V_HEADS,
293+
HEAD_K_DIM,
294+
HEAD_V_DIM,
295+
device="cuda",
296+
dtype=torch.bfloat16,
297+
)
298+
* 0.1
299+
)
300+
state_copy = state.clone()
301+
302+
_, _ = self._run_fused(qkv, alpha, beta_raw, state)
303+
304+
self.assertTrue(torch.equal(state, state_copy), "input state was mutated")
305+
306+
# ------------------------------------------------------------------
307+
# CUDA Graph compatibility
308+
# ------------------------------------------------------------------
309+
310+
def test_cuda_graph(self):
311+
"""Kernel must be capturable in a CUDA graph."""
312+
B = 1
313+
torch.manual_seed(42)
314+
qkv = torch.randn(B, CONV_DIM, device="cuda", dtype=torch.bfloat16) * 0.1
315+
alpha = torch.randn(B, NUM_V_HEADS, device="cuda", dtype=torch.float32)
316+
beta_raw = torch.randn(B, NUM_V_HEADS, device="cuda", dtype=torch.float32)
317+
state = (
318+
torch.randn(
319+
B,
320+
NUM_V_HEADS,
321+
HEAD_K_DIM,
322+
HEAD_V_DIM,
323+
device="cuda",
324+
dtype=torch.bfloat16,
325+
)
326+
* 0.1
327+
)
328+
329+
# Warmup
330+
for _ in range(3):
331+
_ = self._run_fused(qkv, alpha, beta_raw, state)
332+
333+
# Capture
334+
graph = torch.cuda.CUDAGraph()
335+
with torch.cuda.graph(graph):
336+
out_cg, state_cg = self._run_fused(qkv, alpha, beta_raw, state)
337+
338+
# Replay
339+
graph.replay()
340+
341+
# Compare with reference
342+
ref_out, _ = self._run_reference(
343+
qkv.clone(),
344+
alpha.clone(),
345+
beta_raw.clone(),
346+
state.clone(),
347+
)
348+
self.assertFalse(torch.isnan(out_cg).any(), "NaN in CUDA graph output")
349+
self.assertLess(
350+
_max_abs_error(out_cg, ref_out),
351+
MAX_ABS_TOL,
352+
"CUDA graph output mismatch",
353+
)
354+
355+
356+
if __name__ == "__main__":
357+
unittest.main()

backends/cuda/triton/kernels/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,12 @@
3838
__all__.append("tq4_sdpa")
3939
except ImportError:
4040
pass
41+
42+
try:
43+
from executorch.backends.cuda.triton.kernels.fused_deltanet_decode import ( # noqa: F401
44+
fused_deltanet_decode,
45+
)
46+
47+
__all__.append("fused_deltanet_decode")
48+
except ImportError:
49+
pass

0 commit comments

Comments
 (0)