Skip to content

Add recurrent gated delta rule custom op for Qwen3.5 attention#18088

Merged
digantdesai merged 9 commits intopytorch:mainfrom
Phineas1500:feature/recurrent-gated-delta-rule-windows
Apr 24, 2026
Merged

Add recurrent gated delta rule custom op for Qwen3.5 attention#18088
digantdesai merged 9 commits intopytorch:mainfrom
Phineas1500:feature/recurrent-gated-delta-rule-windows

Conversation

@Phineas1500
Copy link
Copy Markdown
Contributor

@Phineas1500 Phineas1500 commented Mar 11, 2026

Summary

This PR adds a fused llama::recurrent_gated_delta_rule custom op and wires Qwen3.5 GatedDeltaNet attention to use it instead of the Python per-token recurrence loop when the op is available.

It also tightens local custom-op loading so we no longer implicitly scan repo-local cmake-out* directories, and adds coverage for recurrent-state correctness, chunked prefill behavior, and export graph selection.

What changed

  • added llama::recurrent_gated_delta_rule runtime and AOT registrations
  • updated Qwen3.5 GatedDeltaNet attention to use the fused op with Python fallback preserved
  • tightened custom_ops_aot_lib discovery:
    • default to package-local discovery
    • allow explicit override via EXECUTORCH_CUSTOM_OPS_AOT_LIB
    • removed implicit repo-local cmake-out* scanning
  • added tests for:
    • recurrent op parity vs reference
    • .out variant behavior
    • chunked-state parity vs full-sequence execution
    • custom-op vs fallback attention parity
    • tiny Qwen3.5 export selecting llama.recurrent_gated_delta_rule

Validation

Linux CPU-only (aarch64)

Built custom_ops_aot_lib successfully and loaded it via EXECUTORCH_CUSTOM_OPS_AOT_LIB.

Passed:

  • pytest extension/llm/custom_ops/test_update_cache.py::RecurrentGatedDeltaRuleTest -q
    • 3 passed
  • pytest examples/models/llama/tests/test_qwen3_5_attention.py -q
    • 7 passed
  • pytest examples/models/llama/tests/test_export_llama_lib.py::ExportLlamaLibTest::test_tiny_qwen35_export_uses_recurrent_gated_delta_rule -q
    • 1 passed

Real-model CPU validation

On a real Qwen3.5-0.8B CPU run, fused recurrence matched the fallback path on next-token selection with very small logit drift, and improved eager prefill latency on the tested prompt.

Observed on local CPU validation:

  • same next token from fused path vs fallback
  • max logit diff on the order of 1e-5
  • eager prefill speedup about 1.6x on the tested prompt

Windows note

A local Windows-only FFHT/MSVC workaround was used during development to keep the local build usable, but that workaround is intentionally not included in this PR.

Non-goals / separate issues

I did not treat the local program.fbs serialization issue as part of this change.

