Skip to content

Commit 0852840

Browse files
Revert "Add recurrent gated delta rule custom op for Qwen3.5 attentio… (#19178)
This reverts commit 476a7ef. Broke some stuff internally @digantdesai to reland
1 parent f5c1c7e commit 0852840

11 files changed

Lines changed: 94 additions & 1227 deletions

examples/models/llama/attention.py

Lines changed: 13 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
from abc import ABC, abstractmethod
32
from enum import Enum
43
from typing import Any, Dict, Optional, Tuple, Type, TypedDict
@@ -53,8 +52,6 @@ def forward(
5352

5453

5554
ATTENTION_REGISTRY: Dict[str, Type[Attention]] = {}
56-
_RECURRENT_GATED_DELTA_RULE_OP = None
57-
_TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = False
5855

5956

6057
def register_attention(name: str):
@@ -67,38 +64,6 @@ def decorator(cls: Type[Attention]):
6764
return decorator
6865

6966

70-
def _get_recurrent_gated_delta_rule_op():
71-
global _RECURRENT_GATED_DELTA_RULE_OP
72-
global _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP
73-
74-
if _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP:
75-
return _RECURRENT_GATED_DELTA_RULE_OP
76-
77-
_TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True
78-
try:
79-
_RECURRENT_GATED_DELTA_RULE_OP = (
80-
torch.ops.llama.recurrent_gated_delta_rule.default
81-
)
82-
return _RECURRENT_GATED_DELTA_RULE_OP
83-
except (AttributeError, RuntimeError):
84-
pass
85-
86-
try:
87-
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
88-
except (ImportError, OSError, RuntimeError):
89-
logging.debug("Failed to import custom ops library", exc_info=True)
90-
return None
91-
92-
try:
93-
_RECURRENT_GATED_DELTA_RULE_OP = (
94-
torch.ops.llama.recurrent_gated_delta_rule.default
95-
)
96-
except (AttributeError, RuntimeError):
97-
_RECURRENT_GATED_DELTA_RULE_OP = None
98-
99-
return _RECURRENT_GATED_DELTA_RULE_OP
100-
101-
10267
class KVCache(nn.Module):
10368
def __init__(
10469
self,
@@ -760,43 +725,28 @@ def _apply_causal_conv(self, mixed_qkv: torch.Tensor) -> torch.Tensor:
760725
out = F.silu(out[:, :, -seq_len:]).to(mixed_qkv.dtype)
761726
return out.transpose(1, 2).contiguous()
762727

763-
def _gated_delta_rule_op(
728+
def _recurrent_gated_delta_rule(
764729
self,
765730
query: torch.Tensor,
766731
key: torch.Tensor,
767732
value: torch.Tensor,
768733
g: torch.Tensor,
769734
beta: torch.Tensor,
770735
) -> torch.Tensor:
771-
batch_size = query.shape[0]
772-
recurrent_gated_delta_rule_op = _get_recurrent_gated_delta_rule_op()
773-
if recurrent_gated_delta_rule_op is not None:
774-
return recurrent_gated_delta_rule_op(
775-
query,
776-
key,
777-
value,
778-
g,
779-
beta,
780-
self.recurrent_state[:batch_size],
781-
)
782-
return self._naive_gated_delta_rule_op(
783-
query,
784-
key,
785-
value,
786-
g,
787-
beta,
788-
)
736+
# query/key/value: (batch, seq_len, num_heads, head_dim)
737+
# g/beta: (batch, seq_len, num_heads)
738+
initial_dtype = query.dtype
739+
query = _l2norm(query, dim=-1, eps=1e-6)
740+
key = _l2norm(key, dim=-1, eps=1e-6)
741+
query, key, value, beta, g = [
742+
x.transpose(1, 2).contiguous().to(torch.float32)
743+
for x in (query, key, value, beta, g)
744+
]
789745

790-
def _naive_gated_delta_rule_op(
791-
self,
792-
query: torch.Tensor,
793-
key: torch.Tensor,
794-
value: torch.Tensor,
795-
g: torch.Tensor,
796-
beta: torch.Tensor,
797-
) -> torch.Tensor:
798-
batch_size, num_heads, sequence_length, _ = key.shape
746+
batch_size, num_heads, sequence_length, k_head_dim = key.shape
799747
v_head_dim = value.shape[-1]
748+
scale = 1.0 / (query.shape[-1] ** 0.5)
749+
query = query * scale
800750

801751
core_attn_out = torch.zeros(
802752
batch_size,
@@ -830,36 +780,6 @@ def _naive_gated_delta_rule_op(
830780
last_recurrent_state.to(self.recurrent_state.dtype)
831781
)
832782

833-
return core_attn_out
834-
835-
def _recurrent_gated_delta_rule(
836-
self,
837-
query: torch.Tensor,
838-
key: torch.Tensor,
839-
value: torch.Tensor,
840-
g: torch.Tensor,
841-
beta: torch.Tensor,
842-
) -> torch.Tensor:
843-
# query/key/value: (batch, seq_len, num_heads, head_dim)
844-
# g/beta: (batch, seq_len, num_heads)
845-
initial_dtype = query.dtype
846-
query = _l2norm(query, dim=-1, eps=1e-6)
847-
key = _l2norm(key, dim=-1, eps=1e-6)
848-
query, key, value, beta, g = [
849-
x.transpose(1, 2).contiguous().to(torch.float32)
850-
for x in (query, key, value, beta, g)
851-
]
852-
853-
scale = 1.0 / (query.shape[-1] ** 0.5)
854-
query = query * scale
855-
856-
core_attn_out = self._gated_delta_rule_op(
857-
query,
858-
key,
859-
value,
860-
g,
861-
beta,
862-
)
863783
return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
864784

865785
def forward(

examples/models/llama/tests/test_export_llama_lib.py

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
import json
9-
import tempfile
108
import unittest
11-
from pathlib import Path
129

1310
from executorch.devtools.backend_debug import get_delegation_info
1411

@@ -28,7 +25,6 @@
2825

2926
from executorch.examples.models.llama.export_llama_lib import (
3027
_export_llama,
31-
_prepare_for_llama_export,
3228
build_args_parser,
3329
get_quantizer_and_quant_params,
3430
)
@@ -41,39 +37,6 @@
4137

4238

4339
class ExportLlamaLibTest(unittest.TestCase):
44-
def _make_tiny_qwen35_params(self) -> dict:
45-
return {
46-
"dim": 64,
47-
"hidden_dim": 128,
48-
"n_heads": 4,
49-
"head_dim": 16,
50-
"n_kv_heads": 2,
51-
"n_layers": 4,
52-
"norm_eps": 1e-6,
53-
"rope_theta": 10000000.0,
54-
"use_scaled_rope": False,
55-
"vocab_size": 256,
56-
"use_hf_rope": True,
57-
"partial_rotary_factor": 0.25,
58-
"attention_qkv_bias": False,
59-
"use_qk_norm": True,
60-
"qk_norm_before_rope": True,
61-
"attention_type": "mha",
62-
"use_q_gate": True,
63-
"rms_norm_add_unit_offset": True,
64-
"linear_conv_kernel_dim": 4,
65-
"linear_key_head_dim": 8,
66-
"linear_value_head_dim": 8,
67-
"linear_num_key_heads": 4,
68-
"linear_num_value_heads": 4,
69-
"layer_types": [
70-
"linear_attention",
71-
"full_attention",
72-
"linear_attention",
73-
"full_attention",
74-
],
75-
}
76-
7740
def test_has_expected_ops_and_op_counts(self):
7841
"""
7942
Checks the presence of unwanted expensive ops.
@@ -103,41 +66,6 @@ def test_has_expected_ops_and_op_counts(self):
10366
for op, _op_info in delegation_info.delegation_by_operator.items():
10467
self.assertTrue(op not in UNWANTED_OPS)
10568

106-
def test_tiny_qwen35_export_uses_recurrent_gated_delta_rule(self):
107-
with tempfile.TemporaryDirectory() as temp_dir:
108-
params_path = Path(temp_dir) / "tiny_qwen35.json"
109-
params_path.write_text(json.dumps(self._make_tiny_qwen35_params()))
110-
111-
parser = build_args_parser()
112-
args = parser.parse_args(
113-
[
114-
"--model",
115-
"qwen3_5_0_8b",
116-
"--params",
117-
str(params_path),
118-
"--use_kv_cache",
119-
"--disable_dynamic_shape",
120-
"--max_seq_length",
121-
"8",
122-
"--max_context_length",
123-
"8",
124-
]
125-
)
126-
127-
llm_config = LlmConfig.from_args(args)
128-
builder = _prepare_for_llama_export(llm_config).export()
129-
assert builder.pre_autograd_graph_module is not None
130-
131-
recurrent_nodes = [
132-
node
133-
for node in builder.pre_autograd_graph_module.graph.nodes
134-
if "auto_functionalized_v2" in str(node.target)
135-
and node.args
136-
and "llama.recurrent_gated_delta_rule" in str(node.args[0])
137-
]
138-
139-
self.assertEqual(len(recurrent_nodes), 2)
140-
14169
@unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available")
14270
def test_get_quantizer_and_quant_params_returns_tosa_quantizer(self):
14371
llm_config = LlmConfig()

examples/models/llama/tests/test_qwen3_5_attention.py

Lines changed: 0 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66

77
import unittest
88

9-
import executorch.examples.models.llama.attention as attention_module
109
import torch
11-
1210
from executorch.examples.models.llama.attention import ATTENTION_REGISTRY
1311
from executorch.examples.models.llama.model_args import ModelArgs
1412
from executorch.examples.models.llama.norm import RMSNorm
@@ -125,109 +123,6 @@ def test_gated_deltanet_no_input_pos_does_not_leak_state(self):
125123
torch.allclose(state_after_first, state_after_second, atol=1e-5)
126124
)
127125

128-
def test_gated_deltanet_chunked_prefill_matches_full_sequence(self):
129-
torch.manual_seed(0)
130-
args = self._make_args(
131-
use_kv_cache=True,
132-
use_q_gate=True,
133-
linear_conv_kernel_dim=4,
134-
linear_key_head_dim=4,
135-
linear_value_head_dim=4,
136-
linear_num_key_heads=2,
137-
linear_num_value_heads=4,
138-
)
139-
rope = Rope(args)
140-
attn_full = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
141-
attn_chunked = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
142-
attn_chunked.load_state_dict(attn_full.state_dict())
143-
144-
x = torch.randn(1, 5, args.dim)
145-
dummy_freq = torch.zeros(1, 1)
146-
147-
full_output, _ = attn_full(
148-
x,
149-
dummy_freq,
150-
dummy_freq,
151-
input_pos=torch.tensor([0], dtype=torch.long),
152-
)
153-
154-
chunk_outputs = []
155-
for start, end in ((0, 3), (3, 4), (4, 5)):
156-
output, _ = attn_chunked(
157-
x[:, start:end],
158-
dummy_freq,
159-
dummy_freq,
160-
input_pos=torch.tensor([start], dtype=torch.long),
161-
)
162-
chunk_outputs.append(output)
163-
164-
chunked_output = torch.cat(chunk_outputs, dim=1)
165-
166-
self.assertTrue(torch.allclose(chunked_output, full_output, atol=1e-5))
167-
self.assertTrue(
168-
torch.allclose(
169-
attn_chunked.recurrent_state, attn_full.recurrent_state, atol=1e-5
170-
)
171-
)
172-
self.assertTrue(
173-
torch.allclose(attn_chunked.conv_state, attn_full.conv_state, atol=1e-5)
174-
)
175-
176-
def test_gated_deltanet_custom_op_matches_fallback(self):
177-
recurrent_op = attention_module._get_recurrent_gated_delta_rule_op()
178-
if recurrent_op is None:
179-
self.skipTest("llama::recurrent_gated_delta_rule is not available")
180-
181-
torch.manual_seed(0)
182-
args = self._make_args(
183-
use_kv_cache=True,
184-
use_q_gate=True,
185-
linear_conv_kernel_dim=4,
186-
linear_key_head_dim=4,
187-
linear_value_head_dim=4,
188-
linear_num_key_heads=2,
189-
linear_num_value_heads=4,
190-
)
191-
rope = Rope(args)
192-
attn_custom = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
193-
attn_fallback = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
194-
attn_fallback.load_state_dict(attn_custom.state_dict())
195-
196-
query = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_k_dim)
197-
key = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_k_dim)
198-
value = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_v_dim)
199-
g = torch.randn(1, 3, attn_custom.num_v_heads)
200-
beta = torch.sigmoid(torch.randn(1, 3, attn_custom.num_v_heads))
201-
202-
original_op = attention_module._RECURRENT_GATED_DELTA_RULE_OP
203-
original_tried_loading = (
204-
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP
205-
)
206-
try:
207-
attention_module._RECURRENT_GATED_DELTA_RULE_OP = recurrent_op
208-
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True
209-
custom_output = attn_custom._recurrent_gated_delta_rule(
210-
query, key, value, g, beta
211-
)
212-
213-
attention_module._RECURRENT_GATED_DELTA_RULE_OP = None
214-
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True
215-
fallback_output = attn_fallback._recurrent_gated_delta_rule(
216-
query, key, value, g, beta
217-
)
218-
finally:
219-
attention_module._RECURRENT_GATED_DELTA_RULE_OP = original_op
220-
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = (
221-
original_tried_loading
222-
)
223-
224-
self.assertTrue(torch.allclose(custom_output, fallback_output, atol=1e-5))
225-
self.assertTrue(
226-
torch.allclose(
227-
attn_custom.recurrent_state, attn_fallback.recurrent_state, atol=1e-5
228-
)
229-
)
230-
231126

232127
if __name__ == "__main__":
233128
unittest.main()

extension/aten_util/make_aten_functor_from_et_functor.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
#pragma once
1616
#include <type_traits>
1717
#include <vector>
18-
#if (defined(_MSC_VER) && (!defined(_MSVC_LANG) || _MSVC_LANG < 201703L)) || \
19-
(!defined(_MSC_VER) && __cplusplus < 201703L)
18+
#if __cplusplus < 201703L
2019
#error "This header requires C++17"
2120
#endif
2221
#include <ATen/native/Resize.h>

0 commit comments

Comments
 (0)