Skip to content

Commit ce3e9ca

Browse files
committed
lint fix - 2
1 parent c90a8e8 commit ce3e9ca

File tree

13 files changed

+45953
-40
lines changed

13 files changed

+45953
-40
lines changed

backends/cuda/tests/test_chunk_gated_delta_rule.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,11 @@ def _make_inputs_from_fla(
9898
q = torch.rand(B, seq_len, H, K, dtype=dtype, device=device)
9999
k = torch.rand(B, seq_len, H, K, dtype=dtype, device=device)
100100
v = torch.rand(B, seq_len, H, V, dtype=dtype, device=device)
101-
beta = torch.rand(B, seq_len, H, dtype=torch.float32, device=device).sigmoid().to(dtype)
101+
beta = (
102+
torch.rand(B, seq_len, H, dtype=torch.float32, device=device)
103+
.sigmoid()
104+
.to(dtype)
105+
)
102106
g = F.logsigmoid(torch.rand(B, seq_len, H, dtype=torch.float32, device=device))
103107
g = (g / gate_logit_normalizer).to(dtype)
104108
if mask_p > 0:
@@ -261,7 +265,9 @@ def test_recurrent_t1(self):
261265
model = ChunkGatedDeltaModel().eval()
262266
for seed, norm, mask_p, nonzero_h0, desc in FLA_TEST_CONFIGS:
263267
with self.subTest(desc=desc):
264-
inputs = _make_inputs_from_fla(seed, norm, mask_p, nonzero_h0, seq_len=1)
268+
inputs = _make_inputs_from_fla(
269+
seed, norm, mask_p, nonzero_h0, seq_len=1
270+
)
265271
q, k, v, g, beta, h0 = inputs
266272

267273
with torch.no_grad():
@@ -312,12 +318,8 @@ def test_dispatch_multiple_seq_lengths(self):
312318
self.assertEqual(s_ours.shape, torch.Size([B, H, K, V]))
313319
o_diff = (o_ours.float() - o_ref.float()).abs().max().item()
314320
s_diff = (s_ours.float() - s_ref.float()).abs().max().item()
315-
self.assertLess(
316-
o_diff, 0.02, f"T={seq_len}: output diff {o_diff}"
317-
)
318-
self.assertLess(
319-
s_diff, 0.02, f"T={seq_len}: state diff {s_diff}"
320-
)
321+
self.assertLess(o_diff, 0.02, f"T={seq_len}: output diff {o_diff}")
322+
self.assertLess(s_diff, 0.02, f"T={seq_len}: state diff {s_diff}")
321323

322324
def test_export_cuda(self):
323325
with tempfile.TemporaryDirectory() as tmpdir:

backends/cuda/triton/kernels/chunk_gated_delta_rule.py

Lines changed: 110 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,13 @@ def _unwrap(kernel):
6868
@triton.jit
6969
def _recurrent_gated_delta_rule_kernel(
7070
# Pointers — all inputs [B, 1, H, *] squeezed to [B, H, *]
71-
q_ptr, # [B, H, K]
72-
k_ptr, # [B, H, K]
73-
v_ptr, # [B, H, V]
74-
g_ptr, # [B, H]
75-
beta_ptr, # [B, H]
71+
q_ptr, # [B, H, K]
72+
k_ptr, # [B, H, K]
73+
v_ptr, # [B, H, V]
74+
g_ptr, # [B, H]
75+
beta_ptr, # [B, H]
7676
state_ptr, # [B, H, K, V] input state (read)
77-
o_ptr, # [B, H, V] output
77+
o_ptr, # [B, H, V] output
7878
new_state_ptr, # [B, H, K, V] output state (write)
7979
# Dims
8080
K: tl.constexpr,
@@ -137,7 +137,9 @@ def _recurrent_gated_delta_rule_kernel(
137137
o_tile = tl.sum(state_tile * q_vec[:, None], axis=0) * scale
138138

139139
# Store output tile
140-
tl.store(o_ptr + v_base + v_range, o_tile.to(o_ptr.dtype.element_ty), mask=v_mask)
140+
tl.store(
141+
o_ptr + v_base + v_range, o_tile.to(o_ptr.dtype.element_ty), mask=v_mask
142+
)
141143

142144
# Store new state tile
143145
tl.store(
@@ -212,34 +214,74 @@ def _launch_chunked(q, k, v, g, beta, initial_state, scale):
212214
# 1. chunk_local_cumsum
213215
g_cumsum = torch.empty(B, T, H, dtype=torch.float32, device=q.device)
214216
wrap_triton(_unwrap(chunk_local_cumsum_scalar_kernel))[(NT, B * H)](
215-
s=g, o=g_cumsum, scale=0, cu_seqlens=0, chunk_indices=0,
216-
T=T, B=B, H=H, BT=BT,
217-
HEAD_FIRST=False, REVERSE=False, HAS_SCALE=False, IS_VARLEN=False,
217+
s=g,
218+
o=g_cumsum,
219+
scale=0,
220+
cu_seqlens=0,
221+
chunk_indices=0,
222+
T=T,
223+
B=B,
224+
H=H,
225+
BT=BT,
226+
HEAD_FIRST=False,
227+
REVERSE=False,
228+
HAS_SCALE=False,
229+
IS_VARLEN=False,
218230
)
219231

220232
# 2. chunk_scaled_dot_kkt
221233
A = torch.empty(B, T, H, BT, device=q.device, dtype=torch.float32)
222234
wrap_triton(_unwrap(chunk_scaled_dot_kkt_fwd_kernel))[(NT, B * H)](
223-
k=k, g=g_cumsum, beta=beta, A=A,
224-
cu_seqlens=0, chunk_indices=0,
225-
T=T, H=H, K=K, BT=BT, USE_G=True, IS_VARLEN=False,
235+
k=k,
236+
g=g_cumsum,
237+
beta=beta,
238+
A=A,
239+
cu_seqlens=0,
240+
chunk_indices=0,
241+
T=T,
242+
H=H,
243+
K=K,
244+
BT=BT,
245+
USE_G=True,
246+
IS_VARLEN=False,
226247
)
227248

228249
# 3. solve_tril
229250
Ai = torch.zeros_like(A, dtype=k.dtype)
230251
wrap_triton(_unwrap(merge_16x16_to_64x64_inverse_kernel))[NT, B * H](
231-
A=A, Ai=Ai, cu_seqlens=0, chunk_indices=0,
232-
T=T, H=H, BT=BT, USE_TMA=IS_TMA_SUPPORTED, IS_VARLEN=False,
252+
A=A,
253+
Ai=Ai,
254+
cu_seqlens=0,
255+
chunk_indices=0,
256+
T=T,
257+
H=H,
258+
BT=BT,
259+
USE_TMA=IS_TMA_SUPPORTED,
260+
IS_VARLEN=False,
233261
)
234262

235263
# 4. recompute_w_u
236264
w = torch.empty_like(k)
237265
u = torch.empty_like(v)
238266
wrap_triton(_unwrap(recompute_w_u_fwd_kernel))[(NT, B * H)](
239-
k=k, v=v, beta=beta, w=w, u=u, A=Ai, g=g_cumsum,
240-
cu_seqlens=0, chunk_indices=0,
241-
T=T, H=H, K=K, V=V, BT=BT, BK=64, BV=64,
242-
USE_G=True, IS_VARLEN=False,
267+
k=k,
268+
v=v,
269+
beta=beta,
270+
w=w,
271+
u=u,
272+
A=Ai,
273+
g=g_cumsum,
274+
cu_seqlens=0,
275+
chunk_indices=0,
276+
T=T,
277+
H=H,
278+
K=K,
279+
V=V,
280+
BT=BT,
281+
BK=64,
282+
BV=64,
283+
USE_G=True,
284+
IS_VARLEN=False,
243285
)
244286

245287
# 5. chunk_gated_delta_rule_fwd_h
@@ -251,13 +293,30 @@ def grid_h(meta):
251293
return (triton.cdiv(V, meta["BV"]), B * H)
252294

253295
wrap_triton(_unwrap(chunk_gated_delta_rule_fwd_kernel_h_blockdim64))[grid_h](
254-
k=k, v=u, w=w, v_new=v_new, g=g_cumsum, gk=0,
255-
h=h, h0=initial_state, ht=final_state,
256-
cu_seqlens=0, chunk_offsets=0,
257-
T=T, H=H, K=K, V=V, BT=BT,
258-
USE_EXP2=False, TRANSPOSE_STATE=False, USE_G=True, USE_GK=False,
259-
USE_INITIAL_STATE=True, STORE_FINAL_STATE=True,
260-
SAVE_NEW_VALUE=True, IS_VARLEN=False,
296+
k=k,
297+
v=u,
298+
w=w,
299+
v_new=v_new,
300+
g=g_cumsum,
301+
gk=0,
302+
h=h,
303+
h0=initial_state,
304+
ht=final_state,
305+
cu_seqlens=0,
306+
chunk_offsets=0,
307+
T=T,
308+
H=H,
309+
K=K,
310+
V=V,
311+
BT=BT,
312+
USE_EXP2=False,
313+
TRANSPOSE_STATE=False,
314+
USE_G=True,
315+
USE_GK=False,
316+
USE_INITIAL_STATE=True,
317+
STORE_FINAL_STATE=True,
318+
SAVE_NEW_VALUE=True,
319+
IS_VARLEN=False,
261320
)
262321

263322
# 6. chunk_fwd_o
@@ -267,10 +326,25 @@ def grid_o(meta):
267326
return (triton.cdiv(V, meta["BV"]), NT, B * H)
268327

269328
wrap_triton(_unwrap(chunk_fwd_kernel_o))[grid_o](
270-
q=q, k=k, v=v_new, h=h, g=g_cumsum, g_gamma=0, o=o,
271-
cu_seqlens=0, chunk_indices=0, scale=scale,
272-
T=T, H=H, K=K, V=V, BT=BT,
273-
TRANSPOSE_STATE=False, USE_G=True, USE_G_GAMMA=False, IS_VARLEN=False,
329+
q=q,
330+
k=k,
331+
v=v_new,
332+
h=h,
333+
g=g_cumsum,
334+
g_gamma=0,
335+
o=o,
336+
cu_seqlens=0,
337+
chunk_indices=0,
338+
scale=scale,
339+
T=T,
340+
H=H,
341+
K=K,
342+
V=V,
343+
BT=BT,
344+
TRANSPOSE_STATE=False,
345+
USE_G=True,
346+
USE_G_GAMMA=False,
347+
IS_VARLEN=False,
274348
)
275349

276350
return o, final_state
@@ -299,8 +373,12 @@ def _validate_inputs(q, k, v, g, beta, initial_state):
299373
if not (q.dtype == k.dtype == v.dtype):
300374
raise ValueError("q, k, v must have the same dtype")
301375
if not (
302-
q.device == k.device == v.device
303-
== g.device == beta.device == initial_state.device
376+
q.device
377+
== k.device
378+
== v.device
379+
== g.device
380+
== beta.device
381+
== initial_state.device
304382
):
305383
raise ValueError("All tensors must be on the same device")
306384
if K > 256:
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""Benchmark recurrent vs chunked FLA in full model decode with torch.compile.
2+
3+
Usage:
4+
# Recurrent (current code):
5+
python bench_fla.py --prequantized ~/models/Qwen3.5-35B-A3B-HQQ-INT4-local --mode recurrent
6+
# Chunked (original FLA triton kernels):
7+
python bench_fla.py --prequantized ~/models/Qwen3.5-35B-A3B-HQQ-INT4-local --mode chunked
8+
"""
9+
import argparse
10+
import time
11+
import torch
12+
13+
14+
def patch_chunked():
15+
"""Restore chunked FLA in GatedDeltaNet before model construction."""
16+
import executorch.examples.models.qwen3_5_moe.model as mod
17+
18+
original_forward = mod.GatedDeltaNet.forward
19+
20+
def chunked_forward(self, x, input_pos):
21+
"""GatedDeltaNet.forward using chunked FLA triton kernels."""
22+
import torch.nn.functional as F
23+
24+
B, T, _ = x.size()
25+
26+
reset = (input_pos[0] == 0).to(self.conv_state.dtype)
27+
keep = 1.0 - reset
28+
self.conv_state[:B].mul_(keep)
29+
self.recurrent_state[:B].mul_(keep)
30+
31+
proj = self.in_proj(x)
32+
cd = self.conv_dim
33+
vd = self.value_dim
34+
nh = self.num_v_heads
35+
mixed_qkv = proj[..., :cd]
36+
z = proj[..., cd : cd + vd].reshape(B, T, self.num_v_heads, self.head_v_dim)
37+
b = proj[..., cd + vd : cd + vd + nh]
38+
a = proj[..., cd + vd + nh :]
39+
40+
qkv_t = mixed_qkv.transpose(1, 2)
41+
conv_input = torch.cat([self.conv_state[:B], qkv_t], dim=-1)
42+
with torch.no_grad():
43+
self.conv_state[:B].copy_(conv_input[:, :, -self.conv_kernel_size :])
44+
w = self.conv1d.weight.squeeze(1).float()
45+
T_conv = conv_input.shape[-1] - self.conv_kernel_size + 1
46+
acc = torch.zeros(
47+
B, conv_input.shape[1], T_conv,
48+
dtype=torch.float32, device=conv_input.device,
49+
)
50+
for k in range(self.conv_kernel_size):
51+
acc = acc + conv_input[:, :, k : k + T_conv].float() * w[:, k : k + 1]
52+
qkv_conv = F.silu(acc[:, :, -T:]).to(conv_input.dtype).transpose(1, 2)
53+
54+
kd = self.key_dim
55+
q = qkv_conv[..., :kd].reshape(B, T, self.num_k_heads, self.head_k_dim)
56+
k = qkv_conv[..., kd : 2 * kd].reshape(B, T, self.num_k_heads, self.head_k_dim)
57+
v = qkv_conv[..., 2 * kd :].reshape(B, T, self.num_v_heads, self.head_v_dim)
58+
59+
q = F.normalize(q, p=2, dim=-1)
60+
k = F.normalize(k, p=2, dim=-1)
61+
62+
if self.head_repeat > 1:
63+
q = q.repeat_interleave(self.head_repeat, dim=2)
64+
k = k.repeat_interleave(self.head_repeat, dim=2)
65+
66+
beta = b.sigmoid()
67+
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
68+
69+
# Use chunked FLA triton kernels
70+
output, state = torch.ops.triton.chunk_gated_delta_rule(
71+
q, k, v, g, beta, self.recurrent_state[:B]
72+
)
73+
with torch.no_grad():
74+
self.recurrent_state[:B].copy_(state)
75+
76+
output = output.reshape(-1, self.head_v_dim)
77+
z = z.reshape(-1, self.head_v_dim)
78+
output = self.norm(output, z)
79+
output = output.reshape(B, T, -1)
80+
81+
return self.out_proj(output)
82+
83+
mod.GatedDeltaNet.forward = chunked_forward
84+
print("Patched: using chunked FLA triton kernels")
85+
86+
87+
def main():
88+
parser = argparse.ArgumentParser()
89+
parser.add_argument("--prequantized", required=True)
90+
parser.add_argument("--mode", choices=["recurrent", "chunked"], required=True)
91+
parser.add_argument("--steps", type=int, default=50)
92+
parser.add_argument("--warmup", type=int, default=30)
93+
parser.add_argument("--no-compile", action="store_true")
94+
args = parser.parse_args()
95+
96+
# Patch BEFORE any model import if chunked
97+
if args.mode == "chunked":
98+
patch_chunked()
99+
100+
import executorch.backends.cuda.triton.kernels # register triton ops
101+
from executorch.examples.models.qwen3_5_moe.export import load_prequantized_model
102+
from executorch.examples.models.qwen3_5_moe.inference import _move_to_cuda
103+
104+
print("Loading model...")
105+
model, config = load_prequantized_model(args.prequantized, max_seq_len=4096)
106+
_move_to_cuda(model, config)
107+
model.eval()
108+
109+
if not args.no_compile:
110+
print("Compiling with torch.compile...")
111+
model = torch.compile(model, mode="default")
112+
113+
# Warmup
114+
print(f"Warming up ({args.warmup} steps)...")
115+
with torch.no_grad():
116+
for i in range(args.warmup):
117+
tok = torch.tensor([[1]], dtype=torch.long, device="cuda")
118+
pos = torch.tensor([i], dtype=torch.long, device="cuda")
119+
model(tok, pos)
120+
torch.cuda.synchronize()
121+
122+
# Benchmark
123+
print(f"Benchmarking ({args.steps} decode steps)...")
124+
torch.cuda.synchronize()
125+
t0 = time.perf_counter()
126+
with torch.no_grad():
127+
for i in range(args.steps):
128+
tok = torch.tensor([[1]], dtype=torch.long, device="cuda")
129+
pos = torch.tensor([args.warmup + i], dtype=torch.long, device="cuda")
130+
model(tok, pos)
131+
torch.cuda.synchronize()
132+
elapsed = time.perf_counter() - t0
133+
134+
tok_s = args.steps / elapsed
135+
ms_per_step = elapsed / args.steps * 1000
136+
print(f"\nResult [{args.mode}]: {tok_s:.1f} tok/s ({ms_per_step:.2f} ms/step, {args.steps} steps)")
137+
138+
139+
if __name__ == "__main__":
140+
main()

0 commit comments

Comments
 (0)