[ExecuTorch][WebGPU] Register-tile the SDPA QK/AV kernels#20405
[ExecuTorch][WebGPU] Register-tile the SDPA QK/AV kernels#20405JulianCloudNTH wants to merge 2 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20405
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 3 New Failures, 3 Unrelated FailuresAs of commit 3ce91e0 with merge base 0e65ba6 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
SS-JIA
left a comment
There was a problem hiding this comment.
Review automatically exported from Phabricator review in Meta.
Stack from ghstack (oldest at bottom):
+32% SDPA attention-compute (AV +40%) — register-tile the QK and AV kernels (isolated GPU-timestamp A/B, decode S=1, Chrome Canary / M4 Pro). A kernel-time win, not a wall-clock
forward()win —forward()stays bound by the submit/sync/readback floor (the separate fusion axis).Problem: The naive QK/AV kernels compute one output element per thread, so each thread re-loads Q/K/V and the dot products are scalar — poor register reuse, ALU/latency-bound.
Solution: Each thread computes a 4×4 output tile with the dot products vec4-packed in registers:
(head, S-tile, {ctx,D}-tile); 4×4 register tile, vec4 dot products. A floating-point accumulation reorder of the same products — no algorithm change.Implementation:
sdpa_compute_attn_weights.wgsl(QK): one thread per(head, S-tile, ctx-tile), gridHq · ceil(S/4) · ceil(ctx/4); tile registers arearray<vec4<f32>, TM/TN>loaded viaforloops.sdpa_compute_out.wgsl(AV): one thread per(head, S-tile, D-tile), gridHq · ceil(S/4) · ceil(D/4).Sdpa.cpp: dispatch math moves from an element count to a tile count (kSdpaTileM/N=4, sharedutils::div_up), keeping the uint32 scratch-overflow guard.utils::div_upmirrors Vulkan'sutils::div_up.Constraints:
update_cache, the bind-group layouts, and the scratch-buffer sizes (Hq*S*ctx) are unchanged.store_qkis identical). See DESIGN_DECISIONS.md.@exported-using-ghexport
Differential Revision: D109081409
Differential Revision: D109081409