Skip to content

Commit 2fdc6d5

Browse files
[ExecuTorch][WebGPU] SDPA: branchless aligned/tail loads in the QK/AV kernels
Pull Request resolved: #20493 **Branchless aligned/tail loads + vec4 storage bindings** — drop the always-true per-lane bounds checks in the tiled QK/AV hot loops, split the AV context contraction into a branch-free aligned body plus a checked tail, and declare the head-dim-indexed SDPA storage buffers as `array<vec4<f32>>` so the loads/stores are forced-vectorized (addresses review feedback to mirror Vulkan's vec4 bindings). **Problem**: The tiled QK/AV vec4 loaders run 4 per-lane `if` bounds checks on every load, every contraction iteration (8 loads/iter). But `head_dim` is always a multiple of 4, so the D-axis checks never fire, and the AV context axis only needs a bounds check on the last ragged chunk. Separately the storage buffers were declared `array<f32>`, so the 4-lane loads/stores were not guaranteed to compile to aligned 128-bit vector accesses. **Solution**: Remove the dead checks, split the ragged axis, and vectorize the bindings: - **Before**: `load_q_vec4`/`load_k_vec4` (and AV `load_a_vec4`/`load_v_d4`) do 4 per-lane bounds `if`s per call; the AV `c4` loop runs checked loads for every chunk; `t_q`/`t_k_cache`/`t_v_cache`/`t_out` are `array<f32>` accessed element-by-element. - **After**: QK loads are a plain unchecked `vec4` (D%4==0, host-guarded); AV runs a branch-free aligned body over `c4 in [0, context_len - context_len%4)` then a 0-or-1 checked tail; the head-dim-indexed buffers `t_q`/`t_k_cache`/`t_v_cache`/`t_out` are `array<vec4<f32>>` indexed `[base/4u]`, and AV writes a single aligned `store_out_vec4`. **Implementation**: - QK: `load_q_vec4`/`load_k_vec4` drop the per-lane D checks and return `t_q[base/4u]` / `t_k_cache[base/4u]`. - AV: branch-free `load_a_vec4_nc`/`load_v_d4_nc` for the aligned body; checked `load_a_vec4`/`load_v_d4` for the tail; V reads `t_v_cache[base/4u]`; output is one aligned `store_out_vec4`. - Bindings: `t_q`, `t_k_cache` (QK) and `t_v_cache`, `t_out` (AV) are `array<vec4<f32>>`. `t_attn_weights` and the softmax buffer stay `array<f32>` — they are `context_len`-indexed (row stride not 4-aligned) and written per-element under the causal mask, so a `vec4` binding there would need a padded scratch row. - Host: add a `D % 4 == 0` guard in `Sdpa.cpp` — WGSL has no `SDPA_PAD_D` pad-load, so fail loud rather than read past the row; this guard also makes every `[base/4u]` index 4-aligned and every buffer a 16-byte multiple. - Test: add a `reject_d6` (head_dim=6) config + an `expect_reject` harness branch asserting the guard rejects a non-aligned head_dim at load. - Mirrors Vulkan `sdpa_compute_out_tiled.glsl` (aligned/tail split) and Vulkan's `array<vec4>` SDPA bindings. **Constraints**: - Requires `head_dim % 4 == 0` (true for every Llama config, D=64); enforced by a loud host throw, not a silent narrowing. - Bit-identical output: the aligned body processes the same chunks in the same accumulation order as the scalar loop, the tail's out-of-range lanes contribute 0, and the `vec4` bindings read/write the same bytes as the scalar version. - No KV-cache layout, dispatch, or uniform change. Co-authored with Claude Code. ghstack-source-id: 396717582 @exported-using-ghexport Differential Revision: [D109521069](https://our.internmc.facebook.com/intern/diff/D109521069/)
1 parent 26bd1c1 commit 2fdc6d5

7 files changed

Lines changed: 132 additions & 78 deletions

File tree

