Skip to content

Commit 78dd8c1

Browse files
[ExecuTorch][WebGPU] Coalesce SDPA AV V-cache reads along contiguous head-dim
Pull Request resolved: #20459 **~19% faster SDPA attention-output (AV) stage** — 393→317 µs on llama3 prefill (Chrome Canary / M4 Pro). **Problem**: V-cache reads load 4 strided context rows × 1 head-dim lane, missing coalescing. **Solution**: Flip access pattern to read 4 contiguous head-dim lanes per context row: - **Before**: `load_v_vec4(d, kvh, c4)` → 4 strided rows, `dot()` along D - **After**: `load_v_d4(c, kvh, d0)` → 4 contiguous D-lanes (16-byte texel), scalar broadcast **Implementation**: - Reindex `load_v` helper to read contiguous head-dim - Replace `dot(A, V)` with `acc += A[c] * V_vec4(d0:d0+3)` - Mirrors Vulkan `load_v_cache_d4` coalescing pattern **Constraints**: - No KV-cache layout change (still `[C, Hkv, D]`) - Output numerically identical (FP-reassociated, max abs diff 1.43e-6 vs torch) ghstack-source-id: 395771238 @exported-using-ghexport Differential Revision: [D109339276](https://our.internmc.facebook.com/intern/diff/D109339276/)
1 parent c13a2ce commit 78dd8c1

2 files changed

Lines changed: 31 additions & 33 deletions

File tree

backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,16 @@ fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4<f32> {
3232
return r;
3333
}
3434