This branch does not modify exir/_serialize/* or schema/program.fbs, and serialization-focused checks passed on both this branch and clean main once the local environment was set up correctly.

A separate end-to-end tiny Qwen3.5 .pte export probe hit:

  • RuntimeError: Missing out variants: {'aten::alias'}

That appears to be a separate pre-existing export issue outside this change set.

cc @larryliu0820 @mergennachin @cccclai @helunwencser @jackzhxng

Copilot AI review requested due to automatic review settings March 11, 2026 04:36
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 11, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18088

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (2 Unrelated Failures)

As of commit 7898600 with merge base f9f29e7 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 11, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a fused llama::recurrent_gated_delta_rule custom op and integrates it into Qwen3.5 GatedDeltaNet attention to avoid the Python per-token recurrence loop when the op is available, along with tighter custom-op library discovery/loading and new test coverage.

Changes:

  • Implemented and registered llama::recurrent_gated_delta_rule (runtime kernel + ATen/AOT registrations) and updated attention to use it with a fallback path.
  • Refined custom_ops_aot_lib discovery/loading (package-local by default, optional EXECUTORCH_CUSTOM_OPS_AOT_LIB override).
  • Added tests for recurrent-state correctness/parity, chunked prefill behavior, and export graph op selection.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
extension/llm/custom_ops/test_update_cache.py Adds unit tests for recurrent gated delta rule correctness, .out behavior, and chunking parity.
extension/llm/custom_ops/op_tile_crop_aot.cpp Replaces WRAP_TO_ATEN usage with explicit ET↔ATen conversion helpers for .out.
extension/llm/custom_ops/op_sdpa_aot.cpp Adds ATen bindings for recurrent op; refactors multiple .out wrappers to explicit conversions.
extension/llm/custom_ops/op_sdpa.h Declares the new recurrent_gated_delta_rule_out kernel signature.
extension/llm/custom_ops/op_sdpa.cpp Implements recurrent kernel logic and registers the ExecuTorch kernel.
extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp Refactors .out binding to explicit ET↔ATen conversion helpers.
extension/llm/custom_ops/custom_ops.py Tightens custom op library discovery/loading; adds meta impl for recurrent op.
extension/llm/custom_ops/CMakeLists.txt Adds MSVC /Zc:__cplusplus compile option.
examples/models/llama/tests/test_qwen3_5_attention.py Adds chunked prefill parity + fused-op vs fallback parity tests.
examples/models/llama/tests/test_export_llama_lib.py Adds tiny Qwen3.5 export test asserting recurrent op selection in graph.
examples/models/llama/attention.py Adds lazy lookup/loading for fused recurrent op and uses it when available.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +37 to +68
def _get_custom_ops_library_override() -> Path | None:
override = os.environ.get("EXECUTORCH_CUSTOM_OPS_AOT_LIB")
if override is None:
return None

lib_path = Path(override).expanduser().resolve()
assert lib_path.is_file(), (
"EXECUTORCH_CUSTOM_OPS_AOT_LIB must point to an existing "
f"custom_ops_aot_lib, but got {lib_path}"
)
return lib_path


def _find_custom_ops_library() -> Path:
override = _get_custom_ops_library_override()
if override is not None:
return override

package_path = Path(__file__).parent.resolve()
logging.info(f"Looking for libcustom_ops_aot_lib.so in {package_path}")
candidates = []
patterns = (
"**/custom_ops_aot_lib.dll",
"**/custom_ops_aot_lib.so",
"**/custom_ops_aot_lib.dylib",
)

for pattern in patterns:
candidates.extend(package_path.glob(pattern))

libs = sorted({path.resolve() for path in candidates if path.is_file()})
assert libs, f"Could not find custom_ops_aot_lib under {package_path}"
return max(libs, key=lambda path: path.stat().st_mtime)
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid using assert for runtime validation of EXECUTORCH_CUSTOM_OPS_AOT_LIB / library discovery. Assertions can be stripped with python -O, turning these into silent misconfigurations; raise a ValueError/FileNotFoundError with the same message instead.

Copilot uses AI. Check for mistakes.
Comment on lines +81 to +84
try:
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
except Exception:
return None
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_recurrent_gated_delta_rule_op() swallows all exceptions when importing executorch.extension.llm.custom_ops.custom_ops. Catching broad Exception can hide real load/link errors and make debugging difficult; consider narrowing to ImportError/OSError (or logging the exception at debug level) so unexpected failures surface.

Copilot uses AI. Check for mistakes.
Comment thread extension/llm/custom_ops/op_sdpa.cpp Outdated
Comment on lines +756 to +758
std::vector<float> kv_mem(v_head_dim);
std::vector<float> delta(v_head_dim);

Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recurrent_gated_delta_rule_out allocates std::vector buffers (kv_mem, delta) inside the per-(batch, head) loop. For long sequences / many heads this adds repeated heap allocations and can dominate runtime; allocate these buffers once per call (or reuse a scratch buffer) and resize as needed, or use stack/arena allocation when sizes are small.

