[ExecuTorch][WebGPU] SDPA: skip QK contraction for fully-masked causal tiles#20492
[ExecuTorch][WebGPU] SDPA: skip QK contraction for fully-masked causal tiles#20492JulianCloudNTH wants to merge 1 commit into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20492
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 3 New Failures, 3 Unrelated FailuresAs of commit 40acc76 with merge base 0e65ba6 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
Stack from ghstack (oldest at bottom):
Skip the QK contraction for fully-masked causal tiles — at S=128 prefill ~48% of the (query, key) tiles are entirely above the diagonal and contribute nothing; this elides their dot products (prefill-only; bit-identical output).
Problem: For causal prefill, ~half the (query S-tile, key context-tile) pairs are entirely above the diagonal, yet the kernel still computes their full
d4dot product before masking the result toNEG_INF.Solution: Skip the contraction for fully-masked tiles; the existing per-element mask still writes the sentinel:
(s0, c0)tile runs the fulld4dot-product loop, thenstore_qkmasks above-diagonal elements toNEG_INF.c0 > s0 + TM-1 + input_pos) breaks thed4loop immediately (accstays 0);store_qkmasks every element toNEG_INFexactly as before.Implementation:
skip_tile = c0 > s0 + (TM - 1) + params.input_pos, folded into thed4loop break condition.sdpa_compute_attn_weights_tiled.glsl(tile_in_mask_region).Constraints:
S=1never triggers it (c0 <= input_pos < input_pos + TM - 1).NEG_INFstays the WGSL-safe-1.0e30(WGSL forbids a literal-inf); does not copy Vulkan's-1.0/0.0.Co-authored with Claude Code.
@exported-using-ghexport
Differential Revision: D109517773
Differential Revision: D109517773