35-
fn load_v_vec4(d: u32, kvh: u32, c4: u32) -> vec4<f32> {
35+
fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4<f32> {
3636
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
37-
if (d >= params.D) {
37+
if (c >= params.context_len) {
3838
return r;
3939
}
40-
let stride = params.Hkv * params.D;
41-
let off = kvh * params.D + d;
42-
if (c4 + 0u < params.context_len) { r.x = t_v_cache[(c4 + 0u) * stride + off]; }
43-
if (c4 + 1u < params.context_len) { r.y = t_v_cache[(c4 + 1u) * stride + off]; }
44-
if (c4 + 2u < params.context_len) { r.z = t_v_cache[(c4 + 2u) * stride + off]; }
45-
if (c4 + 3u < params.context_len) { r.w = t_v_cache[(c4 + 3u) * stride + off]; }
40+
let base = c * params.Hkv * params.D + kvh * params.D + d0;
41+
if (d0 + 0u < params.D) { r.x = t_v_cache[base + 0u]; }
42+
if (d0 + 1u < params.D) { r.y = t_v_cache[base + 1u]; }
43+
if (d0 + 2u < params.D) { r.z = t_v_cache[base + 2u]; }
44+
if (d0 + 3u < params.D) { r.w = t_v_cache[base + 3u]; }
4645
return r;
4746
}
4847

@@ -87,14 +86,14 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
8786
let a1 = load_a_vec4(s0 + 1u, h, c4);
8887
let a2 = load_a_vec4(s0 + 2u, h, c4);
8988
let a3 = load_a_vec4(s0 + 3u, h, c4);
90-
let v0 = load_v_vec4(d0 + 0u, kvh, c4);
91-
let v1 = load_v_vec4(d0 + 1u, kvh, c4);
92-
let v2 = load_v_vec4(d0 + 2u, kvh, c4);
93-
let v3 = load_v_vec4(d0 + 3u, kvh, c4);
94-
acc[0] += vec4<f32>(dot(a0, v0), dot(a0, v1), dot(a0, v2), dot(a0, v3));
95-
acc[1] += vec4<f32>(dot(a1, v0), dot(a1, v1), dot(a1, v2), dot(a1, v3));
96-
acc[2] += vec4<f32>(dot(a2, v0), dot(a2, v1), dot(a2, v2), dot(a2, v3));
97-
acc[3] += vec4<f32>(dot(a3, v0), dot(a3, v1), dot(a3, v2), dot(a3, v3));
89+
let v0 = load_v_d4(c4 + 0u, kvh, d0);
90+
let v1 = load_v_d4(c4 + 1u, kvh, d0);
91+
let v2 = load_v_d4(c4 + 2u, kvh, d0);
92+
let v3 = load_v_d4(c4 + 3u, kvh, d0);
93+
acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3;
94+
acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3;
95+
acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3;
96+
acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3;
9897
c4 = c4 + 4u;
9998
}
10099

backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
namespace executorch::backends::webgpu {
1414

1515
// @generated from sdpa_compute_out.wgsl - DO NOT EDIT.
16-
// wgsl-sha256: 4ffc13bad0bf56b87a57f75307f29e851dd2bd6bf0dba094488df5d262e910e3
16+
// wgsl-sha256: 545f624567b08eba407954034df821010e49124fa6f8fd6b05c64ca4354ee4cc
1717
inline constexpr const char* kSdpaComputeOutWGSL = R"(
1818
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
1919
@group(0) @binding(1) var<storage, read> t_attn_weights_softmax: array<f32>;
@@ -49,17 +49,16 @@ fn load_a_vec4(s: u32, h: u32, c4: u32) -> vec4<f32> {
4949
return r;
5050
}
5151
52-
fn load_v_vec4(d: u32, kvh: u32, c4: u32) -> vec4<f32> {
52+
fn load_v_d4(c: u32, kvh: u32, d0: u32) -> vec4<f32> {
5353
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
54-
if (d >= params.D) {
54+
if (c >= params.context_len) {
5555
return r;
5656
}
57-
let stride = params.Hkv * params.D;
58-
let off = kvh * params.D + d;
59-
if (c4 + 0u < params.context_len) { r.x = t_v_cache[(c4 + 0u) * stride + off]; }
60-
if (c4 + 1u < params.context_len) { r.y = t_v_cache[(c4 + 1u) * stride + off]; }
61-
if (c4 + 2u < params.context_len) { r.z = t_v_cache[(c4 + 2u) * stride + off]; }
62-
if (c4 + 3u < params.context_len) { r.w = t_v_cache[(c4 + 3u) * stride + off]; }
57+
let base = c * params.Hkv * params.D + kvh * params.D + d0;
58+
if (d0 + 0u < params.D) { r.x = t_v_cache[base + 0u]; }
59+
if (d0 + 1u < params.D) { r.y = t_v_cache[base + 1u]; }
60+
if (d0 + 2u < params.D) { r.z = t_v_cache[base + 2u]; }
61+
if (d0 + 3u < params.D) { r.w = t_v_cache[base + 3u]; }
6362
return r;
6463
}
6564
@@ -104,14 +103,14 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
104103
let a1 = load_a_vec4(s0 + 1u, h, c4);
105104
let a2 = load_a_vec4(s0 + 2u, h, c4);
106105
let a3 = load_a_vec4(s0 + 3u, h, c4);
107-
let v0 = load_v_vec4(d0 + 0u, kvh, c4);
108-
let v1 = load_v_vec4(d0 + 1u, kvh, c4);
109-
let v2 = load_v_vec4(d0 + 2u, kvh, c4);
110-
let v3 = load_v_vec4(d0 + 3u, kvh, c4);
111-
acc[0] += vec4<f32>(dot(a0, v0), dot(a0, v1), dot(a0, v2), dot(a0, v3));
112-
acc[1] += vec4<f32>(dot(a1, v0), dot(a1, v1), dot(a1, v2), dot(a1, v3));
113-
acc[2] += vec4<f32>(dot(a2, v0), dot(a2, v1), dot(a2, v2), dot(a2, v3));
114-
acc[3] += vec4<f32>(dot(a3, v0), dot(a3, v1), dot(a3, v2), dot(a3, v3));
106+
let v0 = load_v_d4(c4 + 0u, kvh, d0);
107+
let v1 = load_v_d4(c4 + 1u, kvh, d0);
108+
let v2 = load_v_d4(c4 + 2u, kvh, d0);
109+
let v3 = load_v_d4(c4 + 3u, kvh, d0);
110+
acc[0] += a0.x * v0 + a0.y * v1 + a0.z * v2 + a0.w * v3;
111+
acc[1] += a1.x * v0 + a1.y * v1 + a1.z * v2 + a1.w * v3;
112+
acc[2] += a2.x * v0 + a2.y * v1 + a2.z * v2 + a2.w * v3;
113+
acc[3] += a3.x * v0 + a3.y * v1 + a3.z * v2 + a3.w * v3;
115114
c4 = c4 + 4u;
116115
}
117116

0 commit comments

Comments
 (0)