Copilot uses AI. Check for mistakes.
Comment on lines +20 to +35
namespace {
template <typename EType, typename AType>
auto to_et_arg(AType&& value) {
return executorch::extension::internal::type_convert<AType, EType>(
std::forward<AType>(value));
}

at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) {
auto converted_result =
executorch::extension::internal::type_convert<Tensor&, at::Tensor>(
et_result)
.call();
at::native::resize_output(out, converted_result.sizes());
out.copy_(converted_result);
return out;
}
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The to_et_arg / copy_et_result_to_out helpers are duplicated here and in other custom-op AOT wrappers (tile_crop / sdpa / fast_hadamard_transform). Consider factoring them into a shared utility header to reduce copy-paste and keep conversion semantics consistent across ops.

Copilot uses AI. Check for mistakes.
Comment thread extension/llm/custom_ops/custom_ops.py Outdated
Comment on lines +82 to +86
if os.name == "nt":
os.add_dll_directory(str(lib_path.parent))
torch_lib_dir = Path(torch.__file__).resolve().parent / "lib"
if torch_lib_dir.is_dir():
os.add_dll_directory(str(torch_lib_dir))
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Windows, os.add_dll_directory() returns a handle that must be kept alive; otherwise the directory is removed immediately (CPython refcounting), which can cause torch.ops.load_library() to fail to resolve dependent DLLs. Store the returned handles (e.g., in a module-level list) at least through the load (and ideally for process lifetime).

Copilot uses AI. Check for mistakes.
@Phineas1500
Copy link
Copy Markdown
Contributor Author

@lucylq this is a PR for optimizing the recurrence in Qwen 3.5, which we discussed here: #17801 (comment)

I'm next going to make a PR for quantization. I'll let you know once that's up.

m.impl("tile_crop", torch::executor::native::tile_crop_aten);
m.impl(
"tile_crop.out",
WRAP_TO_ATEN(torch::executor::native::tile_crop_out_no_context, 2));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what problem did you run into with WRAP_TO_ATEN?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MSVC was failing for the accessed .out wrappers when building custom_ops_aot_lib in the WRAP_TO_ATEN instantiation path. I made those wrappers have explicit conversions, and it compiled cleanly on Linux/Windows

@JacobSzwejbka
Copy link
Copy Markdown
Contributor

Can you tell me a bit more about the serialization issue you ran into as well as the MSVC one?

)
return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)

