Skip to content

Commit 8f928b7

Browse files
authored
Add chunk_gated_delta_rule triton kernel for CUDA backend (#18138)
Registers FLA's chunk_gated_delta_rule as a @triton_op, following the same pattern as the existing SDPA triton kernel. Six FLA triton kernels are launched via wrap_triton() so AOTInductor compiles them directly into the generated .so — no C++ shim needed. Key trick: FLA kernels use @triton.heuristics which wrap_triton doesn't support. We unwrap via kernel.fn to get the inner @triton.autotune kernel and pass heuristic values (USE_G, IS_VARLEN, etc.) explicitly. Requires: pip install flash-linear-attention Performance numbers: https://gist.github.com/mergennachin/36bca139a04257b881e80bde6798da24 ``` C++ executor_runner performance (A100 80GB, 100 executions each, bf16) ┌─────────────────────────────┬───────────────┐ │ Config │ Per-exec (ms) │ ├─────────────────────────────┼───────────────┤ │ B=1 T=128 H=4 K=64 V=64 │ 0.458 │ ├─────────────────────────────┼───────────────┤ │ B=1 T=256 H=4 K=64 V=64 │ 0.501 │ ├─────────────────────────────┼───────────────┤ │ B=1 T=512 H=4 K=64 V=64 │ 0.686 │ ├─────────────────────────────┼───────────────┤ │ B=1 T=1024 H=4 K=64 V=64 │ 0.884 │ ├─────────────────────────────┼───────────────┤ │ B=1 T=2048 H=4 K=64 V=64 │ 1.409 │ ├─────────────────────────────┼───────────────┤ │ B=1 T=128 H=8 K=64 V=64 │ 0.535 │ ├─────────────────────────────┼───────────────┤ │ B=1 T=128 H=16 K=64 V=64 │ 0.705 │ ├─────────────────────────────┼───────────────┤ │ B=1 T=128 H=32 K=64 V=64 │ 0.987 │ ├─────────────────────────────┼───────────────┤ │ B=1 T=128 H=4 K=128 V=128 │ 0.587 │ ├─────────────────────────────┼───────────────┤ │ B=1 T=128 H=4 K=256 V=256 │ 0.943 │ ├─────────────────────────────┼───────────────┤ │ B=2 T=128 H=4 K=64 V=64 │ 0.534 │ ├─────────────────────────────┼───────────────┤ │ B=4 T=128 H=4 K=64 V=64 │ 0.861 │ ├─────────────────────────────┼───────────────┤ │ B=8 T=128 H=4 K=64 V=64 │ 1.127 │ ├─────────────────────────────┼───────────────┤ │ B=1 T=1024 H=16 K=128 V=128 │ 4.727 │ └─────────────────────────────┴───────────────┘ ```
1 parent 3906b58 commit 8f928b7

5 files changed

Lines changed: 616 additions & 1 deletion

File tree

.github/workflows/cuda.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,13 @@ jobs:
126126
cmake --workflow --preset default
127127
popd
128128
129-
# Run CUDA backend Python tests, overrides addopts so that we don't run all tests in pytest.ini
129+
# Install flash-linear-attention for chunk_gated_delta_rule triton kernel tests
130+
pip install "flash-linear-attention==0.4.2"
131+
132+
# Build executor_runner (needed by CUDA backend e2e tests)
133+
cmake --build cmake-out --target executor_runner
134+
135+
# Run all CUDA backend Python tests (including chunk_gated_delta e2e)
130136
python -m pytest backends/cuda/tests backends/cuda/passes/tests -v -o "addopts="
131137
132138
export-model-cuda-artifact:
Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Export and validate chunk_gated_delta_rule triton kernel on CUDA backend.
9+
10+
Requires: pip install flash-linear-attention
11+
12+
Usage:
13+
python -m pytest backends/cuda/tests/test_chunk_gated_delta_rule.py -v
14+
15+
# Standalone export (produces .pte + .ptd):
16+
python backends/cuda/tests/test_chunk_gated_delta_rule.py --output-dir /tmp/exports
17+
"""
18+
19+
import argparse
20+
import os
21+
import subprocess
22+
import sys
23+
import tempfile
24+
import unittest
25+
26+
import executorch.backends.cuda.triton.kernels.chunk_gated_delta_rule # noqa: F401
27+
28+
import fla # noqa: F401
29+
30+
import numpy as np
31+
import torch
32+
import torch.nn.functional as F
33+
34+
from executorch.backends.cuda.cuda_backend import CudaBackend
35+
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
36+
from executorch.exir import (
37+
EdgeCompileConfig,
38+
ExecutorchBackendConfig,
39+
to_edge_transform_and_lower,
40+
)
41+
from executorch.exir.passes import MemoryPlanningPass
42+
from torch.export import export
43+
44+
45+
B, T, H, K, V = 1, 128, 4, 64, 64
46+
47+
EXECUTORCH_ROOT = os.path.normpath(os.path.join(os.path.dirname(__file__), "../../.."))
48+
RUNNER_PATH = os.path.join(EXECUTORCH_ROOT, "cmake-out", "executor_runner")
49+
50+
# Test configurations adapted from FLA's test_gated_delta.py test_chunk()
51+
# Format: (seed, gate_logit_normalizer, mask_p, nonzero_h0, description)
52+
FLA_TEST_CONFIGS = [
53+
# Basic configs varying gate normalizer
54+
(42, 1.0, 0.0, False, "basic_norm1"),
55+
(123, 0.1, 0.0, False, "strong_gate"),
56+
(7, 10.0, 0.0, False, "weak_gate"),
57+
# Non-zero initial state
58+
(42, 1.0, 0.0, True, "nonzero_h0_norm1"),
59+
(99, 0.1, 0.0, True, "nonzero_h0_strong"),
60+
(55, 10.0, 0.0, True, "nonzero_h0_weak"),
61+
# Sparse gating (50% of gates masked to zero)
62+
(42, 1.0, 0.5, False, "sparse_gate_50pct"),
63+
(77, 0.1, 0.5, True, "sparse_strong_h0"),
64+
# Different random patterns
65+
(0, 1.0, 0.0, False, "seed0"),
66+
(100, 1.0, 0.0, True, "seed100_h0"),
67+
(2024, 0.5, 0.0, False, "norm0.5"),
68+
(999, 5.0, 0.3, True, "norm5_sparse30_h0"),
69+
# Edge-ish values
70+
(13, 0.01, 0.0, False, "very_strong_gate"),
71+
(31, 100.0, 0.0, False, "very_weak_gate"),
72+
(64, 1.0, 0.9, True, "sparse_90pct_h0"),
73+
]
74+
75+
76+
class ChunkGatedDeltaModel(torch.nn.Module):
77+
def forward(self, q, k, v, g, beta, initial_state):
78+
q = F.normalize(q, p=2, dim=-1)
79+
k = F.normalize(k, p=2, dim=-1)
80+
o, final_state = torch.ops.triton.chunk_gated_delta_rule(
81+
q, k, v, g, beta, initial_state
82+
)
83+
return o, final_state
84+
85+
86+
def _make_inputs_from_fla(
87+
seed,
88+
gate_logit_normalizer,
89+
mask_p=0.0,
90+
nonzero_h0=False,
91+
dtype=torch.bfloat16,
92+
device="cuda",
93+
):
94+
"""Generate inputs following FLA test_chunk() conventions."""
95+
torch.manual_seed(seed)
96+
q = torch.rand(B, T, H, K, dtype=dtype, device=device)
97+
k = torch.rand(B, T, H, K, dtype=dtype, device=device)
98+
v = torch.rand(B, T, H, V, dtype=dtype, device=device)
99+
beta = torch.rand(B, T, H, dtype=torch.float32, device=device).sigmoid().to(dtype)
100+
g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.float32, device=device))
101+
g = (g / gate_logit_normalizer).to(dtype)
102+
if mask_p > 0:
103+
g = g * (torch.rand(B, T, H, dtype=dtype, device=device) > mask_p)
104+
if nonzero_h0:
105+
h0 = torch.randn(B, H, K, V, dtype=dtype, device=device)
106+
else:
107+
h0 = torch.zeros(B, H, K, V, dtype=dtype, device=device)
108+
return q, k, v, g, beta, h0
109+
110+
111+
def _make_inputs(dtype=torch.bfloat16, device="cuda"):
112+
q = torch.randn(B, T, H, K, dtype=dtype, device=device)
113+
k = torch.randn(B, T, H, K, dtype=dtype, device=device)
114+
v = torch.randn(B, T, H, V, dtype=dtype, device=device)
115+
g = F.logsigmoid(torch.randn(B, T, H, dtype=dtype, device=device))
116+
beta = torch.rand(B, T, H, dtype=dtype, device=device).sigmoid()
117+
initial_state = torch.randn(B, H, K, V, dtype=dtype, device=device)
118+
return q, k, v, g, beta, initial_state
119+
120+
121+
def _save_tensor(t, path):
122+
t_cpu = t.cpu().contiguous()
123+
with open(path, "wb") as f:
124+
f.write(bytes(t_cpu.untyped_storage()))
125+
126+
127+
def _load_output(path, shape, dtype):
128+
data = np.fromfile(path, dtype=np.uint8)
129+
return torch.frombuffer(bytearray(data), dtype=dtype).reshape(shape)
130+
131+
132+
def export_chunk_gated_delta(output_dir):
133+
model = ChunkGatedDeltaModel().eval()
134+
inputs = _make_inputs()
135+
136+
with torch.no_grad():
137+
ref_o, ref_s = model(*inputs)
138+
print(f"Eager output shape: {ref_o.shape}, final_state shape: {ref_s.shape}")
139+
140+
with torch.no_grad():
141+
ep = export(model, inputs, strict=True)
142+
print("Export OK")
143+
144+
os.makedirs(output_dir, exist_ok=True)
145+
146+
specs = [CudaBackend.generate_method_name_compile_spec("forward")]
147+
et_prog = to_edge_transform_and_lower(
148+
ep,
149+
partitioner=[CudaPartitioner(specs)],
150+
compile_config=EdgeCompileConfig(
151+
_check_ir_validity=False, _skip_dim_order=True
152+
),
153+
)
154+
et_program = et_prog.to_executorch(
155+
config=ExecutorchBackendConfig(
156+
extract_delegate_segments=True,
157+
do_quant_fusion_and_const_prop=True,
158+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
159+
),
160+
)
161+
162+
pte_path = os.path.join(output_dir, "chunk_gated_delta.pte")
163+
with open(pte_path, "wb") as f:
164+
et_program.write_to_file(f)
165+
166+
if hasattr(et_program, "_tensor_data") and et_program._tensor_data:
167+
et_program.write_tensor_data_to_file(output_dir)
168+
169+
print(f"Saved to {pte_path} ({os.path.getsize(pte_path) / 1024:.0f} KB)")
170+
return pte_path
171+
172+
173+
def _run_cpp_runner(runner_path, pte_path, ptd_path, input_files, output_base):
174+
"""Run executor_runner and return subprocess result."""
175+
cmd = [
176+
runner_path,
177+
f"--model_path={pte_path}",
178+
f"--data_path={ptd_path}",
179+
f"--inputs={','.join(input_files)}",
180+
f"--output_file={output_base}",
181+
]
182+
result = subprocess.run(cmd, capture_output=True, text=True)
183+
return result
184+
185+
186+
class TestChunkGatedDeltaRule(unittest.TestCase):
187+
def setUp(self):
188+
if not torch.cuda.is_available():
189+
self.skipTest("CUDA is not available")
190+
191+
def test_eager(self):
192+
model = ChunkGatedDeltaModel().eval()
193+
inputs = _make_inputs()
194+
with torch.no_grad():
195+
o, s = model(*inputs)
196+
self.assertEqual(o.shape, torch.Size([B, T, H, V]))
197+
self.assertEqual(s.shape, torch.Size([B, H, K, V]))
198+
self.assertEqual(o.dtype, torch.bfloat16)
199+
self.assertEqual(s.dtype, torch.float32)
200+
201+
def test_eager_fla_configs(self):
202+
"""Run FLA-style test configurations and verify against naive reference."""
203+
from fla.ops.gated_delta_rule.naive import naive_recurrent_gated_delta_rule
204+
205+
model = ChunkGatedDeltaModel().eval()
206+
for seed, norm, mask_p, nonzero_h0, desc in FLA_TEST_CONFIGS:
207+
with self.subTest(desc=desc):
208+
inputs = _make_inputs_from_fla(seed, norm, mask_p, nonzero_h0)
209+
q, k, v, g, beta, h0 = inputs
210+
211+
with torch.no_grad():
212+
o_ours, s_ours = model(q, k, v, g, beta, h0)
213+
214+
o_ref, s_ref = naive_recurrent_gated_delta_rule(
215+
q=F.normalize(q, p=2, dim=-1),
216+
k=F.normalize(k, p=2, dim=-1),
217+
v=v,
218+
beta=beta,
219+
g=g,
220+
initial_state=h0,
221+
output_final_state=True,
222+
)
223+
224+
o_diff = (o_ours.float() - o_ref.float()).abs().max().item()
225+
s_diff = (s_ours.float() - s_ref.float()).abs().max().item()
226+
self.assertLess(o_diff, 0.01, f"{desc}: output diff {o_diff}")
227+
self.assertLess(s_diff, 0.01, f"{desc}: state diff {s_diff}")
228+
229+
def test_eager_matches_fla(self):
230+
from fla.ops.gated_delta_rule import chunk_gated_delta_rule as fla_impl
231+
232+
torch.manual_seed(42)
233+
inputs = _make_inputs()
234+
q, k, v, g, beta, h0 = inputs
235+
236+
q_norm = F.normalize(q, p=2, dim=-1)
237+
k_norm = F.normalize(k, p=2, dim=-1)
238+
with torch.no_grad():
239+
o_ours, _ = torch.ops.triton.chunk_gated_delta_rule(
240+
q_norm, k_norm, v, g, beta, h0
241+
)
242+
o_ref, _ = fla_impl(
243+
q,
244+
k,
245+
v,
246+
g,
247+
beta,
248+
initial_state=h0,
249+
output_final_state=True,
250+
use_qk_l2norm_in_kernel=True,
251+
)
252+
253+
self.assertLess((o_ours.float() - o_ref.float()).abs().max().item(), 0.01)
254+
255+
def test_export_cuda(self):
256+
with tempfile.TemporaryDirectory() as tmpdir:
257+
pte_path = export_chunk_gated_delta(tmpdir)
258+
self.assertTrue(os.path.exists(pte_path))
259+
self.assertGreater(os.path.getsize(pte_path), 0)
260+
261+
def test_e2e_cpp_runner(self):
262+
self.assertTrue(
263+
os.path.exists(RUNNER_PATH),
264+
f"executor_runner not found at {RUNNER_PATH}. "
265+
"Build with: cmake --build cmake-out --target executor_runner",
266+
)
267+
"""Export, run executor_runner with FLA test inputs, compare with eager."""
268+
model = ChunkGatedDeltaModel().eval()
269+
270+
with tempfile.TemporaryDirectory() as tmpdir:
271+
export_dir = os.path.join(tmpdir, "export")
272+
pte_path = export_chunk_gated_delta(export_dir)
273+
ptd_path = os.path.join(export_dir, "aoti_cuda_blob.ptd")
274+
275+
for seed, norm, mask_p, nonzero_h0, desc in FLA_TEST_CONFIGS:
276+
with self.subTest(desc=desc):
277+
inputs = _make_inputs_from_fla(seed, norm, mask_p, nonzero_h0)
278+
q, k, v, g, beta, h0 = inputs
279+
280+
with torch.no_grad():
281+
ref_o, ref_s = model(q, k, v, g, beta, h0)
282+
283+
run_dir = os.path.join(tmpdir, f"run_{desc}")
284+
os.makedirs(run_dir)
285+
286+
input_files = []
287+
for i, tensor in enumerate(inputs):
288+
path = os.path.join(run_dir, f"{i}.bin")
289+
_save_tensor(tensor, path)
290+
input_files.append(path)
291+
292+
output_base = os.path.join(run_dir, "output")
293+
result = _run_cpp_runner(
294+
RUNNER_PATH, pte_path, ptd_path, input_files, output_base
295+
)
296+
self.assertEqual(
297+
result.returncode,
298+
0,
299+
f"{desc}: executor_runner failed:\n{result.stderr}",
300+
)
301+
302+
cpp_o = _load_output(
303+
f"{output_base}-0.bin",
304+
(B, T, H, V),
305+
torch.bfloat16,
306+
)
307+
cpp_s = _load_output(
308+
f"{output_base}-1.bin",
309+
(B, H, K, V),
310+
torch.float32,
311+
)
312+
313+
o_diff = (cpp_o.float() - ref_o.cpu().float()).abs().max().item()
314+
s_diff = (cpp_s.float() - ref_s.cpu().float()).abs().max().item()
315+
self.assertLess(o_diff, 0.01, f"{desc}: output diff {o_diff}")
316+
self.assertLess(s_diff, 0.1, f"{desc}: state diff {s_diff}")
317+
318+
319+
if __name__ == "__main__":
320+
parser = argparse.ArgumentParser()
321+
parser.add_argument("--output-dir", default=None)
322+
args, remaining = parser.parse_known_args()
323+
324+
if args.output_dir:
325+
export_chunk_gated_delta(args.output_dir)
326+
else:
327+
sys.argv = [sys.argv[0]] + remaining
328+
unittest.main()

backends/cuda/triton/kernels/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,12 @@
99
__all__ = [
1010
"sdpa",
1111
]
12+
13+
try:
14+
from executorch.backends.cuda.triton.kernels.chunk_gated_delta_rule import ( # noqa: F401
15+
chunk_gated_delta_rule,
16+
)
17+
18+
__all__.append("chunk_gated_delta_rule")
19+
except ImportError:
20+
pass

0 commit comments

Comments
 (0)