backends/webgpu/runtime/ops/sdpa/Sdpa.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,11 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
339339
if (k.dims[kn - 1] != D || v.dims[v.dims.size() - 1] != D) {
340340
throw std::runtime_error("WebGPU sdpa: k/v head_dim must match q");
341341
}
342+
// QK/AV read D as vec4 (no SDPA_PAD_D); head_dim must be a multiple of 4.
343+
if (D % 4 != 0) {
344+
throw std::runtime_error(
345+
"WebGPU sdpa: head_dim (D) must be a multiple of 4");
346+
}
342347
if (v.dims[v.dims.size() - 2] != Hkv) {
343348
throw std::runtime_error("WebGPU sdpa: v num_heads must match k");
344349
}

backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@group(0) @binding(0) var<storage, read_write> t_attn_weights: array<f32>;
2-
@group(0) @binding(1) var<storage, read> t_q: array<f32>;
3-
@group(0) @binding(2) var<storage, read> t_k_cache: array<f32>;
2+
@group(0) @binding(1) var<storage, read> t_q: array<vec4<f32>>;
3+
@group(0) @binding(2) var<storage, read> t_k_cache: array<vec4<f32>>;
44

55
struct Params {
66
S: u32,
@@ -22,30 +22,21 @@ override wg_size: u32 = 64;
2222
const TM: u32 = 4u;
2323
const TN: u32 = 4u;
2424

25+
// D is a multiple of 4 (host-guarded), so a d4 chunk is fully in-bounds — no per-lane check.
2526
fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4<f32> {
26-
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
2727
if (s >= params.S) {
28-
return r;
28+
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
2929
}
30-
let base = s * params.Hq * params.D + h * params.D;
31-
if (d4 + 0u < params.D) { r.x = t_q[base + d4 + 0u]; }
32-
if (d4 + 1u < params.D) { r.y = t_q[base + d4 + 1u]; }
33-
if (d4 + 2u < params.D) { r.z = t_q[base + d4 + 2u]; }
34-
if (d4 + 3u < params.D) { r.w = t_q[base + d4 + 3u]; }
35-
return r;
30+
let base = s * params.Hq * params.D + h * params.D + d4;
31+
return t_q[base / 4u];
3632
}
3733

3834
fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4<f32> {
39-
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
4035
if (c >= params.context_len) {
41-
return r;
36+
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
4237
}
43-
let base = c * params.Hkv * params.D + kvh * params.D;
44-
if (d4 + 0u < params.D) { r.x = t_k_cache[base + d4 + 0u]; }
45-
if (d4 + 1u < params.D) { r.y = t_k_cache[base + d4 + 1u]; }
46-
if (d4 + 2u < params.D) { r.z = t_k_cache[base + d4 + 2u]; }
47-
if (d4 + 3u < params.D) { r.w = t_k_cache[base + d4 + 3u]; }
48-
return r;
38+
let base = c * params.Hkv * params.D + kvh * params.D + d4;
39+
return t_k_cache[base / 4u];
4940
}
5041

5142
fn store_qk(s: u32, c: u32, h: u32, raw: f32) {

backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
namespace executorch::backends::webgpu {
1414

1515
// @generated from sdpa_compute_attn_weights.wgsl - DO NOT EDIT.
16-
// wgsl-sha256: d177264689e6c50e1794a0599808f3cfe6f30ba99c5084d3c8324da4b9f89d10
16+
// wgsl-sha256: 4eef09b234fd926cdc0daf18d03e39cf4fd57dfa4bc67724b4878b7dc68d1254
1717
inline constexpr const char* kSdpaComputeAttnWeightsWGSL = R"(
1818
@group(0) @binding(0) var<storage, read_write> t_attn_weights: array<f32>;
19-
@group(0) @binding(1) var<storage, read> t_q: array<f32>;
20-
@group(0) @binding(2) var<storage, read> t_k_cache: array<f32>;
19+
@group(0) @binding(1) var<storage, read> t_q: array<vec4<f32>>;
20+
@group(0) @binding(2) var<storage, read> t_k_cache: array<vec4<f32>>;
2121
2222
struct Params {
2323
S: u32,
@@ -39,30 +39,21 @@ override wg_size: u32 = 64;
3939
const TM: u32 = 4u;
4040
const TN: u32 = 4u;
4141
42+
// D is a multiple of 4 (host-guarded), so a d4 chunk is fully in-bounds — no per-lane check.
4243
fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4<f32> {
43-
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
4444
if (s >= params.S) {
45-
return r;
45+
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
4646
}
47-
let base = s * params.Hq * params.D + h * params.D;
48-
if (d4 + 0u < params.D) { r.x = t_q[base + d4 + 0u]; }
49-
if (d4 + 1u < params.D) { r.y = t_q[base + d4 + 1u]; }
50-
if (d4 + 2u < params.D) { r.z = t_q[base + d4 + 2u]; }
51-
if (d4 + 3u < params.D) { r.w = t_q[base + d4 + 3u]; }
52-
return r;
47+
let base = s * params.Hq * params.D + h * params.D + d4;
48+
return t_q[base / 4u];
5349
}
5450
5551
fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4<f32> {
56-
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
5752
if (c >= params.context_len) {
58-
return r;
53+
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
5954
}
60-
let base = c * params.Hkv * params.D + kvh * params.D;
61-
if (d4 + 0u < params.D) { r.x = t_k_cache[base + d4 + 0u]; }
62-
if (d4 + 1u < params.D) { r.y = t_k_cache[base + d4 + 1u]; }
63-
if (d4 + 2u < params.D) { r.z = t_k_cache[base + d4 + 2u]; }
64-
if (d4 + 3u < params.D) { r.w = t_k_cache[base + d4 + 3u]; }
65-
return r;
55+
let base = c * params.Hkv * params.D + kvh * params.D + d4;
56+
return t_k_cache[base / 4u];
6657
}
6758
6859
fn store_qk(s: u32, c: u32, h: u32, raw: f32) {

backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
1+
@group(0) @binding(0) var<storage, read_write> t_out: array<vec4<f32>>;
22
@group(0) @binding(1) var<storage, read> t_attn_weights_softmax: array<f32>;
3-
@group(0) @binding(2) var<storage, read> t_v_cache: array<f32>;
3+
@group(0) @binding(2) var<storage, read> t_v_cache: array<vec4<f32>>;
44

55
struct Params {
66
S: u32,
@@ -19,6 +19,7 @@ override wg_size: u32 = 64;
1919
const TM: u32 = 4u;
2020
const TN: u32 = 4u;
2121

22+
// Checked loaders mask context lanes past context_len (D%4==0, host-guarded).
2223
fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4<f32> {
2324
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
2425
if (s >= params.S) {
@@ -33,24 +34,33 @@ fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4<f32> {
3334
}
3435

3536
fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4<f32> {
36-
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
3737
if (c >= params.context_len) {
38-
return r;
38+
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
3939
}
4040
let base = c * params.Hkv * params.D + kvh * params.D + d0;
41-
if (d0 + 0u < params.D) { r.x = t_v_cache[base + 0u]; }
42-
if (d0 + 1u < params.D) { r.y = t_v_cache[base + 1u]; }
43-
if (d0 + 2u < params.D) { r.z = t_v_cache[base + 2u]; }
44-
if (d0 + 3u < params.D) { r.w = t_v_cache[base + 3u]; }
45-
return r;
41+
return t_v_cache[base / 4u];
42+
}
43+
44+
// Branch-free loaders for the aligned body: caller guarantees c4..c4+3 < context_len.
45+
fn load_a_vec4_nc(s: u32, h: u32, c4: u32) -> vec4<f32> {
46+
if (s >= params.S) {
47+
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
48+
}
49+
let base = h * params.S * params.context_len + s * params.context_len + c4;
50+
return vec4<f32>(t_attn_weights_softmax[base], t_attn_weights_softmax[base + 1u], t_attn_weights_softmax[base + 2u], t_attn_weights_softmax[base + 3u]);
51+
}
52+
53+
fn load_v_d4_nc(c: u32, kvh: u32, d0: u32) -> vec4<f32> {
54+
let base = c * params.Hkv * params.D + kvh * params.D + d0;
55+
return t_v_cache[base / 4u];
4656
}
4757

48-
fn store_out(s: u32, d: u32, h: u32, val: f32) {
49-
if (s >= params.S || d >= params.D) {
58+
fn store_out_vec4(s: u32, d0: u32, h: u32, val: vec4<f32>) {
59+
if (s >= params.S) {
5060
return;
5161
}
52-
let idx = s * params.Hq * params.D + h * params.D + d;
53-
t_out[idx] = val;
62+
let idx = s * params.Hq * params.D + h * params.D + d0;
63+
t_out[idx / 4u] = val;
5464
}
5565

5666
@compute @workgroup_size(wg_size, 1, 1)
@@ -77,11 +87,28 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
7787
acc[2] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
7888
acc[3] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
7989

90+
// Branch-free aligned body + checked tail; mirrors Vulkan out_tiled.glsl.
91+
let ctx_aligned = params.context_len - (params.context_len & 3u);
8092
var c4: u32 = 0u;
8193
loop {
82-
if (c4 >= params.context_len) {
94+
if (c4 >= ctx_aligned) {
8395
break;
8496
}
97+
let a0 = load_a_vec4_nc(s0 + 0u, h, c4);
98+
let a1 = load_a_vec4_nc(s0 + 1u, h, c4);
99+
let a2 = load_a_vec4_nc(s0 + 2u, h, c4);
100+
let a3 = load_a_vec4_nc(s0 + 3u, h, c4);
101+
let v0 = load_v_d4_nc(c4 + 0u, kvh, d0);
102+
let v1 = load_v_d4_nc(c4 + 1u, kvh, d0);
103+
let v2 = load_v_d4_nc(c4 + 2u, kvh, d0);
104+
let v3 = load_v_d4_nc(c4 + 3u, kvh, d0);
105+
acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3;
106+
acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3;
107+
acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3;
108+
acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3;
109+
c4 = c4 + 4u;
110+
}
111+
if (c4 < params.context_len) {
85112
let a0 = load_a_vec4(s0 + 0u, h, c4);
86113
let a1 = load_a_vec4(s0 + 1u, h, c4);
87114
let a2 = load_a_vec4(s0 + 2u, h, c4);
@@ -94,19 +121,14 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
94121
acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3;
95122
acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3;
96123
acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3;
97-
c4 = c4 + 4u;
98124
}
99125

100126
var m: u32 = 0u;
101127
loop {
102128
if (m >= TM) {
103129
break;
104130
}
105-
let ov = acc[m];
106-
store_out(s0 + m, d0 + 0u, h, ov.x);
107-
store_out(s0 + m, d0 + 1u, h, ov.y);
108-
store_out(s0 + m, d0 + 2u, h, ov.z);
109-
store_out(s0 + m, d0 + 3u, h, ov.w);
131+
store_out_vec4(s0 + m, d0, h, acc[m]);
110132
m = m + 1u;
111133
}
112134
}

backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
namespace executorch::backends::webgpu {
1414

1515
// @generated from sdpa_compute_out.wgsl - DO NOT EDIT.
16-
// wgsl-sha256: 545f624567b08eba407954034df821010e49124fa6f8fd6b05c64ca4354ee4cc
16+
// wgsl-sha256: 2ffa0eb520b1054e43a10fd13e6b287bd35777f1cfc29bd39e9d668772528191
1717
inline constexpr const char* kSdpaComputeOutWGSL = R"(
18-
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
18+
@group(0) @binding(0) var<storage, read_write> t_out: array<vec4<f32>>;
1919
@group(0) @binding(1) var<storage, read> t_attn_weights_softmax: array<f32>;
20-
@group(0) @binding(2) var<storage, read> t_v_cache: array<f32>;
20+
@group(0) @binding(2) var<storage, read> t_v_cache: array<vec4<f32>>;
2121
2222
struct Params {
2323
S: u32,
@@ -36,6 +36,7 @@ override wg_size: u32 = 64;
3636
const TM: u32 = 4u;
3737
const TN: u32 = 4u;
3838
39+
// Checked loaders mask context lanes past context_len (D%4==0, host-guarded).
3940
fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4<f32> {
4041
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
4142
if (s >= params.S) {
@@ -50,24 +51,33 @@ fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4<f32> {
5051
}
5152
5253
fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4<f32> {
53-
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
5454
if (c >= params.context_len) {
55-
return r;
55+
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
5656
}
5757
let base = c * params.Hkv * params.D + kvh * params.D + d0;
58-
if (d0 + 0u < params.D) { r.x = t_v_cache[base + 0u]; }
59-
if (d0 + 1u < params.D) { r.y = t_v_cache[base + 1u]; }
60-
if (d0 + 2u < params.D) { r.z = t_v_cache[base + 2u]; }
61-
if (d0 + 3u < params.D) { r.w = t_v_cache[base + 3u]; }
62-
return r;
58+
return t_v_cache[base / 4u];
59+
}
60+
61+
// Branch-free loaders for the aligned body: caller guarantees c4..c4+3 < context_len.
62+
fn load_a_vec4_nc(s: u32, h: u32, c4: u32) -> vec4<f32> {
63+
if (s >= params.S) {
64+
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
65+
}
66+
let base = h * params.S * params.context_len + s * params.context_len + c4;
67+
return vec4<f32>(t_attn_weights_softmax[base], t_attn_weights_softmax[base + 1u], t_attn_weights_softmax[base + 2u], t_attn_weights_softmax[base + 3u]);
68+
}
69+
70+
fn load_v_d4_nc(c: u32, kvh: u32, d0: u32) -> vec4<f32> {
71+
let base = c * params.Hkv * params.D + kvh * params.D + d0;
72+
return t_v_cache[base / 4u];
6373
}
6474
65-
fn store_out(s: u32, d: u32, h: u32, val: f32) {
66-
if (s >= params.S || d >= params.D) {
75+
fn store_out_vec4(s: u32, d0: u32, h: u32, val: vec4<f32>) {
76+
if (s >= params.S) {
6777
return;
6878
}
69-
let idx = s * params.Hq * params.D + h * params.D + d;
70-
t_out[idx] = val;
79+
let idx = s * params.Hq * params.D + h * params.D + d0;
80+
t_out[idx / 4u] = val;
7181
}
7282
7383
@compute @workgroup_size(wg_size, 1, 1)
@@ -94,11 +104,28 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
94104
acc[2] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
95105
acc[3] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
96106
107+
// Branch-free aligned body + checked tail; mirrors Vulkan out_tiled.glsl.
108+
let ctx_aligned = params.context_len - (params.context_len & 3u);
97109
var c4: u32 = 0u;
98110
loop {
99-
if (c4 >= params.context_len) {
111+
if (c4 >= ctx_aligned) {
100112
break;
101113
}
114+
let a0 = load_a_vec4_nc(s0 + 0u, h, c4);
115+
let a1 = load_a_vec4_nc(s0 + 1u, h, c4);
116+
let a2 = load_a_vec4_nc(s0 + 2u, h, c4);
117+
let a3 = load_a_vec4_nc(s0 + 3u, h, c4);
118+
let v0 = load_v_d4_nc(c4 + 0u, kvh, d0);
119+
let v1 = load_v_d4_nc(c4 + 1u, kvh, d0);
120+
let v2 = load_v_d4_nc(c4 + 2u, kvh, d0);
121+
let v3 = load_v_d4_nc(c4 + 3u, kvh, d0);
122+
acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3;
123+
acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3;
124+
acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3;
125+
acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3;
126+
c4 = c4 + 4u;
127+
}
128+
if (c4 < params.context_len) {
102129
let a0 = load_a_vec4(s0 + 0u, h, c4);
103130
let a1 = load_a_vec4(s0 + 1u, h, c4);
104131
let a2 = load_a_vec4(s0 + 2u, h, c4);
@@ -111,19 +138,14 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
111138
acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3;
112139
acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3;
113140
acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3;
114-
c4 = c4 + 4u;
115141
}
116142
117143
var m: u32 = 0u;
118144
loop {
119145
if (m >= TM) {
120146
break;
121147
}
122-
let ov = acc[m];
123-
store_out(s0 + m, d0 + 0u, h, ov.x);
124-
store_out(s0 + m, d0 + 1u, h, ov.y);
125-
store_out(s0 + m, d0 + 2u, h, ov.z);
126-
store_out(s0 + m, d0 + 3u, h, ov.w);
148+
store_out_vec4(s0 + m, d0, h, acc[m]);
127149
m = m + 1u;
128150
}
129151
}

backends/webgpu/test/ops/sdpa/test_sdpa.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class SdpaConfig:
5959
# Llama 3.2 1B shape: realistic prefill (S=128 at pos 0) + decode (S=1 at pos 127).
6060
SdpaConfig("llama1b_prefill", 32, 8, 64, 128, 512, 0),
6161
SdpaConfig("llama1b_decode", 32, 8, 64, 1, 512, 127),
62+
# D=6 is not a multiple of 4: the WebGPU head_dim%4 guard must reject it at load.
63+
SdpaConfig("reject_d6", 4, 4, 6, 4, 16, 0),
6264
]
6365

6466

backends/webgpu/test/test_webgpu_native.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ struct SdpaConfig {
435435
int input_pos; // prior tokens already in the cache (decode)
436436
float denom; // ramp divisor (mirrors Python); small -> large logits
437437
bool required = false; // CI (SDPA dir set): absent .pte = FAIL, not skip
438+
bool expect_reject = false; // load MUST fail (e.g. D%4 guard), no golden
438439
};
439440

440441
static const SdpaConfig kSdpaConfigs[] = {
@@ -454,6 +455,17 @@ static const SdpaConfig kSdpaConfigs[] = {
454455
// pos 127).
455456
{"llama1b_prefill", 32, 8, 64, 128, 512, 0, 16.0f},
456457
{"llama1b_decode", 32, 8, 64, 1, 512, 127, 16.0f},
458+
// D=6 is not a multiple of 4: the head_dim%4 guard must reject it at load.
459+
{"reject_d6",
460+
4,
461+
4,
462+
6,
463+
4,
464+
16,
465+
0,
466+
16.0f,
467+
/*required=*/false,
468+
/*expect_reject=*/true},
457469
};
458470

459471
// Ramp denominator; mirror of test_sdpa.py::_RAMP_DENOM (keep in sync).
@@ -507,6 +519,15 @@ static bool test_sdpa_config(
507519

508520
Module module(model_path);
509521
auto err = module.load_forward();
522+
if (cfg.expect_reject) {
523+
// D not a multiple of 4 must be rejected at load by the head_dim guard.
524+
if (err != Error::Ok) {
525+
printf("PASS: %s rejected at load (error %d)\n", cfg.name, (int)err);
526+
return true;
527+
}
528+
printf("FAIL: %s loaded OK; head_dim%%4 guard did not fire\n", cfg.name);
529+
return false;
530+
}
510531
if (err != Error::Ok) {
511532
printf("FAIL: could not load forward method (error %d)\n", (int)err);
512533
return false;

0 commit comments

Comments
 (0)