core_attn_out = torch.zeros(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you put this logic in some function called like "naive_gated_delta_rule_op" and then just have the if statement switch between them to tidy this function up a bit.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed it so _recurrent_gated_delta_rule() switches between _gated_delta_rule_op() and _naive_gated_delta_rule_op()

Comment thread extension/llm/custom_ops/CMakeLists.txt Outdated

set(_common_compile_options
$<$<CXX_COMPILER_ID:MSVC>:/wd4996>
$<$<CXX_COMPILER_ID:MSVC>:/Zc:__cplusplus>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What codepath are you doing down that isnt triggering properly without this? Typically the c10 pattern is to just have explicit msvc conditions and not rely on the c++ version on windows iirc. I could be wrong on that though.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right, i didn't need it. changed make_aten_functor_from_et_functor.h to use MSVC instead, and now it works without /Zc:__cplusplus

@Phineas1500
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "release notes: examples"

@pytorch-bot pytorch-bot Bot added the release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava label Mar 13, 2026
@Phineas1500
Copy link
Copy Markdown
Contributor Author

Can you tell me a bit more about the serialization issue you ran into as well as the MSVC one?

I believe the serialization issue is separate from this PR, as I reproduced it here and in main. _prepare_schema() wants program.fbs to be under exir/_serialize, but in a normal source checkout, it's under schema/. The error is that it can't find program.fbs.

I had two separate local build issues on MSVC:

  1. fht_avx.c uses inline asm that cl.exe doesn't accept. I kept my local workaround out of the PR because it's unrelated.
  2. WRAP_TO_ATEN was failing for the touched .out wrappers, so I switched them to explicit conversions (ET to ATen) in this recent commit

I also dropped the /Zc:__cplusplus change from this PR.

Happy to answer more questions @JacobSzwejbka

@digantdesai
Copy link
Copy Markdown
Contributor

Taking a look.

std::vector<float> kv_mem(v_head_dim);
std::vector<float> delta(v_head_dim);

for (int64_t batch = 0; batch < batch_size; ++batch) {
Copy link
Copy Markdown
Contributor

@digantdesai digantdesai Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a naive implementation, in terms of performance with all the loops etc. Do you have plans to improve this or add a new version and leave this as a reference C++ implementation?

Copy link
Copy Markdown
Contributor

@digantdesai digantdesai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok with merging this as is. Love the tests. Thanks.

Copilot AI review requested due to automatic review settings March 19, 2026 18:34
@digantdesai
Copy link
Copy Markdown
Contributor

Let me try to merge this change.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 11 out of 11 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/models/llama/attention.py Outdated

try:
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
except (ImportError, OSError):
@nil-is-all
Copy link
Copy Markdown
Contributor

Hi @Phineas1500, could you check the lintrunner logs and resolve the issues brought up? Once they are fixed, it should be good to merge.
cc @digantdesai

@Phineas1500
Copy link
Copy Markdown
Contributor Author

Hi @Phineas1500, could you check the lintrunner logs and resolve the issues brought up? Once they are fixed, it should be good to merge. cc @digantdesai

Will do @nil-is-all !

@Phineas1500
Copy link
Copy Markdown
Contributor Author

@digantdesai made a commit addressing the linter issues. Hopefully should be good to merge after the tests run.

@Phineas1500 Phineas1500 requested a review from digantdesai March 26, 2026 20:03
@Phineas1500
Copy link
Copy Markdown
Contributor Author

Apologies if I clicked the wrong button. Could a maintainer approve the twelve workflows?

cc @digantdesai @mergennachin

@nil-is-all
Copy link
Copy Markdown
Contributor

Hi @Phineas1500, approved the workflows to run. Could you take a look at the lintrunner error logs and resolve them?

Copilot AI review requested due to automatic review settings March 31, 2026 16:53
@Phineas1500
Copy link
Copy Markdown
Contributor Author

@nil-is-all resolved the issues from the lintrunner logs

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 11 out of 11 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

raise FileNotFoundError(
f"Could not find custom_ops_aot_lib under {package_path}"
)
return max(libs, key=lambda path: path.stat().st_mtime)
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_find_custom_ops_library() collects all matching custom_ops_aot_lib binaries under the package dir and then picks the newest by mtime. If multiple candidates exist (e.g., debug/release builds or stale artifacts), this can silently load the wrong library and make behavior depend on filesystem timestamps. Consider failing fast when >1 candidate is found (and listing them), unless the user provides an explicit override via EXECUTORCH_CUSTOM_OPS_AOT_LIB.

Suggested change
return max(libs, key=lambda path: path.stat().st_mtime)
if len(libs) > 1:
discovered = "\n - " + "\n - ".join(str(p) for p in libs)
raise RuntimeError(
"Found multiple candidate custom_ops_aot_lib libraries under "
f"{package_path}:\n{discovered}\n"
"Please set the EXECUTORCH_CUSTOM_OPS_AOT_LIB environment variable "
"to the desired library to disambiguate."
)
return libs[0]

Copilot uses AI. Check for mistakes.
Comment on lines +19 to +34
namespace {
template <typename EType, typename AType>
auto to_et_arg(AType&& value) {
return executorch::extension::internal::type_convert<AType, EType>(
std::forward<AType>(value));
}

at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) {
auto converted_result =
executorch::extension::internal::type_convert<Tensor&, at::Tensor>(
et_result)
.call();
at::native::resize_output(out, converted_result.sizes());
out.copy_(converted_result);
return out;
}
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file introduces local helpers (to_et_arg / copy_et_result_to_out) that duplicate the conversion+resize+copy logic already provided by WRAP_TO_ATEN in executorch/extension/aten_util/make_aten_functor_from_et_functor.h. Keeping multiple bespoke wrappers across ops increases maintenance cost and the risk of subtle divergence. Prefer using WRAP_TO_ATEN again, or moving these helpers into a shared utility used by all the AOT wrapper files.

Copilot uses AI. Check for mistakes.
Comment on lines +20 to +35
namespace {
template <typename EType, typename AType>
auto to_et_arg(AType&& value) {
return executorch::extension::internal::type_convert<AType, EType>(
std::forward<AType>(value));
}

at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) {
auto converted_result =
executorch::extension::internal::type_convert<Tensor&, at::Tensor>(
et_result)
.call();
at::native::resize_output(out, converted_result.sizes());
out.copy_(converted_result);
return out;
}
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file reimplements argument conversion and out-copying helpers (to_et_arg / copy_et_result_to_out) that largely mirror WRAP_TO_ATEN’s wrapper_impl (type_convert + resize_output + copy_). Since the same helpers are also added in other custom-op AOT files, consider using WRAP_TO_ATEN directly or centralizing the helpers to avoid wrapper logic drifting across files.

Copilot uses AI. Check for mistakes.
Comment on lines +16 to +30
template <typename EType, typename AType>
auto to_et_arg(AType&& value) {
return executorch::extension::internal::type_convert<AType, EType>(
std::forward<AType>(value));
}

at::Tensor& copy_et_result_to_out(Tensor& et_result, at::Tensor& out) {
auto converted_result =
executorch::extension::internal::type_convert<Tensor&, at::Tensor>(
et_result)
.call();
at::native::resize_output(out, converted_result.sizes());
out.copy_(converted_result);
return out;
}
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new to_et_arg / copy_et_result_to_out helpers duplicate the WRAP_TO_ATEN wrapper logic from executorch/extension/aten_util/make_aten_functor_from_et_functor.h, and similar copies exist in other custom-op wrapper files. Consider switching back to WRAP_TO_ATEN or sharing a single helper to reduce duplicated conversion/copy code paths.

Copilot uses AI. Check for mistakes.
@nil-is-all nil-is-all added the module: llm Issues related to LLM examples and apps, and to the extensions/llm/ code label Apr 1, 2026
@Phineas1500
Copy link
Copy Markdown
Contributor Author

@nil-is-all do you think it's good to merge? I don't believe the current errors are related to this PR's code

@Phineas1500
Copy link
Copy Markdown
Contributor Author

Phineas1500 commented Apr 2, 2026

@digantdesai @mergennachin is there anything else you think I should do? after this PR is merged, I'm thinking of making PRs adding q8da4w support, dynamic-shape export for DeltaNet, and better multi-token prefill

@digantdesai
Copy link
Copy Markdown
Contributor

Hey sorry, this is good to go. Let me run Claude through to catch nit issues. If something major comes up we should fix it else this is good.

@digantdesai
Copy link
Copy Markdown
Contributor

@claude Review this PR

@claude
Copy link
Copy Markdown

claude Bot commented Apr 22, 2026

Claude finished @digantdesai's task in 1m 54s —— View job


PR Review: Add recurrent gated delta rule custom op for Qwen3.5 attention

  • Read all changed files
  • Review C++ kernel implementation (op_sdpa.cpp, op_sdpa.h)
  • Review AOT/ATen bindings (op_sdpa_aot.cpp, op_tile_crop_aot.cpp, op_fast_hadamard_transform_aten.cpp)
  • Review Python op discovery and loading (custom_ops.py)
  • Review attention integration (attention.py)
  • Review tests
  • Review MSVC / CMake changes
  • Post review summary

Overall this is a well-structured PR with good test coverage. The kernel implementation is correct, the Python fallback path is cleanly preserved, and the test suite covers key correctness scenarios. Below are findings organized by severity.


Issues

1. Heap allocation inside hot loop (C++ kernel)

op_sdpa.cpp:736-737kv_mem and delta vectors are declared at function scope but above the batch/head loops. However, they are still heap-allocated std::vector<float> objects. For a production kernel, this is fine as a reference implementation, but if this path is hit during inference (which it will be), the heap allocations can be avoided by hoisting them outside the loops as raw arrays or stack buffers when v_head_dim is bounded (or at least allocating once and reusing).

This was flagged by Copilot and acknowledged by @digantdesai as a known future improvement, so not blocking, but worth noting.

2. Duplicated to_et_arg / copy_et_result_to_out helpers across 3 files

The same to_et_arg and copy_et_result_to_out helper templates are copy-pasted into:

  • op_sdpa_aot.cpp:21-35
  • op_tile_crop_aot.cpp:20-34
  • op_fast_hadamard_transform_aten.cpp:16-30

This creates maintenance risk — if the conversion logic needs a fix, three files must be updated in sync. These should be extracted to a shared header (e.g., aot_conversion_helpers.h) in a follow-up.

Fix this →

3. _find_custom_ops_library silently picks newest by mtime when multiple candidates exist

custom_ops.py:73 — When multiple custom_ops_aot_lib binaries exist under the package directory, the code silently picks the newest by st_mtime. This can lead to loading the wrong library (e.g., a stale debug build) with no indication. Consider either warning or erroring when len(libs) > 1 and no explicit override is set.

4. Validation before resize in recurrent_gated_delta_rule_out

op_sdpa.cpp:689-700 — The output tensor is resized before argument validation. If validation fails, the output may have already been resized to an inconsistent state. Swapping these two blocks so validation runs first would be more defensive:

// Validate first
ET_KERNEL_CHECK(ctx, validate_..., InvalidArgument, output);
// Then resize
ET_KERNEL_CHECK_MSG(ctx, resize_tensor(output, ...) == Error::Ok, ...);

Nits

5. assert used for runtime validation in custom_ops.py

custom_ops.py:335-350 (_validate_recurrent_gated_delta_rule_params) — Uses Python assert statements for validating tensor shapes and dtypes. These are stripped under python -O. This pattern is consistent with the existing validation functions in the same file (e.g., _validate_params), so it's not unique to this PR, but ValueError / RuntimeError would be more robust.

6. Exception catching in _get_recurrent_gated_delta_rule_op could be narrower

attention.py:83 — The except (AttributeError, RuntimeError) for the initial torch.ops.llama.recurrent_gated_delta_rule.default check is appropriate. The second except (ImportError, OSError, RuntimeError) at line 88 for importing custom_ops is also reasonable given that custom_ops.py can raise RuntimeError if ops fail to register. Good that this was addressed after an earlier review round.

7. Minor: g.dim() == 3 validation inconsistency

In _validate_recurrent_gated_delta_rule_params (Python, custom_ops.py:344), the dimension names are (B, S, H) in the assertion messages. But in the C++ kernel (op_sdpa.cpp:195-196), the g tensor is expected as (B, H, S) shape (since the Python side does .transpose(1, 2) before calling the op). The validation is functionally correct, just noting that the shape convention differs between the meta kernel check and the C++ kernel — the meta kernel validates the pre-transpose shape while the C++ validates the post-transpose shape. This is fine but could be confusing to future readers.

8. _CUSTOM_OPS_DLL_DIR_HANDLES is a good fix

custom_ops.py:26 — Storing the os.add_dll_directory() handles in a module-level list to keep them alive is the correct fix for the Windows DLL directory lifetime issue flagged in earlier reviews. Well done.


Test Quality

The tests are thorough and well-designed:

  • RecurrentGatedDeltaRuleTest (test_update_cache.py:436-585): Tests reference parity, .out variant behavior, and chunked-state parity. Good use of multiple test cases and subTest.
  • Qwen35AttentionTest (test_qwen3_5_attention.py): Tests state reset, no-leak semantics, chunked prefill parity, and custom-op-vs-fallback parity. The fallback test at line 176 correctly monkeypatches the module globals.
  • Export test (test_export_llama_lib.py:106-139): Verifies the exported graph contains exactly 2 recurrent_gated_delta_rule nodes. Clean and targeted.

Summary

This is a solid PR. The kernel implementation is correct, the fallback path is clean, and the test coverage is good. The main actionable items for follow-up are:

  1. Extract duplicated AOT conversion helpers into a shared header (not blocking)
  2. Consider warning on ambiguous library discovery when multiple candidates exist
  3. Swap resize/validation order in the C++ kernel for defensive correctness

None of these are blockers for merge. The PR is ready as-is per @digantdesai's earlier approval.

Copilot AI review requested due to automatic review settings April 23, 2026 23:10
Copy link
Copy Markdown
Contributor

@digantdesai digantdesai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, and sorry for the delay. I will try to merge it.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


try:
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
except (ImportError, OSError, RuntimeError):
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_recurrent_gated_delta_rule_op() attempts to import executorch.extension.llm.custom_ops.custom_ops as a best-effort fallback, but it doesn't catch FileNotFoundError. custom_ops.py can raise FileNotFoundError when custom_ops_aot_lib isn't present, which would crash attention initialization instead of cleanly falling back to the Python implementation. Consider catching FileNotFoundError here (or making custom_ops.py raise a RuntimeError that is already handled).

Suggested change
except (ImportError, OSError, RuntimeError):
except (ImportError, FileNotFoundError, OSError, RuntimeError):

Copilot uses AI. Check for mistakes.
_ = portable_lib


if not _is_custom_ops_registered():
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If EXECUTORCH_CUSTOM_OPS_AOT_LIB is set, the intent is to allow an explicit override for loading custom_ops_aot_lib. Right now _load_custom_ops_library() only runs when _is_custom_ops_registered() is false, so an override can be ignored in environments where some ops were already registered (e.g., via another shared library), preventing the override from taking effect. Consider changing the top-level guard to attempt _load_custom_ops_library() when the override env var is set (or at least when the specific expected ops like recurrent_gated_delta_rule are missing).

Suggested change
if not _is_custom_ops_registered():
_custom_ops_library_override = _get_custom_ops_library_override()
if _custom_ops_library_override is not None or not _is_custom_ops_registered():

Copilot uses AI. Check for mistakes.
Comment on lines +20 to +24
template <typename EType, typename AType>
auto to_et_arg(AType&& value) {
return executorch::extension::internal::type_convert<AType, EType>(
std::forward<AType>(value));
}
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file now uses std::forward, but it doesn't include <utility>. Relying on transitive includes can make builds brittle across toolchains; add an explicit #include <utility> to ensure std::forward is always available.

Copilot uses AI. Check for mistakes.
Comment on lines +21 to +25
template <typename EType, typename AType>
auto to_et_arg(AType&& value) {
return executorch::extension::internal::type_convert<AType, EType>(
std::forward<AType>(value));
}
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file now uses std::forward, but it doesn't include <utility>. Relying on transitive includes can make builds brittle across toolchains; add an explicit #include <utility> to ensure std::forward is always available.

Copilot uses AI. Check for mistakes.
Comment on lines +16 to +20
template <typename EType, typename AType>
auto to_et_arg(AType&& value) {
return executorch::extension::internal::type_convert<AType, EType>(
std::forward<AType>(value));
}
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file now uses std::forward, but it doesn't include <utility>. Relying on transitive includes can make builds brittle across toolchains; add an explicit #include <utility> to ensure std::forward is always available.

Copilot uses AI. Check for mistakes.
@digantdesai digantdesai merged commit 476a7ef into pytorch:main Apr 24, 2026
173 of 175 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: llm Issues related to LLM examples and apps, and to the extensions/llm/ code release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants