You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add recurrent gated delta rule custom op for Qwen3.5 attention (pytorch#18088)
## Summary
This PR adds a fused `llama::recurrent_gated_delta_rule` custom op and
wires Qwen3.5 GatedDeltaNet attention to use it instead of the Python
per-token recurrence loop when the op is available.
It also tightens local custom-op loading so we no longer implicitly scan
repo-local `cmake-out*` directories, and adds coverage for
recurrent-state correctness, chunked prefill behavior, and export graph
selection.
## What changed
- added `llama::recurrent_gated_delta_rule` runtime and AOT
registrations
- updated Qwen3.5 GatedDeltaNet attention to use the fused op with
Python fallback preserved
- tightened `custom_ops_aot_lib` discovery:
- default to package-local discovery
- allow explicit override via `EXECUTORCH_CUSTOM_OPS_AOT_LIB`
- removed implicit repo-local `cmake-out*` scanning
- added tests for:
- recurrent op parity vs reference
- `.out` variant behavior
- chunked-state parity vs full-sequence execution
- custom-op vs fallback attention parity
- tiny Qwen3.5 export selecting `llama.recurrent_gated_delta_rule`
## Validation
### Linux CPU-only (aarch64)
Built `custom_ops_aot_lib` successfully and loaded it via
`EXECUTORCH_CUSTOM_OPS_AOT_LIB`.
Passed:
- `pytest
extension/llm/custom_ops/test_update_cache.py::RecurrentGatedDeltaRuleTest
-q`
- `3 passed`
- `pytest examples/models/llama/tests/test_qwen3_5_attention.py -q`
- `7 passed`
- `pytest
examples/models/llama/tests/test_export_llama_lib.py::ExportLlamaLibTest::test_tiny_qwen35_export_uses_recurrent_gated_delta_rule
-q`
- `1 passed`
### Real-model CPU validation
On a real `Qwen3.5-0.8B` CPU run, fused recurrence matched the fallback
path on next-token selection with very small logit drift, and improved
eager prefill latency on the tested prompt.
Observed on local CPU validation:
- same next token from fused path vs fallback
- max logit diff on the order of `1e-5`
- eager prefill speedup about `1.6x` on the tested prompt
### Windows note
A local Windows-only FFHT/MSVC workaround was used during development to
keep the local build usable, but that workaround is intentionally
**not** included in this PR.
## Non-goals / separate issues
I did not treat the local `program.fbs` serialization issue as part of
this change.
This branch does not modify `exir/_serialize/*` or `schema/program.fbs`,
and serialization-focused checks passed on both this branch and clean
`main` once the local environment was set up correctly.
A separate end-to-end tiny Qwen3.5 `.pte` export probe hit:
- `RuntimeError: Missing out variants: {'aten::alias'}`
That appears to be a separate pre-existing export issue outside this
change set.
cc @larryliu0820@mergennachin@cccclai@helunwencser@jackzhxng
---------
Co-authored-by: Digant Desai <digantdesai@meta.com>
Co-authored-by: Nikhil Viswanath Sivakumar <68182521+nil-is-all@users.noreply.github.com>
0 commit comments