Skip to content

Commit 13d36cf

Browse files
ggml-webgpu: enable FLASH_ATTN_EXT on browser without subgroup matrix (ggml-org#22199)
* ggml-webgpu: add tile flash attention fallback * ggml-webgpu: add new fields and discard usage of mnk for tile version * ggml-webgpu: modify the vec path to discard the mnk parameter * ggml-webgpu: enable flash attention vec and tile version for broswer * ggml-webgpu: stagging KV for flash attention tile version * formatting * turn on subgroup uniformity check * remove Q_TILE as it is always 1 for vec path * make row_max and exp_sum to local register * make different bindings with same underlying buffer to have the same usage flags * move path selection into the shader library and have the host consume a single flash-attn decision object. * turn off skip_validation and address buffer overlapping when nwg==1 * formatting * merge binding when kv overlap
1 parent f65bc34 commit 13d36cf

6 files changed

Lines changed: 809 additions & 392 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 180 additions & 146 deletions
Large diffs are not rendered by default.

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 119 additions & 74 deletions
Large diffs are not rendered by default.

ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,26 +138,55 @@ struct Params {
138138
};
139139

140140
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
141+
#ifdef KV_OVERLAP
142+
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
143+
#define V K
144+
#else
141145
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
142146
@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
147+
#endif
143148

144149
#if defined(MASK) && defined(SINKS)
150+
#ifdef KV_OVERLAP
151+
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
152+
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
153+
#define DST_BINDING 4
154+
#define PARAMS_BINDING 5
155+
#else
145156
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
146157
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
147158
#define DST_BINDING 5
148159
#define PARAMS_BINDING 6
160+
#endif
149161
#elif defined(MASK)
162+
#ifdef KV_OVERLAP
163+
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
164+
#define DST_BINDING 3
165+
#define PARAMS_BINDING 4
166+
#else
150167
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
151168
#define DST_BINDING 4
152169
#define PARAMS_BINDING 5
170+
#endif
153171
#elif defined(SINKS)
172+
#ifdef KV_OVERLAP
173+
@group(0) @binding(2) var<storage, read_write> sinks: array<f32>;
174+
#define DST_BINDING 3
175+
#define PARAMS_BINDING 4
176+
#else
154177
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
155178
#define DST_BINDING 4
156179
#define PARAMS_BINDING 5
180+
#endif
181+
#else
182+
#ifdef KV_OVERLAP
183+
#define DST_BINDING 2
184+
#define PARAMS_BINDING 3
157185
#else
158186
#define DST_BINDING 3
159187
#define PARAMS_BINDING 4
160188
#endif
189+
#endif
161190

162191
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
163192
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
enable f16;
2+
enable subgroups;
3+
4+
#define HEAD_DIM_QK 64
5+
#define HEAD_DIM_V 64
6+
#define KV_STAGE_STRIDE 64
7+
#define Q_TILE 4
8+
#define KV_TILE 64
9+
#define WG_SIZE 128
10+
11+
struct Params {
12+
offset_q: u32,
13+
offset_k: u32,
14+
offset_v: u32,
15+
offset_mask: u32,
16+
offset_sinks: u32,
17+
offset_dst: u32,
18+
19+
n_heads: u32,
20+
seq_len_q: u32,
21+
seq_len_kv: u32,
22+
23+
stride_q1: u32,
24+
stride_q2: u32,
25+
stride_q3: u32,
26+
stride_k1: u32,
27+
stride_k2: u32,
28+
stride_k3: u32,
29+
stride_v1: u32,
30+
stride_v2: u32,
31+
stride_v3: u32,
32+
stride_mask3: u32,
33+
34+
q_per_kv: u32,
35+
36+
scale: f32,
37+
max_bias: f32,
38+
logit_softcap: f32,
39+
n_head_log2: f32,
40+
m0: f32,
41+
m1: f32,
42+
};
43+
44+
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
45+
#ifdef KV_OVERLAP
46+
@group(0) @binding(1) var<storage, read_write> K: array<vec4<f16>>;
47+
#define V K
48+
#else
49+
@group(0) @binding(1) var<storage, read_write> K: array<vec4<f16>>;
50+
@group(0) @binding(2) var<storage, read_write> V: array<vec4<f16>>;
51+
#endif
52+
53+
#if defined(MASK) && defined(SINKS)
54+
#ifdef KV_OVERLAP
55+
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
56+
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
57+
#define DST_BINDING 4
58+
#define PARAMS_BINDING 5
59+
#else
60+
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
61+
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
62+
#define DST_BINDING 5
63+
#define PARAMS_BINDING 6
64+
#endif
65+
#elif defined(MASK)
66+
#ifdef KV_OVERLAP
67+
@group(0) @binding(2) var<storage, read_write> mask: array<f16>;
68+
#define DST_BINDING 3
69+
#define PARAMS_BINDING 4
70+
#else
71+
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
72+
#define DST_BINDING 4
73+
#define PARAMS_BINDING 5
74+
#endif
75+
#elif defined(SINKS)
76+
#ifdef KV_OVERLAP
77+
@group(0) @binding(2) var<storage, read_write> sinks: array<f32>;
78+
#define DST_BINDING 3
79+
#define PARAMS_BINDING 4
80+
#else
81+
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
82+
#define DST_BINDING 4
83+
#define PARAMS_BINDING 5
84+
#endif
85+
#else
86+
#ifdef KV_OVERLAP
87+
#define DST_BINDING 2
88+
#define PARAMS_BINDING 3
89+
#else
90+
#define DST_BINDING 3
91+
#define PARAMS_BINDING 4
92+
#endif
93+
#endif
94+
95+
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
96+
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
97+
98+
const FLOAT_MIN: f32 = -1.0e9;
99+
const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u;
100+
const V_CHUNKS: u32 = HEAD_DIM_V / 4u;
101+
const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE;
102+
const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE;
103+
104+
var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
105+
var<workgroup> kv_shmem: array<f16, KV_TILE * KV_STAGE_STRIDE>;
106+
var<workgroup> p_shmem: array<f32, Q_TILE * KV_TILE>;
107+
108+
@compute @workgroup_size(WG_SIZE)
109+
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
110+
@builtin(local_invocation_id) local_id: vec3<u32>,
111+
@builtin(subgroup_id) subgroup_id: u32,
112+
@builtin(subgroup_size) subgroup_size: u32,
113+
@builtin(num_subgroups) num_subgroups: u32,
114+
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
115+
if (subgroup_size == 0u || num_subgroups < Q_TILE) {
116+
return;
117+
}
118+
119+
let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
120+
let wg_per_batch = wg_per_head * params.n_heads;
121+
122+
let dst2_stride = HEAD_DIM_V * params.n_heads;
123+
let dst3_stride = dst2_stride * params.seq_len_q;
124+
125+
let batch_idx = wg_id.x / wg_per_batch;
126+
let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;
127+
let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;
128+
let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;
129+
let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride;
130+
let wg_in_batch = wg_id.x % wg_per_batch;
131+
132+
let head_idx = wg_in_batch / wg_per_head;
133+
let q_head_offset = q_batch_offset + head_idx * params.stride_q2;
134+
let k_head_idx = head_idx / params.q_per_kv;
135+
let v_head_offset = v_batch_offset + k_head_idx * params.stride_v2;
136+
let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
137+
138+
let wg_in_head = wg_in_batch % wg_per_head;
139+
let q_row_start = wg_in_head * Q_TILE;
140+
let global_q_row = q_row_start + subgroup_id;
141+
let row_active = subgroup_id < Q_TILE && global_q_row < params.seq_len_q;
142+
143+
#ifdef MASK
144+
let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;
145+
#endif
146+
147+
let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V;
148+
149+
let head = f32(head_idx);
150+
let slope = select(1.0,
151+
select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0),
152+
pow(params.m0, head + 1.0),
153+
head < params.n_head_log2),
154+
params.max_bias > 0.0);
155+
156+
for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
157+
let q_tile_row = elem_idx / HEAD_DIM_QK;
158+
let q_col = elem_idx % HEAD_DIM_QK;
159+
let head_q_row = q_row_start + q_tile_row;
160+
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
161+
q_shmem[elem_idx] = f16(select(
162+
0.0,
163+
Q[global_q_row_offset + q_col] * params.scale,
164+
head_q_row < params.seq_len_q));
165+
}
166+
167+
workgroupBarrier();
168+
169+
var row_max = FLOAT_MIN;
170+
var exp_sum = 0.0;
171+
var out_regs: array<vec4<f32>, OUT_REGS_PER_LANE>;
172+
for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) {
173+
out_regs[reg_idx] = vec4<f32>(0.0);
174+
}
175+
176+
let q_base = subgroup_id * HEAD_DIM_QK;
177+
let subgroup_p_offset = subgroup_id * KV_TILE;
178+
179+
for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
180+
let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile);
181+
let score_slots = min(SCORE_REGS_PER_LANE, (kv_count + subgroup_size - 1u) / subgroup_size);
182+
let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size);
183+
var local_scores: array<f32, SCORE_REGS_PER_LANE>;
184+
for (var slot = 0u; slot < SCORE_REGS_PER_LANE; slot += 1u) {
185+
local_scores[slot] = FLOAT_MIN;
186+
}
187+
188+
for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) {
189+
let kv_local = vec_idx_local / Q_CHUNKS;
190+
let chunk = vec_idx_local % Q_CHUNKS;
191+
let global_k_row = kv_tile + kv_local;
192+
let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u;
193+
let k4 = K[k_vec_index];
194+
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
195+
kv_shmem[kv_off + 0u] = k4.x;
196+
kv_shmem[kv_off + 1u] = k4.y;
197+
kv_shmem[kv_off + 2u] = k4.z;
198+
kv_shmem[kv_off + 3u] = k4.w;
199+
}
200+
201+
workgroupBarrier();
202+
203+
var local_max = FLOAT_MIN;
204+
if (row_active) {
205+
for (var slot = 0u; slot < score_slots; slot += 1u) {
206+
let kv_local = sg_inv_id + slot * subgroup_size;
207+
if (kv_local >= kv_count) {
208+
continue;
209+
}
210+
211+
let global_k_row = kv_tile + kv_local;
212+
var dot_val = 0.0;
213+
for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) {
214+
let q_off = q_base + chunk * 4u;
215+
let qv = vec4<f32>(
216+
f32(q_shmem[q_off + 0u]),
217+
f32(q_shmem[q_off + 1u]),
218+
f32(q_shmem[q_off + 2u]),
219+
f32(q_shmem[q_off + 3u]));
220+
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
221+
let kv = vec4<f32>(
222+
f32(kv_shmem[kv_off + 0u]),
223+
f32(kv_shmem[kv_off + 1u]),
224+
f32(kv_shmem[kv_off + 2u]),
225+
f32(kv_shmem[kv_off + 3u]));
226+
dot_val += dot(qv, kv);
227+
}
228+
#ifdef LOGIT_SOFTCAP
229+
dot_val = params.logit_softcap * tanh(dot_val);
230+
#endif
231+
#ifdef MASK
232+
let mask_idx = mask_global_offset + subgroup_id * params.seq_len_kv + global_k_row;
233+
dot_val += slope * f32(mask[mask_idx]);
234+
#endif
235+
local_scores[slot] = dot_val;
236+
local_max = max(local_max, dot_val);
237+
}
238+
}
239+
240+
let tile_max = subgroupMax(local_max);
241+
let new_max = max(row_max, tile_max);
242+
let cur_exp = exp(row_max - new_max);
243+
exp_sum *= cur_exp;
244+
for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) {
245+
out_regs[reg_idx] *= cur_exp;
246+
}
247+
248+
var local_sum = 0.0;
249+
for (var slot = 0u; slot < score_slots; slot += 1u) {
250+
let kv_local = sg_inv_id + slot * subgroup_size;
251+
if (row_active && kv_local < kv_count) {
252+
let p = exp(local_scores[slot] - new_max);
253+
p_shmem[subgroup_p_offset + kv_local] = p;
254+
local_sum += p;
255+
}
256+
}
257+
258+
workgroupBarrier();
259+
260+
for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) {
261+
let kv_local = vec_idx_local / V_CHUNKS;
262+
let chunk = vec_idx_local % V_CHUNKS;
263+
let global_v_row = kv_tile + kv_local;
264+
let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u;
265+
let v4 = V[v_vec_index];
266+
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
267+
kv_shmem[kv_off + 0u] = v4.x;
268+
kv_shmem[kv_off + 1u] = v4.y;
269+
kv_shmem[kv_off + 2u] = v4.z;
270+
kv_shmem[kv_off + 3u] = v4.w;
271+
}
272+
273+
workgroupBarrier();
274+
275+
let tile_sum = subgroupAdd(local_sum);
276+
exp_sum += tile_sum;
277+
row_max = new_max;
278+
279+
if (row_active) {
280+
for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) {
281+
let chunk = sg_inv_id + reg_idx * subgroup_size;
282+
if (chunk >= V_CHUNKS) {
283+
continue;
284+
}
285+
286+
var acc = out_regs[reg_idx];
287+
for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) {
288+
let p = p_shmem[subgroup_p_offset + kv_local];
289+
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
290+
let v4 = vec4<f32>(
291+
f32(kv_shmem[kv_off + 0u]),
292+
f32(kv_shmem[kv_off + 1u]),
293+
f32(kv_shmem[kv_off + 2u]),
294+
f32(kv_shmem[kv_off + 3u]));
295+
acc += p * v4;
296+
}
297+
out_regs[reg_idx] = acc;
298+
}
299+
}
300+
301+
workgroupBarrier();
302+
}
303+
304+
#ifdef SINKS
305+
if (row_active) {
306+
let sink_score = sinks[params.offset_sinks + head_idx];
307+
let sink_max = max(row_max, sink_score);
308+
let sink_scale = exp(row_max - sink_max);
309+
for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) {
310+
out_regs[reg_idx] *= sink_scale;
311+
}
312+
exp_sum = exp_sum * sink_scale + exp(sink_score - sink_max);
313+
row_max = sink_max;
314+
}
315+
#endif
316+
317+
if (row_active) {
318+
let inv_exp_sum = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
319+
let row_base = dst_global_offset + subgroup_id * dst2_stride;
320+
let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size);
321+
for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) {
322+
let chunk = sg_inv_id + reg_idx * subgroup_size;
323+
if (chunk >= V_CHUNKS) {
324+
continue;
325+
}
326+
let dst_vec_index = (row_base + chunk * 4u) >> 2u;
327+
dst[dst_vec_index] = out_regs[reg_idx] * inv_exp_sum;
328+
}
329+
}
330+
}

ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ struct Params {
1515
nblk1: u32,
1616
};
1717

18-
@group(0) @binding(0) var<storage, read> mask: array<f16>;
18+
@group(0) @binding(0) var<storage, read_write> mask: array<f16>;
1919
@group(0) @binding(1) var<storage, read_write> blk: array<u32>;
2020
@group(0) @binding(2) var<uniform> params: Params;
2121

0 commit comments

Comments
 (0)