Skip to content

Commit 927dada

Browse files
authored
ggml-webgpu: Enables running gpt-oss-20b (ggml-org#22906)
* Enable to run gpt-oss-20b and refactor mulmat-q * disable test-backend-ops in ubuntu-24-webgpu
1 parent 239a497 commit 927dada

10 files changed

Lines changed: 6134 additions & 5824 deletions

File tree

.github/workflows/build.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,8 @@ jobs:
456456
run: |
457457
cd build
458458
# This is using llvmpipe and runs slower than other backends
459-
ctest -L main --verbose --timeout 900
459+
# test-backend-ops is too slow on llvmpipe, skip it
460+
ctest -L main -E test-backend-ops --verbose --timeout 900
460461
461462
ubuntu-24-webgpu-wasm:
462463
runs-on: ${{ 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}

docs/ops.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Legend:
1818
| ACC ||||||| 🟡 |||||
1919
| ADD ||||| 🟡 |||||||
2020
| ADD1 ||||||||||||
21-
| ADD_ID ||||||||| |||
21+
| ADD_ID ||||||||| |||
2222
| ARANGE ||||||||||||
2323
| ARGMAX ||||||||||||
2424
| ARGSORT |||||| 🟡 | 🟡 |||||
@@ -71,7 +71,7 @@ Legend:
7171
| MUL_MAT_HADAMARD ||||||||||||
7272
| MUL_MAT_ID || 🟡 ||| 🟡 | 🟡 | 🟡 || 🟡 | 🟡 ||
7373
| NEG |||| 🟡 |||| 🟡 ||||
74-
| NORM |||||||| 🟡 | |||
74+
| NORM |||||||| 🟡 | |||
7575
| OPT_STEP_ADAMW ||||||||||||
7676
| OPT_STEP_SGD ||||||||||||
7777
| OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 ||| 🟡 |||| 🟡 |
@@ -118,5 +118,5 @@ Legend:
118118
| TOP_K ||||||| 🟡 | 🟡 ||||
119119
| TRI ||||||||||||
120120
| TRUNC |||| 🟡 ||| 🟡 | 🟡 ||||
121-
| UPSCALE || 🟡 |||| 🟡 ||| |||
121+
| UPSCALE || 🟡 |||| 🟡 ||| |||
122122
| XIELU ||||||||||||

docs/ops/WebGPU.csv

Lines changed: 5737 additions & 5728 deletions
Large diffs are not rendered by default.

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,22 @@ struct ggml_webgpu_binary_pipeline_key_hash {
495495
}
496496
};
497497

498+
/* Add_Id */
499+
500+
struct ggml_webgpu_add_id_pipeline_key {
501+
bool inplace;
502+
503+
bool operator==(const ggml_webgpu_add_id_pipeline_key & other) const { return inplace == other.inplace; }
504+
};
505+
506+
struct ggml_webgpu_add_id_pipeline_key_hash {
507+
size_t operator()(const ggml_webgpu_add_id_pipeline_key & key) const {
508+
size_t seed = 0;
509+
ggml_webgpu_hash_combine(seed, key.inplace);
510+
return seed;
511+
}
512+
};
513+
498514
/** Unary **/
499515

500516
struct ggml_webgpu_unary_pipeline_key {
@@ -1058,7 +1074,9 @@ class ggml_webgpu_shader_lib {
10581074
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
10591075
pad_pipelines; // circular/non-circular
10601076
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
1061-
binary_pipelines; // type/op/inplace/overlap
1077+
binary_pipelines; // type/op/inplace/overlap/src_overlap
1078+
std::unordered_map<ggml_webgpu_add_id_pipeline_key, webgpu_pipeline, ggml_webgpu_add_id_pipeline_key_hash>
1079+
add_id_pipelines; // inplace
10621080
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
10631081
concat_pipelines; // type
10641082
std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
@@ -1433,6 +1451,7 @@ class ggml_webgpu_shader_lib {
14331451
case GGML_TYPE_IQ3_S:
14341452
case GGML_TYPE_IQ1_S:
14351453
case GGML_TYPE_IQ4_NL:
1454+
case GGML_TYPE_MXFP4:
14361455
{
14371456
// Quantized types using u32 buffers for portability.
14381457
defines.push_back("SRC_TYPE=u32");
@@ -1451,6 +1470,7 @@ class ggml_webgpu_shader_lib {
14511470
defines.push_back(type_upper + "_SCALE_MIN");
14521471
defines.push_back(type_upper + "_TABLES");
14531472
defines.push_back(type_upper + "_GRID");
1473+
defines.push_back(type_upper + "_LUT");
14541474

14551475
variant += "_";
14561476
variant += type_str;
@@ -1460,7 +1480,7 @@ class ggml_webgpu_shader_lib {
14601480
if (key.src_type == GGML_TYPE_Q1_0) {
14611481
defines.push_back("BLOCK_SIZE=128u");
14621482
} else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
1463-
key.src_type == GGML_TYPE_IQ4_NL) {
1483+
key.src_type == GGML_TYPE_IQ4_NL || key.src_type == GGML_TYPE_MXFP4) {
14641484
defines.push_back("BLOCK_SIZE=32u");
14651485
} else if (key.src_type >= GGML_TYPE_Q2_K) {
14661486
defines.push_back("BLOCK_SIZE=256u");
@@ -1774,6 +1794,9 @@ class ggml_webgpu_shader_lib {
17741794
defines.push_back(type_upper + "_GRID");
17751795
defines.push_back(type_upper + "_TABLES");
17761796
break;
1797+
case GGML_TYPE_MXFP4:
1798+
defines.push_back(type_upper + "_LUT");
1799+
break;
17771800
default:
17781801
break;
17791802
}
@@ -1908,6 +1931,9 @@ class ggml_webgpu_shader_lib {
19081931
defines.push_back(type_upper + "_GRID");
19091932
defines.push_back(type_upper + "_TABLES");
19101933
break;
1934+
case GGML_TYPE_MXFP4:
1935+
defines.push_back(type_upper + "_LUT");
1936+
break;
19111937
default:
19121938
break;
19131939
}
@@ -2042,6 +2068,7 @@ class ggml_webgpu_shader_lib {
20422068
case GGML_TYPE_IQ3_S:
20432069
case GGML_TYPE_IQ1_S:
20442070
case GGML_TYPE_IQ4_NL:
2071+
case GGML_TYPE_MXFP4:
20452072
{
20462073
// Quantized types using u32 buffers for portability.
20472074
defines.push_back("SRC0_TYPE=u32");
@@ -2169,6 +2196,9 @@ class ggml_webgpu_shader_lib {
21692196
defines.push_back(type_upper + "_GRID");
21702197
defines.push_back(type_upper + "_TABLES");
21712198
break;
2199+
case GGML_TYPE_MXFP4:
2200+
defines.push_back(type_upper + "_LUT");
2201+
break;
21722202
default:
21732203
break;
21742204
}
@@ -2286,6 +2316,9 @@ class ggml_webgpu_shader_lib {
22862316
defines.push_back(type_upper + "_GRID");
22872317
defines.push_back(type_upper + "_TABLES");
22882318
break;
2319+
case GGML_TYPE_MXFP4:
2320+
defines.push_back(type_upper + "_LUT");
2321+
break;
22892322
default:
22902323
break;
22912324
}
@@ -2503,6 +2536,37 @@ class ggml_webgpu_shader_lib {
25032536
return binary_pipelines[key];
25042537
}
25052538

2539+
webgpu_pipeline get_add_id_pipeline(const ggml_webgpu_shader_lib_context & context) {
2540+
ggml_webgpu_add_id_pipeline_key key = {};
2541+
key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst);
2542+
2543+
auto it = add_id_pipelines.find(key);
2544+
if (it != add_id_pipelines.end()) {
2545+
return it->second;
2546+
}
2547+
2548+
std::vector<std::string> defines;
2549+
std::string variant = "add_id";
2550+
const char * shader_src = wgsl_add_id;
2551+
2552+
if (key.inplace) {
2553+
defines.push_back("INPLACE");
2554+
variant += "_inplace";
2555+
}
2556+
2557+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
2558+
2559+
auto processed = preprocessor.preprocess(shader_src, defines);
2560+
auto pipeline_decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
2561+
pipeline_decisions->wg_size = context.max_wg_size;
2562+
pipeline_decisions->inplace = key.inplace;
2563+
2564+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
2565+
pipeline.context = pipeline_decisions;
2566+
add_id_pipelines[key] = pipeline;
2567+
return pipeline;
2568+
}
2569+
25062570
webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
25072571
ggml_webgpu_concat_pipeline_key key = {};
25082572
key.type = context.dst->type;

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,8 +1411,6 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
14111411
case GGML_TYPE_Q3_K:
14121412
case GGML_TYPE_Q2_K:
14131413
case GGML_TYPE_Q1_0:
1414-
use_fast = true;
1415-
break;
14161414
case GGML_TYPE_IQ1_S:
14171415
case GGML_TYPE_IQ1_M:
14181416
case GGML_TYPE_IQ2_XXS:
@@ -1422,6 +1420,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
14221420
case GGML_TYPE_IQ3_S:
14231421
case GGML_TYPE_IQ4_NL:
14241422
case GGML_TYPE_IQ4_XS:
1423+
case GGML_TYPE_MXFP4:
14251424
use_fast = true;
14261425
break;
14271426
default:
@@ -2145,6 +2144,56 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
21452144
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
21462145
}
21472146

2147+
static webgpu_encoded_op ggml_webgpu_add_id(webgpu_context & ctx,
2148+
ggml_tensor * src0,
2149+
ggml_tensor * src1,
2150+
ggml_tensor * src2,
2151+
ggml_tensor * dst) {
2152+
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
2153+
shader_lib_ctx.src0 = src0;
2154+
shader_lib_ctx.src1 = src1;
2155+
shader_lib_ctx.src2 = src2;
2156+
shader_lib_ctx.dst = dst;
2157+
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
2158+
2159+
webgpu_pipeline pipeline = ctx->shader_lib->get_add_id_pipeline(shader_lib_ctx);
2160+
2161+
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
2162+
2163+
std::vector<uint32_t> params = {
2164+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
2165+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
2166+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)),
2167+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
2168+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
2169+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
2170+
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
2171+
(uint32_t) (src2->nb[0] / ggml_type_size(src2->type)),
2172+
(uint32_t) (src2->nb[1] / ggml_type_size(src2->type)),
2173+
(uint32_t) dst->ne[0],
2174+
(uint32_t) dst->ne[1],
2175+
(uint32_t) dst->ne[2],
2176+
};
2177+
2178+
std::vector<wgpu::BindGroupEntry> entries;
2179+
2180+
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0));
2181+
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1));
2182+
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2));
2183+
2184+
if (!decisions->inplace) {
2185+
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, dst));
2186+
}
2187+
2188+
uint32_t wg_x = 1;
2189+
uint32_t wg_y = 1;
2190+
uint32_t total_wg = ggml_nrows(dst);
2191+
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
2192+
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
2193+
2194+
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
2195+
}
2196+
21482197
static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx,
21492198
ggml_tensor * src0,
21502199
ggml_tensor * src1,
@@ -2918,6 +2967,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx,
29182967
case GGML_OP_MUL:
29192968
case GGML_OP_DIV:
29202969
return ggml_webgpu_binary_op(ctx, src0, src1, node);
2970+
case GGML_OP_ADD_ID:
2971+
return ggml_webgpu_add_id(ctx, src0, src1, src2, node);
29212972
case GGML_OP_CONCAT:
29222973
return ggml_webgpu_concat(ctx, src0, src1, node);
29232974
case GGML_OP_REPEAT:
@@ -3867,6 +3918,7 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) {
38673918
case GGML_TYPE_IQ1_M:
38683919
case GGML_TYPE_IQ4_NL:
38693920
case GGML_TYPE_IQ4_XS:
3921+
case GGML_TYPE_MXFP4:
38703922
return true;
38713923
default:
38723924
return false;
@@ -3905,6 +3957,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
39053957
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
39063958
(src1->type == op->type);
39073959
break;
3960+
case GGML_OP_ADD_ID:
3961+
supports_op = src0->type == GGML_TYPE_F32;
3962+
break;
39083963
case GGML_OP_CONCAT:
39093964
supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
39103965
break;
@@ -3962,6 +4017,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
39624017
case GGML_TYPE_IQ1_M:
39634018
case GGML_TYPE_IQ4_NL:
39644019
case GGML_TYPE_IQ4_XS:
4020+
case GGML_TYPE_MXFP4:
39654021
supports_op = true;
39664022
break;
39674023
default:
@@ -4001,6 +4057,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
40014057
case GGML_TYPE_IQ3_S:
40024058
case GGML_TYPE_IQ4_NL:
40034059
case GGML_TYPE_IQ4_XS:
4060+
case GGML_TYPE_MXFP4:
40044061
supports_op = true;
40054062
break;
40064063
default:
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
struct Params {
2+
offset_src0: u32,
3+
offset_src1: u32,
4+
offset_ids: u32,
5+
offset_dst: u32,
6+
7+
nb01: u32,
8+
nb02: u32,
9+
nb11: u32,
10+
nb20: u32,
11+
nb21: u32,
12+
13+
ne0: u32,
14+
ne1: u32,
15+
ne2: u32,
16+
};
17+
18+
@group(0) @binding(0) var<storage, read_write> src0: array<f32>; // [n_embd, n_experts_used, n_token]
19+
@group(0) @binding(1) var<storage, read_write> src1: array<f32>; // [n_embd, n_experts]
20+
@group(0) @binding(2) var<storage, read_write> ids: array<i32>; // [n_experts_used, n_token]
21+
22+
#ifdef INPLACE
23+
24+
@group(0) @binding(3)
25+
var<uniform> params: Params;
26+
27+
#else
28+
29+
@group(0) @binding(3)
30+
var<storage, read_write> dst: array<f32>;
31+
32+
@group(0) @binding(4)
33+
var<uniform> params: Params;
34+
35+
#endif
36+
37+
@compute @workgroup_size(WG_SIZE)
38+
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
39+
@builtin(num_workgroups) num_wg: vec3<u32>,
40+
@builtin(local_invocation_id) local_id: vec3<u32>) {
41+
42+
let wg_linear = wg_id.x + wg_id.y * num_wg.x;
43+
44+
if (wg_linear < params.ne1 * params.ne2) {
45+
let thread_id = local_id.x;
46+
let i2 = wg_linear / params.ne1;
47+
let i1 = wg_linear % params.ne1;
48+
49+
let i11 = u32(ids[params.offset_ids + i1 * params.nb20 + i2 * params.nb21]);
50+
51+
let src0_row = params.offset_src0 + i1 * params.nb01 + i2 * params.nb02;
52+
let src1_row = params.offset_src1 + i11 * params.nb11;
53+
let dst_row = params.offset_dst + i1 * params.ne0 + i2 * (params.ne0 * params.ne1);
54+
55+
for (var i = thread_id;i < params.ne0; i += WG_SIZE) {
56+
#ifdef INPLACE
57+
src0[src0_row + i] = src0[src0_row + i] + src1[src1_row + i];
58+
#else
59+
dst[dst_row + i] = src0[src0_row + i] + src1[src1_row + i];
60+
#endif
61+
}
62+
}
63+
64+
}

ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,3 +896,10 @@ const kvalues_iq4nl = array<i32, 16>(
896896
);
897897

898898
#endif
899+
900+
#ifdef MXFP4_LUT
901+
const kvalues_mxfp4 = array<i32, 16>(
902+
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12
903+
);
904+
#endif
905+

0 commit comments

Comments
 (0)