Skip to content

Commit a0a730a

Browse files
[ExecuTorch][WebGPU] Register-tile the SDPA QK/AV kernels (#20507)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #20405 by @JulianCloudNTH ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/49/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/49/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/49/orig @diff-train-skip-merge --------- Co-authored-by: Julian Ng-Thow-Hing <juliannth@meta.com>
1 parent 5a920c3 commit a0a730a

7 files changed

Lines changed: 417 additions & 83 deletions

File tree

backends/webgpu/runtime/ops/sdpa/Sdpa.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ namespace executorch::backends::webgpu {
2626

2727
namespace {
2828

29+
// Register-tile dims; MUST match TM/TN in the reg WGSL kernels.
30+
constexpr int64_t kSdpaTileM = 4;
31+
constexpr int64_t kSdpaTileN = 4;
32+
2933
// Uniform param structs (all 16-byte aligned, matching the WGSL Params).
3034
struct UpdateCacheParams {
3135
uint32_t numel;
@@ -335,6 +339,11 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
335339
if (k.dims[kn - 1] != D || v.dims[v.dims.size() - 1] != D) {
336340
throw std::runtime_error("WebGPU sdpa: k/v head_dim must match q");
337341
}
342+
// QK/AV read D as vec4 (no SDPA_PAD_D); head_dim must be a multiple of 4.
343+
if (D % 4 != 0) {
344+
throw std::runtime_error(
345+
"WebGPU sdpa: head_dim (D) must be a multiple of 4");
346+
}
338347
if (v.dims[v.dims.size() - 2] != Hkv) {
339348
throw std::runtime_error("WebGPU sdpa: v num_heads must match k");
340349
}
@@ -464,14 +473,16 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
464473
dynamic_pos,
465474
"update_cache(V)");
466475

467-
// --- Dispatch 3: QK -> attn_weights. One thread per (h,s,c) element.
476+
// --- Dispatch 3: QK -> attn_weights. One thread per TM x TN tile.
468477
{
469478
if (aw_floats > UINT32_MAX) {
470479
throw std::runtime_error(
471480
"WebGPU sdpa: Hq*S*context_len exceeds uint32 max");
472481
}
482+
const int64_t qk_tiles = Hq * utils::div_up(S, kSdpaTileM) *
483+
utils::div_up(context_len, kSdpaTileN);
473484
const uint32_t wgc = utils::compute_1d_workgroup_count(
474-
device, static_cast<uint32_t>(aw_floats), qk_wg, "QK");
485+
device, static_cast<uint32_t>(qk_tiles), qk_wg, "QK");
475486
AttnWeightsParams p = make_attn_weights_params(
476487
S, Hq, Hkv, D, context_len, input_pos, g, scale);
477488
WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p));
@@ -515,12 +526,12 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
515526
softmax_buf = ubuf;
516527
}
517528

