Skip to content

[ExecuTorch][WebGPU] SDPA: branchless aligned/tail loads in the QK/AV kernels#20493

Open
JulianCloudNTH wants to merge 2 commits into
gh/JulianCloudNTH/63/basefrom
gh/JulianCloudNTH/63/head
Open

[ExecuTorch][WebGPU] SDPA: branchless aligned/tail loads in the QK/AV kernels#20493
JulianCloudNTH wants to merge 2 commits into
gh/JulianCloudNTH/63/basefrom
gh/JulianCloudNTH/63/head

Conversation

@JulianCloudNTH

@JulianCloudNTH JulianCloudNTH commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

Branchless aligned/tail loads + vec4 storage bindings — drop the always-true per-lane bounds checks in the tiled QK/AV hot loops, split the AV context contraction into a branch-free aligned body plus a checked tail, and declare the head-dim-indexed SDPA storage buffers as array<vec4<f32>> so the loads/stores are forced-vectorized (addresses review feedback to mirror Vulkan's vec4 bindings).

Problem: The tiled QK/AV vec4 loaders run 4 per-lane if bounds checks on every load, every contraction iteration (8 loads/iter). But head_dim is always a multiple of 4, so the D-axis checks never fire, and the AV context axis only needs a bounds check on the last ragged chunk. Separately the storage buffers were declared array<f32>, so the 4-lane loads/stores were not guaranteed to compile to aligned 128-bit vector accesses.

Solution: Remove the dead checks, split the ragged axis, and vectorize the bindings:

  • Before: load_q_vec4/load_k_vec4 (and AV load_a_vec4/load_v_d4) do 4 per-lane bounds ifs per call; the AV c4 loop runs checked loads for every chunk; t_q/t_k_cache/t_v_cache/t_out are array<f32> accessed element-by-element.
  • After: QK loads are a plain unchecked vec4 (D%4==0, host-guarded); AV runs a branch-free aligned body over c4 in [0, context_len - context_len%4) then a 0-or-1 checked tail; the head-dim-indexed buffers t_q/t_k_cache/t_v_cache/t_out are array<vec4<f32>> indexed [base/4u], and AV writes a single aligned store_out_vec4.

Implementation:

  • QK: load_q_vec4/load_k_vec4 drop the per-lane D checks and return t_q[base/4u] / t_k_cache[base/4u].
  • AV: branch-free load_a_vec4_nc/load_v_d4_nc for the aligned body; checked load_a_vec4/load_v_d4 for the tail; V reads t_v_cache[base/4u]; output is one aligned store_out_vec4.
  • Bindings: t_q, t_k_cache (QK) and t_v_cache, t_out (AV) are array<vec4<f32>>. t_attn_weights and the softmax buffer stay array<f32> — they are context_len-indexed (row stride not 4-aligned) and written per-element under the causal mask, so a vec4 binding there would need a padded scratch row.
  • Host: add a D % 4 == 0 guard in Sdpa.cpp — WGSL has no SDPA_PAD_D pad-load, so fail loud rather than read past the row; this guard also makes every [base/4u] index 4-aligned and every buffer a 16-byte multiple.
  • Test: add a reject_d6 (head_dim=6) config + an expect_reject harness branch asserting the guard rejects a non-aligned head_dim at load.
  • Mirrors Vulkan sdpa_compute_out_tiled.glsl (aligned/tail split) and Vulkan's array<vec4> SDPA bindings.

Constraints:

  • Requires head_dim % 4 == 0 (true for every Llama config, D=64); enforced by a loud host throw, not a silent narrowing.
  • Bit-identical output: the aligned body processes the same chunks in the same accumulation order as the scalar loop, the tail's out-of-range lanes contribute 0, and the vec4 bindings read/write the same bytes as the scalar version.
  • No KV-cache layout, dispatch, or uniform change.

Co-authored with Claude Code.
@exported-using-ghexport

Differential Revision: D109521069

Differential Revision: D109521069

[ghstack-poisoned]
@pytorch-bot

pytorch-bot Bot commented Jun 24, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20493

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 New Failures, 4 Unrelated Failures

As of commit 8ba8c50 with merge base 0e65ba6 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

[ghstack-poisoned]
JulianCloudNTH added a commit that referenced this pull request Jun 24, 2026
… kernels

Pull Request resolved: #20493

**Branchless aligned/tail loads + vec4 storage bindings** — drop the always-true per-lane bounds checks in the tiled QK/AV hot loops, split the AV context contraction into a branch-free aligned body plus a checked tail, and declare the head-dim-indexed SDPA storage buffers as `array<vec4<f32>>` so the loads/stores are forced-vectorized (addresses review feedback to mirror Vulkan's vec4 bindings).

**Problem**: The tiled QK/AV vec4 loaders run 4 per-lane `if` bounds checks on every load, every contraction iteration (8 loads/iter). But `head_dim` is always a multiple of 4, so the D-axis checks never fire, and the AV context axis only needs a bounds check on the last ragged chunk. Separately the storage buffers were declared `array<f32>`, so the 4-lane loads/stores were not guaranteed to compile to aligned 128-bit vector accesses.

**Solution**: Remove the dead checks, split the ragged axis, and vectorize the bindings:
- **Before**: `load_q_vec4`/`load_k_vec4` (and AV `load_a_vec4`/`load_v_d4`) do 4 per-lane bounds `if`s per call; the AV `c4` loop runs checked loads for every chunk; `t_q`/`t_k_cache`/`t_v_cache`/`t_out` are `array<f32>` accessed element-by-element.
- **After**: QK loads are a plain unchecked `vec4` (D%4==0, host-guarded); AV runs a branch-free aligned body over `c4 in [0, context_len - context_len%4)` then a 0-or-1 checked tail; the head-dim-indexed buffers `t_q`/`t_k_cache`/`t_v_cache`/`t_out` are `array<vec4<f32>>` indexed `[base/4u]`, and AV writes a single aligned `store_out_vec4`.

**Implementation**:
- QK: `load_q_vec4`/`load_k_vec4` drop the per-lane D checks and return `t_q[base/4u]` / `t_k_cache[base/4u]`.
- AV: branch-free `load_a_vec4_nc`/`load_v_d4_nc` for the aligned body; checked `load_a_vec4`/`load_v_d4` for the tail; V reads `t_v_cache[base/4u]`; output is one aligned `store_out_vec4`.
- Bindings: `t_q`, `t_k_cache` (QK) and `t_v_cache`, `t_out` (AV) are `array<vec4<f32>>`. `t_attn_weights` and the softmax buffer stay `array<f32>` — they are `context_len`-indexed (row stride not 4-aligned) and written per-element under the causal mask, so a `vec4` binding there would need a padded scratch row.
- Host: add a `D % 4 == 0` guard in `Sdpa.cpp` — WGSL has no `SDPA_PAD_D` pad-load, so fail loud rather than read past the row; this guard also makes every `[base/4u]` index 4-aligned and every buffer a 16-byte multiple.
- Test: add a `reject_d6` (head_dim=6) config + an `expect_reject` harness branch asserting the guard rejects a non-aligned head_dim at load.
- Mirrors Vulkan `sdpa_compute_out_tiled.glsl` (aligned/tail split) and Vulkan's `array<vec4>` SDPA bindings.

**Constraints**:
- Requires `head_dim % 4 == 0` (true for every Llama config, D=64); enforced by a loud host throw, not a silent narrowing.
- Bit-identical output: the aligned body processes the same chunks in the same accumulation order as the scalar loop, the tail's out-of-range lanes contribute 0, and the `vec4` bindings read/write the same bytes as the scalar version.
- No KV-cache layout, dispatch, or uniform change.

Co-authored with Claude Code.
ghstack-source-id: 396717582
@exported-using-ghexport

Differential Revision: [D109521069](https://our.internmc.facebook.com/intern/diff/D109521069/)

@SS-JIA SS-JIA left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review automatically exported from Phabricator review in Meta.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants