Skip to content

Commit 239a497

Browse files
authored
ggml-webgpu: address precision issues for multimodal (#22808)
* fix(mixed-types): use f32 for precision and update the shared memory calculation logic for f32 * fix(unary): correct the gelu, gelu quick and gelu erf functions * fix(flash-attn-tile): fix the hardcode v type * fix(flash_attn): fix tile path * fix: pass editorconfig and address the type conflicts * fix: remove reduant pipeline keys * fix: remove inline min/max group size functions and revert the flash attn path order * fix: use clamp to avoid NaN for GELU * fix: use the right range for exp, 80 is safer for f32 exp
1 parent 89730c8 commit 239a497

6 files changed

Lines changed: 295 additions & 186 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 129 additions & 64 deletions
Large diffs are not rendered by default.

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ struct webgpu_capabilities {
187187
uint32_t sg_mat_k = 0;
188188

189189
uint32_t subgroup_size = 0;
190+
uint32_t min_subgroup_size = 0;
190191
uint32_t max_subgroup_size = 0;
191192
size_t memset_bytes_per_thread;
192193
};
@@ -1442,6 +1443,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
14421443
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
14431444
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
14441445
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
1446+
shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size;
14451447
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
14461448

14471449
// Get or create pipeline
@@ -1750,6 +1752,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
17501752
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
17511753
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
17521754
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
1755+
shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size;
17531756
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
17541757
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(
17551758
shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
@@ -3469,6 +3472,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
34693472
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
34703473
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
34713474
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
3475+
shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
34723476
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
34733477

34743478
const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
@@ -3667,8 +3671,9 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
36673671
#endif
36683672
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
36693673

3670-
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
3671-
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
3674+
// Runtime subgroup size can be any supported size in this range. Shaders
3675+
// that allocate per-lane register arrays must size them for the minimum.
3676+
ctx->webgpu_global_ctx->capabilities.min_subgroup_size = info.subgroupMinSize;
36723677
ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize;
36733678
// Initialize device
36743679
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
@@ -4024,11 +4029,14 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
40244029
shader_lib_ctx.dst = const_cast<ggml_tensor *>(op);
40254030
shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups;
40264031
shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix;
4032+
shader_lib_ctx.max_wg_size =
4033+
ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
40274034
shader_lib_ctx.wg_mem_limit_bytes =
40284035
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
40294036
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
40304037
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
40314038
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
4039+
shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
40324040
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
40334041

40344042
const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
@@ -4040,19 +4048,19 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
40404048
break;
40414049
}
40424050
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
4043-
const size_t min_bytes =
4044-
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
4045-
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
4051+
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
4052+
decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask,
4053+
decisions.kv_direct, decisions.path);
40464054
if (min_bytes > limit_bytes) {
40474055
supports_op = false;
40484056
}
40494057
break;
40504058
}
40514059

40524060
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
4053-
const size_t min_bytes =
4054-
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
4055-
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
4061+
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
4062+
decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask,
4063+
decisions.kv_direct, decisions.path);
40564064
if (min_bytes > limit_bytes) {
40574065
supports_op = false;
40584066
}
@@ -4063,9 +4071,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
40634071
supports_op = false;
40644072
break;
40654073
}
4066-
const size_t min_bytes =
4067-
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
4068-
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
4074+
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
4075+
decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask,
4076+
decisions.kv_direct, decisions.path);
40694077
if (min_bytes > limit_bytes) {
40704078
supports_op = false;
40714079
}

ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl

Lines changed: 54 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,33 @@
11
enable f16;
22
enable subgroups;
33

4+
#ifdef Q_F16
5+
#define Q_TYPE f16
6+
#else
7+
#define Q_TYPE f32
8+
#endif
9+
10+
#ifdef KV_F32
11+
#define KV_TYPE f32
12+
#else
13+
#define KV_TYPE f16
14+
#endif
15+
16+
#ifdef DST_F16
17+
#define DST_TYPE f16
18+
#else
19+
#define DST_TYPE f32
20+
#endif
21+
422
#define HEAD_DIM_QK 64
523
#define HEAD_DIM_V 64
624
#define KV_STAGE_STRIDE 64
725
#define Q_TILE 4
826
#define KV_TILE 64
927
#define WG_SIZE 128
28+
#ifndef MIN_SUBGROUP_SIZE
29+
#define MIN_SUBGROUP_SIZE MAX_SUBGROUP_SIZE
30+
#endif
1031

1132
struct Params {
1233
offset_q: u32,
@@ -41,13 +62,13 @@ struct Params {
4162
m1: f32,
4263
};
4364

44-
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
65+
@group(0) @binding(0) var<storage, read_write> Q: array<Q_TYPE>;
4566
#ifdef KV_OVERLAP
46-
@group(0) @binding(1) var<storage, read_write> K: array<vec4<f16>>;
67+
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
4768
#define V K
4869
#else
49-
@group(0) @binding(1) var<storage, read_write> K: array<vec4<f16>>;
50-
@group(0) @binding(2) var<storage, read_write> V: array<vec4<f16>>;
70+
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
71+
@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
5172
#endif
5273

5374
#if defined(MASK) && defined(SINKS)
@@ -92,17 +113,17 @@ struct Params {
92113
#endif
93114
#endif
94115

95-
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
116+
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<DST_TYPE>>;
96117
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
97118

98119
const FLOAT_MIN: f32 = -1.0e9;
99120
const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u;
100121
const V_CHUNKS: u32 = HEAD_DIM_V / 4u;
101-
const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE;
102-
const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE;
122+
const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE;
123+
const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE;
103124

104-
var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
105-
var<workgroup> kv_shmem: array<f16, KV_TILE * KV_STAGE_STRIDE>;
125+
var<workgroup> q_shmem: array<f32, Q_TILE * HEAD_DIM_QK>;
126+
var<workgroup> kv_shmem: array<f32, KV_TILE * KV_STAGE_STRIDE>;
106127
var<workgroup> p_shmem: array<f32, Q_TILE * KV_TILE>;
107128

108129
@compute @workgroup_size(WG_SIZE)
@@ -158,10 +179,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
158179
let q_col = elem_idx % HEAD_DIM_QK;
159180
let head_q_row = q_row_start + q_tile_row;
160181
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
161-
q_shmem[elem_idx] = f16(select(
182+
q_shmem[elem_idx] = select(
162183
0.0,
163-
Q[global_q_row_offset + q_col] * params.scale,
164-
head_q_row < params.seq_len_q));
184+
f32(Q[global_q_row_offset + q_col]) * params.scale,
185+
head_q_row < params.seq_len_q);
165186
}
166187

167188
workgroupBarrier();
@@ -192,10 +213,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
192213
let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u;
193214
let k4 = K[k_vec_index];
194215
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
195-
kv_shmem[kv_off + 0u] = k4.x;
196-
kv_shmem[kv_off + 1u] = k4.y;
197-
kv_shmem[kv_off + 2u] = k4.z;
198-
kv_shmem[kv_off + 3u] = k4.w;
216+
kv_shmem[kv_off + 0u] = f32(k4.x);
217+
kv_shmem[kv_off + 1u] = f32(k4.y);
218+
kv_shmem[kv_off + 2u] = f32(k4.z);
219+
kv_shmem[kv_off + 3u] = f32(k4.w);
199220
}
200221

201222
workgroupBarrier();
@@ -213,16 +234,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
213234
for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) {
214235
let q_off = q_base + chunk * 4u;
215236
let qv = vec4<f32>(
216-
f32(q_shmem[q_off + 0u]),
217-
f32(q_shmem[q_off + 1u]),
218-
f32(q_shmem[q_off + 2u]),
219-
f32(q_shmem[q_off + 3u]));
237+
q_shmem[q_off + 0u],
238+
q_shmem[q_off + 1u],
239+
q_shmem[q_off + 2u],
240+
q_shmem[q_off + 3u]);
220241
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
221242
let kv = vec4<f32>(
222-
f32(kv_shmem[kv_off + 0u]),
223-
f32(kv_shmem[kv_off + 1u]),
224-
f32(kv_shmem[kv_off + 2u]),
225-
f32(kv_shmem[kv_off + 3u]));
243+
kv_shmem[kv_off + 0u],
244+
kv_shmem[kv_off + 1u],
245+
kv_shmem[kv_off + 2u],
246+
kv_shmem[kv_off + 3u]);
226247
dot_val += dot(qv, kv);
227248
}
228249
#ifdef LOGIT_SOFTCAP
@@ -264,10 +285,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
264285
let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u;
265286
let v4 = V[v_vec_index];
266287
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
267-
kv_shmem[kv_off + 0u] = v4.x;
268-
kv_shmem[kv_off + 1u] = v4.y;
269-
kv_shmem[kv_off + 2u] = v4.z;
270-
kv_shmem[kv_off + 3u] = v4.w;
288+
kv_shmem[kv_off + 0u] = f32(v4.x);
289+
kv_shmem[kv_off + 1u] = f32(v4.y);
290+
kv_shmem[kv_off + 2u] = f32(v4.z);
291+
kv_shmem[kv_off + 3u] = f32(v4.w);
271292
}
272293

273294
workgroupBarrier();
@@ -288,10 +309,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
288309
let p = p_shmem[subgroup_p_offset + kv_local];
289310
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
290311
let v4 = vec4<f32>(
291-
f32(kv_shmem[kv_off + 0u]),
292-
f32(kv_shmem[kv_off + 1u]),
293-
f32(kv_shmem[kv_off + 2u]),
294-
f32(kv_shmem[kv_off + 3u]));
312+
kv_shmem[kv_off + 0u],
313+
kv_shmem[kv_off + 1u],
314+
kv_shmem[kv_off + 2u],
315+
kv_shmem[kv_off + 3u]);
295316
acc += p * v4;
296317
}
297318
out_regs[reg_idx] = acc;
@@ -324,7 +345,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
324345
continue;
325346
}
326347
let dst_vec_index = (row_base + chunk * 4u) >> 2u;
327-
dst[dst_vec_index] = out_regs[reg_idx] * inv_exp_sum;
348+
dst[dst_vec_index] = vec4<DST_TYPE>(out_regs[reg_idx] * inv_exp_sum);
328349
}
329350
}
330351
}

ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@ diagnostic(off, subgroup_uniformity);
22
enable f16;
33
enable subgroups;
44

5+
#ifdef DST_F16
6+
#define DST_TYPE f16
7+
#else
8+
#define DST_TYPE f32
9+
#endif
10+
511
// Default values
612
#define HEAD_DIM_V 64
713
#define WG_SIZE 128
@@ -17,7 +23,7 @@ struct Params {
1723
};
1824

1925
@group(0) @binding(0) var<storage, read_write> tmp: array<f32>;
20-
@group(0) @binding(1) var<storage, read_write> dst: array<vec4<f32>>;
26+
@group(0) @binding(1) var<storage, read_write> dst: array<vec4<DST_TYPE>>;
2127
@group(0) @binding(2) var<uniform> params: Params;
2228

2329
const FLOAT_MIN: f32 = -1.0e9;
@@ -72,7 +78,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
7278

7379
if (thread == 0u) {
7480
let dst_vec_index = (row_base + elem_base) >> 2u;
75-
dst[dst_vec_index] = vec4<f32>(sum_x, sum_y, sum_z, sum_w) * inv_s;
81+
dst[dst_vec_index] = vec4<DST_TYPE>(vec4<f32>(sum_x, sum_y, sum_z, sum_w) * inv_s);
7682
}
7783
}
7884
}

0 commit comments

Comments
 (0)