518-
// --- Dispatch 5: AV -> out. One thread per (s,h,d) output element.
529+
// --- Dispatch 5: AV -> out. One thread per TM x TN tile.
519530
{
520-
const uint64_t out_floats = static_cast<uint64_t>(S) *
521-
static_cast<uint64_t>(Hq) * static_cast<uint64_t>(D);
531+
const int64_t av_tiles =
532+
Hq * utils::div_up(S, kSdpaTileM) * utils::div_up(D, kSdpaTileN);
522533
const uint32_t wgc = utils::compute_1d_workgroup_count(
523-
device, static_cast<uint32_t>(out_floats), av_wg, "AV");
534+
device, static_cast<uint32_t>(av_tiles), av_wg, "AV");
524535
ComputeOutParams p = make_compute_out_params(S, Hq, Hkv, D, context_len, g);
525536
WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p));
526537
BufferBinding bindings[3] = {
@@ -591,9 +602,11 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
591602
AttnWeightsParams qp =
592603
make_attn_weights_params(S, Hq, Hkv, D, ctx, pos, g, scale);
593604
wgpuQueueWriteBuffer(gr.queue(), qk_buf, 0, &qp, sizeof(qp));
605+
const int64_t qk_tiles = Hq * utils::div_up(S, kSdpaTileM) *
606+
utils::div_up(ctx, kSdpaTileN);
594607
const uint32_t qk_wgc = utils::compute_1d_workgroup_count(
595608
gr.device(),
596-
static_cast<uint32_t>(aw_floats),
609+
static_cast<uint32_t>(qk_tiles),
597610
qk_wg,
598611
"QK(resize)");
599612
gr.dispatch_at(qk_idx).workgroup_count_x = qk_wgc;
Lines changed: 82 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@group(0) @binding(0) var<storage, read_write> t_attn_weights: array<f32>;
2-
@group(0) @binding(1) var<storage, read> t_q: array<f32>;
3-
@group(0) @binding(2) var<storage, read> t_k_cache: array<f32>;
2+
@group(0) @binding(1) var<storage, read> t_q: array<vec4<f32>>;
3+
@group(0) @binding(2) var<storage, read> t_k_cache: array<vec4<f32>>;
44

55
struct Params {
66
S: u32,
@@ -19,37 +19,98 @@ const NEG_INF: f32 = -1.0e30;
1919

2020
override wg_size: u32 = 64;
2121

22+
const TM: u32 = 4u;
23+
const TN: u32 = 4u;
24+
25+
// D is a multiple of 4 (host-guarded), so a d4 chunk is fully in-bounds — no per-lane check.
26+
fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4<f32> {
27+
if (s >= params.S) {
28+
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
29+
}
30+
let base = s * params.Hq * params.D + h * params.D + d4;
31+
return t_q[base / 4u];
32+
}
33+
34+
fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4<f32> {
35+
if (c >= params.context_len) {
36+
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
37+
}
38+
let base = c * params.Hkv * params.D + kvh * params.D + d4;
39+
return t_k_cache[base / 4u];
40+
}
41+
42+
fn store_qk(s: u32, c: u32, h: u32, raw: f32) {
43+
if (s >= params.S || c >= params.context_len) {
44+
return;
45+
}
46+
var val = raw * params.scale;
47+
// Causal mask: position c may not attend beyond s + input_pos.
48+
if (c > s + params.input_pos) {
49+
val = NEG_INF;
50+
}
51+
let idx = h * params.S * params.context_len + s * params.context_len + c;
52+
t_attn_weights[idx] = val;
53+
}
54+
2255
@compute @workgroup_size(wg_size, 1, 1)
2356
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
24-
let total = params.Hq * params.S * params.context_len;
25-
let idx = gid.x;
26-
if (idx >= total) {
57+
let nrt = (params.S + TM - 1u) / TM;
58+
let nct = (params.context_len + TN - 1u) / TN;
59+
let tiles = nrt * nct;
60+
let total = tiles * params.Hq;
61+
if (gid.x >= total) {
2762
return;
2863
}
29-
let c = idx % params.context_len;
30-
let s = (idx / params.context_len) % params.S;
31-
let h = idx / (params.context_len * params.S);
3264

65+
let h = gid.x / tiles;
66+
let rem = gid.x % tiles;
67+
let row_tile = rem / nct;
68+
let col_tile = rem % nct;
3369
let kvh = h / params.g;
70+
let s0 = row_tile * TM;
71+
let c0 = col_tile * TN;
3472

35-
let q_base = s * params.Hq * params.D + h * params.D;
36-
let k_base = c * params.Hkv * params.D + kvh * params.D;
73+
var acc: array<vec4<f32>, 4>;
74+
acc[0] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
75+
acc[1] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
76+
acc[2] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
77+
acc[3] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
3778

38-
var acc: f32 = 0.0;
39-
var d: u32 = 0u;
79+
// Skip fully-masked causal tiles; mirrors Vulkan attn_weights_tiled.glsl.
80+
let skip_tile = c0 > s0 + (TM - 1u) + params.input_pos;
81+
var d4: u32 = 0u;
4082
loop {
41-
if (d >= params.D) {
83+
if (d4 >= params.D || skip_tile) {
4284
break;
4385
}
44-
acc = acc + t_q[q_base + d] * t_k_cache[k_base + d];
45-
d = d + 1u;
86+
var q: array<vec4<f32>, TM>;
87+
var k: array<vec4<f32>, TN>;
88+
for (var i: u32 = 0u; i < TM; i = i + 1u) {
89+
q[i] = load_q_vec4(s0 + i, h, d4);
90+
}
91+
for (var j: u32 = 0u; j < TN; j = j + 1u) {
92+
k[j] = load_k_vec4(c0 + j, kvh, d4);
93+
}
94+
for (var i: u32 = 0u; i < TM; i = i + 1u) {
95+
acc[i] += vec4<f32>(
96+
dot(q[i], k[0]),
97+
dot(q[i], k[1]),
98+
dot(q[i], k[2]),
99+
dot(q[i], k[3]));
100+
}
101+
d4 = d4 + 4u;
46102
}
47-
acc = acc * params.scale;
48103

49-
// Causal mask: position c may not attend beyond s + input_pos.
50-
if (c > s + params.input_pos) {
51-
acc = NEG_INF;
104+
var m: u32 = 0u;
105+
loop {
106+
if (m >= TM) {
107+
break;
108+
}
109+
let av = acc[m];
110+
store_qk(s0 + m, c0 + 0u, h, av.x);
111+
store_qk(s0 + m, c0 + 1u, h, av.y);
112+
store_qk(s0 + m, c0 + 2u, h, av.z);
113+
store_qk(s0 + m, c0 + 3u, h, av.w);
114+
m = m + 1u;
52115
}
53-
54-
t_attn_weights[idx] = acc;
55116
}

backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h

Lines changed: 83 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
namespace executorch::backends::webgpu {
1414

1515
// @generated from sdpa_compute_attn_weights.wgsl - DO NOT EDIT.
16-
// wgsl-sha256: 7410869c1c35f09777851bf49b835dc8fecaff3f327aa64a9c900ac0cc3445e1
16+
// wgsl-sha256: 4eef09b234fd926cdc0daf18d03e39cf4fd57dfa4bc67724b4878b7dc68d1254
1717
inline constexpr const char* kSdpaComputeAttnWeightsWGSL = R"(
1818
@group(0) @binding(0) var<storage, read_write> t_attn_weights: array<f32>;
19-
@group(0) @binding(1) var<storage, read> t_q: array<f32>;
20-
@group(0) @binding(2) var<storage, read> t_k_cache: array<f32>;
19+
@group(0) @binding(1) var<storage, read> t_q: array<vec4<f32>>;
20+
@group(0) @binding(2) var<storage, read> t_k_cache: array<vec4<f32>>;
2121
2222
struct Params {
2323
S: u32,
@@ -36,39 +36,100 @@ const NEG_INF: f32 = -1.0e30;
3636
3737
override wg_size: u32 = 64;
3838
39+
const TM: u32 = 4u;
40+
const TN: u32 = 4u;
41+
42+
// D is a multiple of 4 (host-guarded), so a d4 chunk is fully in-bounds — no per-lane check.
43+
fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4<f32> {
44+
if (s >= params.S) {
45+
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
46+
}
47+
let base = s * params.Hq * params.D + h * params.D + d4;
48+
return t_q[base / 4u];
49+
}
50+
51+
fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4<f32> {
52+
if (c >= params.context_len) {
53+
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
54+
}
55+
let base = c * params.Hkv * params.D + kvh * params.D + d4;
56+
return t_k_cache[base / 4u];
57+
}
58+
59+
fn store_qk(s: u32, c: u32, h: u32, raw: f32) {
60+
if (s >= params.S || c >= params.context_len) {
61+
return;
62+
}
63+
var val = raw * params.scale;
64+
// Causal mask: position c may not attend beyond s + input_pos.
65+
if (c > s + params.input_pos) {
66+
val = NEG_INF;
67+
}
68+
let idx = h * params.S * params.context_len + s * params.context_len + c;
69+
t_attn_weights[idx] = val;
70+
}
71+
3972
@compute @workgroup_size(wg_size, 1, 1)
4073
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
41-
let total = params.Hq * params.S * params.context_len;
42-
let idx = gid.x;
43-
if (idx >= total) {
74+
let nrt = (params.S + TM - 1u) / TM;
75+
let nct = (params.context_len + TN - 1u) / TN;
76+
let tiles = nrt * nct;
77+
let total = tiles * params.Hq;
78+
if (gid.x >= total) {
4479
return;
4580
}
46-
let c = idx % params.context_len;
47-
let s = (idx / params.context_len) % params.S;
48-
let h = idx / (params.context_len * params.S);
4981
82+
let h = gid.x / tiles;
83+
let rem = gid.x % tiles;
84+
let row_tile = rem / nct;
85+
let col_tile = rem % nct;
5086
let kvh = h / params.g;
87+
let s0 = row_tile * TM;
88+
let c0 = col_tile * TN;
5189
52-
let q_base = s * params.Hq * params.D + h * params.D;
53-
let k_base = c * params.Hkv * params.D + kvh * params.D;
90+
var acc: array<vec4<f32>, 4>;
91+
acc[0] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
92+
acc[1] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
93+
acc[2] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
94+
acc[3] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
5495
55-
var acc: f32 = 0.0;
56-
var d: u32 = 0u;
96+
// Skip fully-masked causal tiles; mirrors Vulkan attn_weights_tiled.glsl.
97+
let skip_tile = c0 > s0 + (TM - 1u) + params.input_pos;
98+
var d4: u32 = 0u;
5799
loop {
58-
if (d >= params.D) {
100+
if (d4 >= params.D || skip_tile) {
59101
break;
60102
}
61-
acc = acc + t_q[q_base + d] * t_k_cache[k_base + d];
62-
d = d + 1u;
103+
var q: array<vec4<f32>, TM>;
104+
var k: array<vec4<f32>, TN>;
105+
for (var i: u32 = 0u; i < TM; i = i + 1u) {
106+
q[i] = load_q_vec4(s0 + i, h, d4);
107+
}
108+
for (var j: u32 = 0u; j < TN; j = j + 1u) {
109+
k[j] = load_k_vec4(c0 + j, kvh, d4);
110+
}
111+
for (var i: u32 = 0u; i < TM; i = i + 1u) {
112+
acc[i] += vec4<f32>(
113+
dot(q[i], k[0]),
114+
dot(q[i], k[1]),
115+
dot(q[i], k[2]),
116+
dot(q[i], k[3]));
117+
}
118+
d4 = d4 + 4u;
63119
}
64-
acc = acc * params.scale;
65120
66-
// Causal mask: position c may not attend beyond s + input_pos.
67-
if (c > s + params.input_pos) {
68-
acc = NEG_INF;
121+
var m: u32 = 0u;
122+
loop {
123+
if (m >= TM) {
124+
break;
125+
}
126+
let av = acc[m];
127+
store_qk(s0 + m, c0 + 0u, h, av.x);
128+
store_qk(s0 + m, c0 + 1u, h, av.y);
129+
store_qk(s0 + m, c0 + 2u, h, av.z);
130+
store_qk(s0 + m, c0 + 3u, h, av.w);
131+
m = m + 1u;
69132
}
70-
71-
t_attn_weights[idx] = acc;
72133
}
73134
)";
74135

0 commit comments

Comments
 (0)