Add recurrent gated delta rule custom op for Qwen3.5 attention#18088
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18088
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (2 Unrelated Failures)As of commit 7898600 with merge base f9f29e7 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Pull request overview
Adds a fused llama::recurrent_gated_delta_rule custom op and integrates it into Qwen3.5 GatedDeltaNet attention to avoid the Python per-token recurrence loop when the op is available, along with tighter custom-op library discovery/loading and new test coverage.
Changes:
- Implemented and registered
llama::recurrent_gated_delta_rule(runtime kernel + ATen/AOT registrations) and updated attention to use it with a fallback path. - Refined
custom_ops_aot_libdiscovery/loading (package-local by default, optionalEXECUTORCH_CUSTOM_OPS_AOT_LIBoverride). - Added tests for recurrent-state correctness/parity, chunked prefill behavior, and export graph op selection.
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| extension/llm/custom_ops/test_update_cache.py | Adds unit tests for recurrent gated delta rule correctness, .out behavior, and chunking parity. |
| extension/llm/custom_ops/op_tile_crop_aot.cpp | Replaces WRAP_TO_ATEN usage with explicit ET↔ATen conversion helpers for .out. |
| extension/llm/custom_ops/op_sdpa_aot.cpp | Adds ATen bindings for recurrent op; refactors multiple .out wrappers to explicit conversions. |
| extension/llm/custom_ops/op_sdpa.h | Declares the new recurrent_gated_delta_rule_out kernel signature. |
| extension/llm/custom_ops/op_sdpa.cpp | Implements recurrent kernel logic and registers the ExecuTorch kernel. |
| extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp | Refactors .out binding to explicit ET↔ATen conversion helpers. |
| extension/llm/custom_ops/custom_ops.py | Tightens custom op library discovery/loading; adds meta impl for recurrent op. |
| extension/llm/custom_ops/CMakeLists.txt | Adds MSVC /Zc:__cplusplus compile option. |
| examples/models/llama/tests/test_qwen3_5_attention.py | Adds chunked prefill parity + fused-op vs fallback parity tests. |
| examples/models/llama/tests/test_export_llama_lib.py | Adds tiny Qwen3.5 export test asserting recurrent op selection in graph. |
| examples/models/llama/attention.py | Adds lazy lookup/loading for fused recurrent op and uses it when available. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _get_custom_ops_library_override() -> Path | None: | ||
| override = os.environ.get("EXECUTORCH_CUSTOM_OPS_AOT_LIB") | ||
| if override is None: | ||
| return None | ||
|
|
||
| lib_path = Path(override).expanduser().resolve() | ||
| assert lib_path.is_file(), ( | ||
| "EXECUTORCH_CUSTOM_OPS_AOT_LIB must point to an existing " | ||
| f"custom_ops_aot_lib, but got {lib_path}" | ||
| ) | ||
| return lib_path | ||
|
|
||
|
|
||
| def _find_custom_ops_library() -> Path: | ||
| override = _get_custom_ops_library_override() | ||
| if override is not None: | ||
| return override | ||
|
|
||
| package_path = Path(__file__).parent.resolve() | ||
| logging.info(f"Looking for libcustom_ops_aot_lib.so in {package_path}") | ||
| candidates = [] | ||
| patterns = ( | ||
| "**/custom_ops_aot_lib.dll", | ||
| "**/custom_ops_aot_lib.so", | ||
| "**/custom_ops_aot_lib.dylib", | ||
| ) | ||
|
|
||
| for pattern in patterns: | ||
| candidates.extend(package_path.glob(pattern)) | ||
|
|
||
| libs = sorted({path.resolve() for path in candidates if path.is_file()}) | ||
| assert libs, f"Could not find custom_ops_aot_lib under {package_path}" | ||
| return max(libs, key=lambda path: path.stat().st_mtime) |
There was a problem hiding this comment.
Avoid using assert for runtime validation of EXECUTORCH_CUSTOM_OPS_AOT_LIB / library discovery. Assertions can be stripped with python -O, turning these into silent misconfigurations; raise a ValueError/FileNotFoundError with the same message instead.
| try: | ||
| from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 | ||
| except Exception: | ||
| return None |
There was a problem hiding this comment.
_get_recurrent_gated_delta_rule_op() swallows all exceptions when importing executorch.extension.llm.custom_ops.custom_ops. Catching broad Exception can hide real load/link errors and make debugging difficult; consider narrowing to ImportError/OSError (or logging the exception at debug level) so unexpected failures surface.
| std::vector<float> kv_mem(v_head_dim); | ||
| std::vector<float> delta(v_head_dim); | ||
|
|
There was a problem hiding this comment.
recurrent_gated_delta_rule_out allocates std::vector buffers (kv_mem, delta) inside the per-(batch, head) loop. For long sequences / many heads this adds repeated heap allocations and can dominate runtime; allocate these buffers once per call (or reuse a scratch buffer) and resize as needed, or use stack/arena allocation when sizes are small.
| namespace { | ||
| template <typename EType, typename AType> | ||
| auto to_et_arg(AType&& value) { | ||
| return executorch::extension::internal::type_convert<AType, EType>( | ||
| std::forward<AType>(value)); | ||
| } | ||
|
|
||
| at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) { | ||
| auto converted_result = | ||
| executorch::extension::internal::type_convert<Tensor&, at::Tensor>( | ||
| et_result) | ||
| .call(); | ||
| at::native::resize_output(out, converted_result.sizes()); | ||
| out.copy_(converted_result); | ||
| return out; | ||
| } |
There was a problem hiding this comment.
The to_et_arg / copy_et_result_to_out helpers are duplicated here and in other custom-op AOT wrappers (tile_crop / sdpa / fast_hadamard_transform). Consider factoring them into a shared utility header to reduce copy-paste and keep conversion semantics consistent across ops.
| if os.name == "nt": | ||
| os.add_dll_directory(str(lib_path.parent)) | ||
| torch_lib_dir = Path(torch.__file__).resolve().parent / "lib" | ||
| if torch_lib_dir.is_dir(): | ||
| os.add_dll_directory(str(torch_lib_dir)) |
There was a problem hiding this comment.
On Windows, os.add_dll_directory() returns a handle that must be kept alive; otherwise the directory is removed immediately (CPython refcounting), which can cause torch.ops.load_library() to fail to resolve dependent DLLs. Store the returned handles (e.g., in a module-level list) at least through the load (and ideally for process lifetime).
|
@lucylq this is a PR for optimizing the recurrence in Qwen 3.5, which we discussed here: #17801 (comment) I'm next going to make a PR for quantization. I'll let you know once that's up. |
| m.impl("tile_crop", torch::executor::native::tile_crop_aten); | ||
| m.impl( | ||
| "tile_crop.out", | ||
| WRAP_TO_ATEN(torch::executor::native::tile_crop_out_no_context, 2)); |
There was a problem hiding this comment.
what problem did you run into with WRAP_TO_ATEN?
There was a problem hiding this comment.
MSVC was failing for the accessed .out wrappers when building custom_ops_aot_lib in the WRAP_TO_ATEN instantiation path. I made those wrappers have explicit conversions, and it compiled cleanly on Linux/Windows
|
Can you tell me a bit more about the serialization issue you ran into as well as the MSVC one? |
| ) | ||
| return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) | ||
|
|
||
| core_attn_out = torch.zeros( |
There was a problem hiding this comment.
can you put this logic in some function called like "naive_gated_delta_rule_op" and then just have the if statement switch between them to tidy this function up a bit.
There was a problem hiding this comment.
fixed it so _recurrent_gated_delta_rule() switches between _gated_delta_rule_op() and _naive_gated_delta_rule_op()
|
|
||
| set(_common_compile_options | ||
| $<$<CXX_COMPILER_ID:MSVC>:/wd4996> | ||
| $<$<CXX_COMPILER_ID:MSVC>:/Zc:__cplusplus> |
There was a problem hiding this comment.
What codepath are you doing down that isnt triggering properly without this? Typically the c10 pattern is to just have explicit msvc conditions and not rely on the c++ version on windows iirc. I could be wrong on that though.
There was a problem hiding this comment.
you're right, i didn't need it. changed make_aten_functor_from_et_functor.h to use MSVC instead, and now it works without /Zc:__cplusplus
|
@pytorchbot label "release notes: examples" |
I believe the serialization issue is separate from this PR, as I reproduced it here and in main. _prepare_schema() wants program.fbs to be under exir/_serialize, but in a normal source checkout, it's under schema/. The error is that it can't find program.fbs. I had two separate local build issues on MSVC:
I also dropped the Happy to answer more questions @JacobSzwejbka |
|
Taking a look. |
| std::vector<float> kv_mem(v_head_dim); | ||
| std::vector<float> delta(v_head_dim); | ||
|
|
||
| for (int64_t batch = 0; batch < batch_size; ++batch) { |
There was a problem hiding this comment.
This is a naive implementation, in terms of performance with all the loops etc. Do you have plans to improve this or add a new version and leave this as a reference C++ implementation?
digantdesai
left a comment
There was a problem hiding this comment.
Ok with merging this as is. Love the tests. Thanks.
|
Let me try to merge this change. |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 11 out of 11 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| try: | ||
| from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 | ||
| except (ImportError, OSError): |
|
Hi @Phineas1500, could you check the lintrunner logs and resolve the issues brought up? Once they are fixed, it should be good to merge. |
Will do @nil-is-all ! |
|
@digantdesai made a commit addressing the linter issues. Hopefully should be good to merge after the tests run. |
|
Apologies if I clicked the wrong button. Could a maintainer approve the twelve workflows? |
|
Hi @Phineas1500, approved the workflows to run. Could you take a look at the lintrunner error logs and resolve them? |
|
@nil-is-all resolved the issues from the lintrunner logs |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 11 out of 11 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| raise FileNotFoundError( | ||
| f"Could not find custom_ops_aot_lib under {package_path}" | ||
| ) | ||
| return max(libs, key=lambda path: path.stat().st_mtime) |
There was a problem hiding this comment.
_find_custom_ops_library() collects all matching custom_ops_aot_lib binaries under the package dir and then picks the newest by mtime. If multiple candidates exist (e.g., debug/release builds or stale artifacts), this can silently load the wrong library and make behavior depend on filesystem timestamps. Consider failing fast when >1 candidate is found (and listing them), unless the user provides an explicit override via EXECUTORCH_CUSTOM_OPS_AOT_LIB.
| return max(libs, key=lambda path: path.stat().st_mtime) | |
| if len(libs) > 1: | |
| discovered = "\n - " + "\n - ".join(str(p) for p in libs) | |
| raise RuntimeError( | |
| "Found multiple candidate custom_ops_aot_lib libraries under " | |
| f"{package_path}:\n{discovered}\n" | |
| "Please set the EXECUTORCH_CUSTOM_OPS_AOT_LIB environment variable " | |
| "to the desired library to disambiguate." | |
| ) | |
| return libs[0] |
| namespace { | ||
| template <typename EType, typename AType> | ||
| auto to_et_arg(AType&& value) { | ||
| return executorch::extension::internal::type_convert<AType, EType>( | ||
| std::forward<AType>(value)); | ||
| } | ||
|
|
||
| at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) { | ||
| auto converted_result = | ||
| executorch::extension::internal::type_convert<Tensor&, at::Tensor>( | ||
| et_result) | ||
| .call(); | ||
| at::native::resize_output(out, converted_result.sizes()); | ||
| out.copy_(converted_result); | ||
| return out; | ||
| } |
There was a problem hiding this comment.
This file introduces local helpers (to_et_arg / copy_et_result_to_out) that duplicate the conversion+resize+copy logic already provided by WRAP_TO_ATEN in executorch/extension/aten_util/make_aten_functor_from_et_functor.h. Keeping multiple bespoke wrappers across ops increases maintenance cost and the risk of subtle divergence. Prefer using WRAP_TO_ATEN again, or moving these helpers into a shared utility used by all the AOT wrapper files.
| namespace { | ||
| template <typename EType, typename AType> | ||
| auto to_et_arg(AType&& value) { | ||
| return executorch::extension::internal::type_convert<AType, EType>( | ||
| std::forward<AType>(value)); | ||
| } | ||
|
|
||
| at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) { | ||
| auto converted_result = | ||
| executorch::extension::internal::type_convert<Tensor&, at::Tensor>( | ||
| et_result) | ||
| .call(); | ||
| at::native::resize_output(out, converted_result.sizes()); | ||
| out.copy_(converted_result); | ||
| return out; | ||
| } |
There was a problem hiding this comment.
This file reimplements argument conversion and out-copying helpers (to_et_arg / copy_et_result_to_out) that largely mirror WRAP_TO_ATEN’s wrapper_impl (type_convert + resize_output + copy_). Since the same helpers are also added in other custom-op AOT files, consider using WRAP_TO_ATEN directly or centralizing the helpers to avoid wrapper logic drifting across files.
| template <typename EType, typename AType> | ||
| auto to_et_arg(AType&& value) { | ||
| return executorch::extension::internal::type_convert<AType, EType>( | ||
| std::forward<AType>(value)); | ||
| } | ||
|
|
||
| at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) { | ||
| auto converted_result = | ||
| executorch::extension::internal::type_convert<Tensor&, at::Tensor>( | ||
| et_result) | ||
| .call(); | ||
| at::native::resize_output(out, converted_result.sizes()); | ||
| out.copy_(converted_result); | ||
| return out; | ||
| } |
There was a problem hiding this comment.
The new to_et_arg / copy_et_result_to_out helpers duplicate the WRAP_TO_ATEN wrapper logic from executorch/extension/aten_util/make_aten_functor_from_et_functor.h, and similar copies exist in other custom-op wrapper files. Consider switching back to WRAP_TO_ATEN or sharing a single helper to reduce duplicated conversion/copy code paths.
|
@nil-is-all do you think it's good to merge? I don't believe the current errors are related to this PR's code |
|
@digantdesai @mergennachin is there anything else you think I should do? after this PR is merged, I'm thinking of making PRs adding q8da4w support, dynamic-shape export for DeltaNet, and better multi-token prefill |
|
Hey sorry, this is good to go. Let me run Claude through to catch nit issues. If something major comes up we should fix it else this is good. |
|
@claude Review this PR |
|
Claude finished @digantdesai's task in 1m 54s —— View job PR Review: Add recurrent gated delta rule custom op for Qwen3.5 attention
Overall this is a well-structured PR with good test coverage. The kernel implementation is correct, the Python fallback path is cleanly preserved, and the test suite covers key correctness scenarios. Below are findings organized by severity. Issues1. Heap allocation inside hot loop (C++ kernel)
This was flagged by Copilot and acknowledged by @digantdesai as a known future improvement, so not blocking, but worth noting. 2. Duplicated
|
digantdesai
left a comment
There was a problem hiding this comment.
Thanks, and sorry for the delay. I will try to merge it.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| try: | ||
| from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 | ||
| except (ImportError, OSError, RuntimeError): |
There was a problem hiding this comment.
_get_recurrent_gated_delta_rule_op() attempts to import executorch.extension.llm.custom_ops.custom_ops as a best-effort fallback, but it doesn't catch FileNotFoundError. custom_ops.py can raise FileNotFoundError when custom_ops_aot_lib isn't present, which would crash attention initialization instead of cleanly falling back to the Python implementation. Consider catching FileNotFoundError here (or making custom_ops.py raise a RuntimeError that is already handled).
| except (ImportError, OSError, RuntimeError): | |
| except (ImportError, FileNotFoundError, OSError, RuntimeError): |
| _ = portable_lib | ||
|
|
||
|
|
||
| if not _is_custom_ops_registered(): |
There was a problem hiding this comment.
If EXECUTORCH_CUSTOM_OPS_AOT_LIB is set, the intent is to allow an explicit override for loading custom_ops_aot_lib. Right now _load_custom_ops_library() only runs when _is_custom_ops_registered() is false, so an override can be ignored in environments where some ops were already registered (e.g., via another shared library), preventing the override from taking effect. Consider changing the top-level guard to attempt _load_custom_ops_library() when the override env var is set (or at least when the specific expected ops like recurrent_gated_delta_rule are missing).
| if not _is_custom_ops_registered(): | |
| _custom_ops_library_override = _get_custom_ops_library_override() | |
| if _custom_ops_library_override is not None or not _is_custom_ops_registered(): |
| template <typename EType, typename AType> | ||
| auto to_et_arg(AType&& value) { | ||
| return executorch::extension::internal::type_convert<AType, EType>( | ||
| std::forward<AType>(value)); | ||
| } |
There was a problem hiding this comment.
This file now uses std::forward, but it doesn't include <utility>. Relying on transitive includes can make builds brittle across toolchains; add an explicit #include <utility> to ensure std::forward is always available.
| template <typename EType, typename AType> | ||
| auto to_et_arg(AType&& value) { | ||
| return executorch::extension::internal::type_convert<AType, EType>( | ||
| std::forward<AType>(value)); | ||
| } |
There was a problem hiding this comment.
This file now uses std::forward, but it doesn't include <utility>. Relying on transitive includes can make builds brittle across toolchains; add an explicit #include <utility> to ensure std::forward is always available.
| template <typename EType, typename AType> | ||
| auto to_et_arg(AType&& value) { | ||
| return executorch::extension::internal::type_convert<AType, EType>( | ||
| std::forward<AType>(value)); | ||
| } |
There was a problem hiding this comment.
This file now uses std::forward, but it doesn't include <utility>. Relying on transitive includes can make builds brittle across toolchains; add an explicit #include <utility> to ensure std::forward is always available.
Summary
This PR adds a fused
llama::recurrent_gated_delta_rulecustom op and wires Qwen3.5 GatedDeltaNet attention to use it instead of the Python per-token recurrence loop when the op is available.It also tightens local custom-op loading so we no longer implicitly scan repo-local
cmake-out*directories, and adds coverage for recurrent-state correctness, chunked prefill behavior, and export graph selection.What changed
llama::recurrent_gated_delta_ruleruntime and AOT registrationscustom_ops_aot_libdiscovery:EXECUTORCH_CUSTOM_OPS_AOT_LIBcmake-out*scanning.outvariant behaviorllama.recurrent_gated_delta_ruleValidation
Linux CPU-only (aarch64)
Built
custom_ops_aot_libsuccessfully and loaded it viaEXECUTORCH_CUSTOM_OPS_AOT_LIB.Passed:
pytest extension/llm/custom_ops/test_update_cache.py::RecurrentGatedDeltaRuleTest -q3 passedpytest examples/models/llama/tests/test_qwen3_5_attention.py -q7 passedpytest examples/models/llama/tests/test_export_llama_lib.py::ExportLlamaLibTest::test_tiny_qwen35_export_uses_recurrent_gated_delta_rule -q1 passedReal-model CPU validation
On a real
Qwen3.5-0.8BCPU run, fused recurrence matched the fallback path on next-token selection with very small logit drift, and improved eager prefill latency on the tested prompt.Observed on local CPU validation:
1e-51.6xon the tested promptWindows note
A local Windows-only FFHT/MSVC workaround was used during development to keep the local build usable, but that workaround is intentionally not included in this PR.
Non-goals / separate issues
I did not treat the local
program.fbsserialization issue as part of this change.This branch does not modify
exir/_serialize/*orschema/program.fbs, and serialization-focused checks passed on both this branch and cleanmainonce the local environment was set up correctly.A separate end-to-end tiny Qwen3.5
.pteexport probe hit:RuntimeError: Missing out variants: {'aten::alias'}That appears to be a separate pre-existing export issue outside this change set.
cc @larryliu0820 @mergennachin @cccclai @helunwencser @jackzhxng