Skip to content

Commit 476a7ef

Browse files
Phineas1500digantdesainil-is-all
authored
Add recurrent gated delta rule custom op for Qwen3.5 attention (pytorch#18088)
## Summary This PR adds a fused `llama::recurrent_gated_delta_rule` custom op and wires Qwen3.5 GatedDeltaNet attention to use it instead of the Python per-token recurrence loop when the op is available. It also tightens local custom-op loading so we no longer implicitly scan repo-local `cmake-out*` directories, and adds coverage for recurrent-state correctness, chunked prefill behavior, and export graph selection. ## What changed - added `llama::recurrent_gated_delta_rule` runtime and AOT registrations - updated Qwen3.5 GatedDeltaNet attention to use the fused op with Python fallback preserved - tightened `custom_ops_aot_lib` discovery: - default to package-local discovery - allow explicit override via `EXECUTORCH_CUSTOM_OPS_AOT_LIB` - removed implicit repo-local `cmake-out*` scanning - added tests for: - recurrent op parity vs reference - `.out` variant behavior - chunked-state parity vs full-sequence execution - custom-op vs fallback attention parity - tiny Qwen3.5 export selecting `llama.recurrent_gated_delta_rule` ## Validation ### Linux CPU-only (aarch64) Built `custom_ops_aot_lib` successfully and loaded it via `EXECUTORCH_CUSTOM_OPS_AOT_LIB`. Passed: - `pytest extension/llm/custom_ops/test_update_cache.py::RecurrentGatedDeltaRuleTest -q` - `3 passed` - `pytest examples/models/llama/tests/test_qwen3_5_attention.py -q` - `7 passed` - `pytest examples/models/llama/tests/test_export_llama_lib.py::ExportLlamaLibTest::test_tiny_qwen35_export_uses_recurrent_gated_delta_rule -q` - `1 passed` ### Real-model CPU validation On a real `Qwen3.5-0.8B` CPU run, fused recurrence matched the fallback path on next-token selection with very small logit drift, and improved eager prefill latency on the tested prompt. Observed on local CPU validation: - same next token from fused path vs fallback - max logit diff on the order of `1e-5` - eager prefill speedup about `1.6x` on the tested prompt ### Windows note A local Windows-only FFHT/MSVC workaround was used during development to keep the local build usable, but that workaround is intentionally **not** included in this PR. ## Non-goals / separate issues I did not treat the local `program.fbs` serialization issue as part of this change. This branch does not modify `exir/_serialize/*` or `schema/program.fbs`, and serialization-focused checks passed on both this branch and clean `main` once the local environment was set up correctly. A separate end-to-end tiny Qwen3.5 `.pte` export probe hit: - `RuntimeError: Missing out variants: {'aten::alias'}` That appears to be a separate pre-existing export issue outside this change set. cc @larryliu0820 @mergennachin @cccclai @helunwencser @jackzhxng --------- Co-authored-by: Digant Desai <digantdesai@meta.com> Co-authored-by: Nikhil Viswanath Sivakumar <68182521+nil-is-all@users.noreply.github.com>
1 parent 98a1d66 commit 476a7ef

11 files changed

Lines changed: 1227 additions & 94 deletions

examples/models/llama/attention.py

Lines changed: 93 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from abc import ABC, abstractmethod
23
from enum import Enum
34
from typing import Any, Dict, Optional, Tuple, Type, TypedDict
@@ -52,6 +53,8 @@ def forward(
5253

5354

5455
ATTENTION_REGISTRY: Dict[str, Type[Attention]] = {}
56+
_RECURRENT_GATED_DELTA_RULE_OP = None
57+
_TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = False
5558

5659

5760
def register_attention(name: str):
@@ -64,6 +67,38 @@ def decorator(cls: Type[Attention]):
6467
return decorator
6568

6669

70+
def _get_recurrent_gated_delta_rule_op():
71+
global _RECURRENT_GATED_DELTA_RULE_OP
72+
global _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP
73+
74+
if _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP:
75+
return _RECURRENT_GATED_DELTA_RULE_OP
76+
77+
_TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True
78+
try:
79+
_RECURRENT_GATED_DELTA_RULE_OP = (
80+
torch.ops.llama.recurrent_gated_delta_rule.default
81+
)
82+
return _RECURRENT_GATED_DELTA_RULE_OP
83+
except (AttributeError, RuntimeError):
84+
pass
85+
86+
try:
87+
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
88+
except (ImportError, OSError, RuntimeError):
89+
logging.debug("Failed to import custom ops library", exc_info=True)
90+
return None
91+
92+
try:
93+
_RECURRENT_GATED_DELTA_RULE_OP = (
94+
torch.ops.llama.recurrent_gated_delta_rule.default
95+
)
96+
except (AttributeError, RuntimeError):
97+
_RECURRENT_GATED_DELTA_RULE_OP = None
98+
99+
return _RECURRENT_GATED_DELTA_RULE_OP
100+
101+
67102
class KVCache(nn.Module):
68103
def __init__(
69104
self,
@@ -725,28 +760,43 @@ def _apply_causal_conv(self, mixed_qkv: torch.Tensor) -> torch.Tensor:
725760
out = F.silu(out[:, :, -seq_len:]).to(mixed_qkv.dtype)
726761
return out.transpose(1, 2).contiguous()
727762

728-
def _recurrent_gated_delta_rule(
763+
def _gated_delta_rule_op(
729764
self,
730765
query: torch.Tensor,
731766
key: torch.Tensor,
732767
value: torch.Tensor,
733768
g: torch.Tensor,
734769
beta: torch.Tensor,
735770
) -> torch.Tensor:
736-
# query/key/value: (batch, seq_len, num_heads, head_dim)
737-
# g/beta: (batch, seq_len, num_heads)
738-
initial_dtype = query.dtype
739-
query = _l2norm(query, dim=-1, eps=1e-6)
740-
key = _l2norm(key, dim=-1, eps=1e-6)
741-
query, key, value, beta, g = [
742-
x.transpose(1, 2).contiguous().to(torch.float32)
743-
for x in (query, key, value, beta, g)
744-
]
771+
batch_size = query.shape[0]
772+
recurrent_gated_delta_rule_op = _get_recurrent_gated_delta_rule_op()
773+
if recurrent_gated_delta_rule_op is not None:
774+
return recurrent_gated_delta_rule_op(
775+
query,
776+
key,
777+
value,
778+
g,
779+
beta,
780+
self.recurrent_state[:batch_size],
781+
)
782+
return self._naive_gated_delta_rule_op(
783+
query,
784+
key,
785+
value,
786+
g,
787+
beta,
788+
)
745789

746-
batch_size, num_heads, sequence_length, k_head_dim = key.shape
790+
def _naive_gated_delta_rule_op(
791+
self,
792+
query: torch.Tensor,
793+
key: torch.Tensor,
794+
value: torch.Tensor,
795+
g: torch.Tensor,
796+
beta: torch.Tensor,
797+
) -> torch.Tensor:
798+
batch_size, num_heads, sequence_length, _ = key.shape
747799
v_head_dim = value.shape[-1]
748-
scale = 1.0 / (query.shape[-1] ** 0.5)
749-
query = query * scale
750800

751801
core_attn_out = torch.zeros(
752802
batch_size,
@@ -780,6 +830,36 @@ def _recurrent_gated_delta_rule(
780830
last_recurrent_state.to(self.recurrent_state.dtype)
781831
)
782832

833+
return core_attn_out
834+
835+
def _recurrent_gated_delta_rule(
836+
self,
837+
query: torch.Tensor,
838+
key: torch.Tensor,
839+
value: torch.Tensor,
840+
g: torch.Tensor,
841+
beta: torch.Tensor,
842+
) -> torch.Tensor:
843+
# query/key/value: (batch, seq_len, num_heads, head_dim)
844+
# g/beta: (batch, seq_len, num_heads)
845+
initial_dtype = query.dtype
846+
query = _l2norm(query, dim=-1, eps=1e-6)
847+
key = _l2norm(key, dim=-1, eps=1e-6)
848+
query, key, value, beta, g = [
849+
x.transpose(1, 2).contiguous().to(torch.float32)
850+
for x in (query, key, value, beta, g)
851+
]
852+
853+
scale = 1.0 / (query.shape[-1] ** 0.5)
854+
query = query * scale
855+
856+
core_attn_out = self._gated_delta_rule_op(
857+
query,
858+
key,
859+
value,
860+
g,
861+
beta,
862+
)
783863
return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
784864

785865
def forward(

examples/models/llama/tests/test_export_llama_lib.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
import json
9+
import tempfile
810
import unittest
11+
from pathlib import Path
912

1013
from executorch.devtools.backend_debug import get_delegation_info
1114

@@ -25,6 +28,7 @@
2528

2629
from executorch.examples.models.llama.export_llama_lib import (
2730
_export_llama,
31+
_prepare_for_llama_export,
2832
build_args_parser,
2933
get_quantizer_and_quant_params,
3034
)
@@ -37,6 +41,39 @@
3741

3842

3943
class ExportLlamaLibTest(unittest.TestCase):
44+
def _make_tiny_qwen35_params(self) -> dict:
45+
return {
46+
"dim": 64,
47+
"hidden_dim": 128,
48+
"n_heads": 4,
49+
"head_dim": 16,
50+
"n_kv_heads": 2,
51+
"n_layers": 4,
52+
"norm_eps": 1e-6,
53+
"rope_theta": 10000000.0,
54+
"use_scaled_rope": False,
55+
"vocab_size": 256,
56+
"use_hf_rope": True,
57+
"partial_rotary_factor": 0.25,
58+
"attention_qkv_bias": False,
59+
"use_qk_norm": True,
60+
"qk_norm_before_rope": True,
61+
"attention_type": "mha",
62+
"use_q_gate": True,
63+
"rms_norm_add_unit_offset": True,
64+
"linear_conv_kernel_dim": 4,
65+
"linear_key_head_dim": 8,
66+
"linear_value_head_dim": 8,
67+
"linear_num_key_heads": 4,
68+
"linear_num_value_heads": 4,
69+
"layer_types": [
70+
"linear_attention",
71+
"full_attention",
72+
"linear_attention",
73+
"full_attention",
74+
],
75+
}
76+
4077
def test_has_expected_ops_and_op_counts(self):
4178
"""
4279
Checks the presence of unwanted expensive ops.
@@ -66,6 +103,41 @@ def test_has_expected_ops_and_op_counts(self):
66103
for op, _op_info in delegation_info.delegation_by_operator.items():
67104
self.assertTrue(op not in UNWANTED_OPS)
68105

106+
def test_tiny_qwen35_export_uses_recurrent_gated_delta_rule(self):
107+
with tempfile.TemporaryDirectory() as temp_dir:
108+
params_path = Path(temp_dir) / "tiny_qwen35.json"
109+
params_path.write_text(json.dumps(self._make_tiny_qwen35_params()))
110+
111+
parser = build_args_parser()
112+
args = parser.parse_args(
113+
[
114+
"--model",
115+
"qwen3_5_0_8b",
116+
"--params",
117+
str(params_path),
118+
"--use_kv_cache",
119+
"--disable_dynamic_shape",
120+
"--max_seq_length",
121+
"8",
122+
"--max_context_length",
123+
"8",
124+
]
125+
)
126+
127+
llm_config = LlmConfig.from_args(args)
128+
builder = _prepare_for_llama_export(llm_config).export()
129+
assert builder.pre_autograd_graph_module is not None
130+
131+
recurrent_nodes = [
132+
node
133+
for node in builder.pre_autograd_graph_module.graph.nodes
134+
if "auto_functionalized_v2" in str(node.target)
135+
and node.args
136+
and "llama.recurrent_gated_delta_rule" in str(node.args[0])
137+
]
138+
139+
self.assertEqual(len(recurrent_nodes), 2)
140+
69141
@unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available")
70142
def test_get_quantizer_and_quant_params_returns_tosa_quantizer(self):
71143
llm_config = LlmConfig()

examples/models/llama/tests/test_qwen3_5_attention.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
import unittest
88

9+
import executorch.examples.models.llama.attention as attention_module
910
import torch
11+
1012
from executorch.examples.models.llama.attention import ATTENTION_REGISTRY
1113
from executorch.examples.models.llama.model_args import ModelArgs
1214
from executorch.examples.models.llama.norm import RMSNorm
@@ -123,6 +125,109 @@ def test_gated_deltanet_no_input_pos_does_not_leak_state(self):
123125
torch.allclose(state_after_first, state_after_second, atol=1e-5)
124126
)
125127

128+
def test_gated_deltanet_chunked_prefill_matches_full_sequence(self):
129+
torch.manual_seed(0)
130+
args = self._make_args(
131+
use_kv_cache=True,
132+
use_q_gate=True,
133+
linear_conv_kernel_dim=4,
134+
linear_key_head_dim=4,
135+
linear_value_head_dim=4,
136+
linear_num_key_heads=2,
137+
linear_num_value_heads=4,
138+
)
139+
rope = Rope(args)
140+
attn_full = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
141+
attn_chunked = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
142+
attn_chunked.load_state_dict(attn_full.state_dict())
143+
144+
x = torch.randn(1, 5, args.dim)
145+
dummy_freq = torch.zeros(1, 1)
146+
147+
full_output, _ = attn_full(
148+
x,
149+
dummy_freq,
150+
dummy_freq,
151+
input_pos=torch.tensor([0], dtype=torch.long),
152+
)
153+
154+
chunk_outputs = []
155+
for start, end in ((0, 3), (3, 4), (4, 5)):
156+
output, _ = attn_chunked(
157+
x[:, start:end],
158+
dummy_freq,
159+
dummy_freq,
160+
input_pos=torch.tensor([start], dtype=torch.long),
161+
)
162+
chunk_outputs.append(output)
163+
164+
chunked_output = torch.cat(chunk_outputs, dim=1)
165+
166+
self.assertTrue(torch.allclose(chunked_output, full_output, atol=1e-5))
167+
self.assertTrue(
168+
torch.allclose(
169+
attn_chunked.recurrent_state, attn_full.recurrent_state, atol=1e-5
170+
)
171+
)
172+
self.assertTrue(
173+
torch.allclose(attn_chunked.conv_state, attn_full.conv_state, atol=1e-5)
174+
)
175+
176+
def test_gated_deltanet_custom_op_matches_fallback(self):
177+
recurrent_op = attention_module._get_recurrent_gated_delta_rule_op()
178+
if recurrent_op is None:
179+
self.skipTest("llama::recurrent_gated_delta_rule is not available")
180+
181+
torch.manual_seed(0)
182+
args = self._make_args(
183+
use_kv_cache=True,
184+
use_q_gate=True,
185+
linear_conv_kernel_dim=4,
186+
linear_key_head_dim=4,
187+
linear_value_head_dim=4,
188+
linear_num_key_heads=2,
189+
linear_num_value_heads=4,
190+
)
191+
rope = Rope(args)
192+
attn_custom = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
193+
attn_fallback = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
194+
attn_fallback.load_state_dict(attn_custom.state_dict())
195+
196+
query = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_k_dim)
197+
key = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_k_dim)
198+
value = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_v_dim)
199+
g = torch.randn(1, 3, attn_custom.num_v_heads)
200+
beta = torch.sigmoid(torch.randn(1, 3, attn_custom.num_v_heads))
201+
202+
original_op = attention_module._RECURRENT_GATED_DELTA_RULE_OP
203+
original_tried_loading = (
204+
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP
205+
)
206+
try:
207+
attention_module._RECURRENT_GATED_DELTA_RULE_OP = recurrent_op
208+
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True
209+
custom_output = attn_custom._recurrent_gated_delta_rule(
210+
query, key, value, g, beta
211+
)
212+
213+
attention_module._RECURRENT_GATED_DELTA_RULE_OP = None
214+
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True
215+
fallback_output = attn_fallback._recurrent_gated_delta_rule(
216+
query, key, value, g, beta
217+
)
218+
finally:
219+
attention_module._RECURRENT_GATED_DELTA_RULE_OP = original_op
220+
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = (
221+
original_tried_loading
222+
)
223+
224+
self.assertTrue(torch.allclose(custom_output, fallback_output, atol=1e-5))
225+
self.assertTrue(
226+
torch.allclose(
227+
attn_custom.recurrent_state, attn_fallback.recurrent_state, atol=1e-5
228+
)
229+
)
230+
126231

127232
if __name__ == "__main__":
128233
unittest.main()

extension/aten_util/make_aten_functor_from_et_functor.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
#pragma once
1616
#include <type_traits>
1717
#include <vector>
18-
#if __cplusplus < 201703L
18+
#if (defined(_MSC_VER) && (!defined(_MSVC_LANG) || _MSVC_LANG < 201703L)) || \
19+
(!defined(_MSC_VER) && __cplusplus < 201703L)
1920
#error "This header requires C++17"
2021
#endif
2122
#include <ATen/native/Resize.h>

0 commit comments

Comments
 (0)