Metal backend: Add gated delta rule kernel for linear attention#18878
Metal backend: Add gated delta rule kernel for linear attention#18878manuelcandales merged 29 commits intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18878
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@claude review this code |
|
Claude finished @metascroy's task in 2m 37s —— View job
Review: Metal gated delta rule kernelOverall this is a well-structured addition that follows existing Metal backend patterns. The kernel logic is correct and the in-place state mutation approach is sound. A few issues worth addressing: Bug: Missing error check on output tensor creation ( The // op_gather_qmv.mm pattern:
AOTITorchError create_result = aoti_torch_create_tensor_from_blob_v2(
out_contents_ptr, 2, output_sizes.data(), output_strides.data(),
0, dtype, 13, 0, &out_tensor_handle, 0, nullptr, 0);
if (create_result != Error::Ok || !out_tensor_handle) {
ET_LOG(Error, "...: Failed to create output tensor");
aoti_torch_mps_free(out_contents_ptr);
return Error::Internal;
}In Bug: Python reference impl silently assumes Hk == Hv ( The reference implementation broadcasts Currently all instantiated kernels use
Nit: Stale comment in instantiation block ( Line 124 says "Qwen 3.5 MoE dimensions: Dk=128, Dv=128, Hk=16, Hv=32" but line 143 instantiates Nit: Tolerance relaxation for gather_qmv bfloat16 ( The bfloat16 tolerances for Minor observations (non-blocking):
|
This PR needs a
|
Adds Metal kernel for the gated delta rule recurrence used by Qwen 3.5
MoE's GatedDeltaNet linear attention layers. Ported from the MLX delegate
PR (#18785) Metal shader. The kernel processes the full sequence
sequentially within a single GPU dispatch, keeping recurrent state in
per-thread registers.
Grid: [32, Dv, B*Hv], Threadgroup: [32, 4, 1]. Each simdgroup of 32
threads handles Dk/32 elements of the key dimension with SIMD reduction
for dot products.
The op mutates the recurrent state buffer in-place (mutates_args).
Instantiated for both real model (Dk=128, Dv=128, Hk=32, Hv=32) and
tiny test (Dk=64, Dv=64, Hk=4, Hv=4) dimensions.
Includes: Metal shader + C++ host dispatch, Python custom op definition
(metal::gated_delta_rule) with reference CPU impl and Meta impl, C shim
dict, fallback kernel registration, CMakeLists entry, and test module.
Authored with Claude.