Skip to content

Commit a95a11e

Browse files
authored
ggml-webgpu: Improve performance of mat-vec and mat-mat for MUL_MAT_ID (ggml-org#22464)
* Add mat-vec fast path of MUL_MAT_ID. * Add shared accumulation vec logic and the other types supports. * Add i-quant mat-mat for MUL_MAT_ID and fix some parts * Remove n_experts from shader_lib_context.
1 parent 5cbfb18 commit a95a11e

5 files changed

Lines changed: 1780 additions & 1295 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 160 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_
664664
}
665665
const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
666666
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
667-
size_t bytes_per_kv = 0;
667+
size_t bytes_per_kv = 0;
668668
if (!key.kv_direct) {
669669
bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v);
670670
}
@@ -701,10 +701,10 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
701701
(v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u);
702702
const bool kv_vec_type_supported =
703703
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
704-
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) &&
705-
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
706-
kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
707-
(context.src2->type == K->type);
704+
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) &&
705+
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
706+
kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
707+
(context.src2->type == K->type);
708708
const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 &&
709709
V->type == GGML_TYPE_F16 && f16_vec4_aligned &&
710710
(context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
@@ -862,9 +862,12 @@ struct ggml_webgpu_mul_mat_shader_decisions {
862862
struct ggml_webgpu_mul_mat_id_pipeline_key {
863863
ggml_type src0_type;
864864
ggml_type src1_type;
865+
uint32_t n_experts;
866+
int vectorized;
865867

866868
bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
867-
return src0_type == other.src0_type && src1_type == other.src1_type;
869+
return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
870+
vectorized == other.vectorized;
868871
}
869872
};
870873

@@ -873,6 +876,8 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
873876
size_t seed = 0;
874877
ggml_webgpu_hash_combine(seed, key.src0_type);
875878
ggml_webgpu_hash_combine(seed, key.src1_type);
879+
ggml_webgpu_hash_combine(seed, key.n_experts);
880+
ggml_webgpu_hash_combine(seed, key.vectorized);
876881
return seed;
877882
}
878883
};
@@ -1023,6 +1028,8 @@ class ggml_webgpu_shader_lib {
10231028
std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed
10241029
std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
10251030
mul_mat_id_pipelines; // src0_type/src1_type
1031+
std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
1032+
mul_mat_id_vec_pipelines; // src0_type/src1_type
10261033

10271034
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
10281035
set_rows_pipelines;
@@ -1516,7 +1523,7 @@ class ggml_webgpu_shader_lib {
15161523
key.type = context.dst->type;
15171524
key.d_state = (int) context.src0->ne[0];
15181525
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
1519-
ggml_webgpu_tensor_overlap(context.src1, context.src5);
1526+
ggml_webgpu_tensor_overlap(context.src1, context.src5);
15201527

15211528
auto it = ssm_scan_pipelines.find(key);
15221529
if (it != ssm_scan_pipelines.end()) {
@@ -1633,10 +1640,10 @@ class ggml_webgpu_shader_lib {
16331640
ggml_webgpu_mul_mat_vec_pipeline_key key = {};
16341641
key.src0_type = context.src0->type;
16351642
key.src1_type = context.src1->type;
1636-
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
1643+
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
16371644
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1638-
1 :
1639-
0;
1645+
1 :
1646+
0;
16401647

16411648
auto it = mul_mat_vec_pipelines.find(key);
16421649
if (it != mul_mat_vec_pipelines.end()) {
@@ -2012,6 +2019,11 @@ class ggml_webgpu_shader_lib {
20122019
ggml_webgpu_mul_mat_id_pipeline_key key = {};
20132020
key.src0_type = context.src0->type;
20142021
key.src1_type = context.src1->type;
2022+
key.n_experts = context.src0->ne[2];
2023+
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 &&
2024+
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
2025+
1 :
2026+
0;
20152027

20162028
auto it = mul_mat_id_pipelines.find(key);
20172029
if (it != mul_mat_id_pipelines.end()) {
@@ -2041,14 +2053,12 @@ class ggml_webgpu_shader_lib {
20412053
switch (context.src0->type) {
20422054
case GGML_TYPE_F32:
20432055
defines.push_back("SRC0_INNER_TYPE=f32");
2044-
defines.push_back("FLOAT");
20452056
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
20462057
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
20472058
variant += "_f32";
20482059
break;
20492060
case GGML_TYPE_F16:
20502061
defines.push_back("SRC0_INNER_TYPE=f16");
2051-
defines.push_back("FLOAT");
20522062
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
20532063
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
20542064
variant += "_f16";
@@ -2064,12 +2074,32 @@ class ggml_webgpu_shader_lib {
20642074
defines.push_back("U32_DEQUANT_HELPERS");
20652075
defines.push_back("SRC0_INNER_TYPE=u32");
20662076

2077+
switch (context.src0->type) {
2078+
case GGML_TYPE_IQ1_S:
2079+
case GGML_TYPE_IQ1_M:
2080+
case GGML_TYPE_IQ4_NL:
2081+
case GGML_TYPE_IQ4_XS:
2082+
defines.push_back(type_upper + "_GRID");
2083+
break;
2084+
case GGML_TYPE_IQ2_XXS:
2085+
case GGML_TYPE_IQ2_XS:
2086+
case GGML_TYPE_IQ2_S:
2087+
case GGML_TYPE_IQ3_XXS:
2088+
case GGML_TYPE_IQ3_S:
2089+
defines.push_back(type_upper + "_GRID");
2090+
defines.push_back(type_upper + "_TABLES");
2091+
break;
2092+
default:
2093+
break;
2094+
}
2095+
20672096
variant += std::string("_") + src0_name;
20682097
break;
20692098
}
20702099
}
20712100

2072-
defines.push_back("SCALAR");
2101+
// VEC/SCALAR controls
2102+
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
20732103

20742104
// mul_mat_id is register-tile only.
20752105
const uint32_t tile_k =
@@ -2102,6 +2132,123 @@ class ggml_webgpu_shader_lib {
21022132
return mul_mat_id_pipelines[key];
21032133
}
21042134

2135+
webgpu_pipeline get_mul_mat_id_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
2136+
ggml_webgpu_mul_mat_id_pipeline_key key = {};
2137+
key.src0_type = context.src0->type;
2138+
key.src1_type = context.src1->type;
2139+
key.n_experts = context.src0->ne[2];
2140+
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
2141+
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
2142+
1 :
2143+
0;
2144+
2145+
auto it = mul_mat_id_vec_pipelines.find(key);
2146+
if (it != mul_mat_id_vec_pipelines.end()) {
2147+
return it->second;
2148+
}
2149+
2150+
std::vector<std::string> defines;
2151+
std::string variant = "mul_mat_id_vec";
2152+
const char * shader_src = wgsl_mul_mat_id_vec;
2153+
2154+
// src1 type
2155+
switch (context.src1->type) {
2156+
case GGML_TYPE_F32:
2157+
defines.push_back("SRC1_INNER_TYPE=f32");
2158+
break;
2159+
case GGML_TYPE_F16:
2160+
defines.push_back("SRC1_INNER_TYPE=f16");
2161+
break;
2162+
default:
2163+
GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
2164+
}
2165+
2166+
// src0 type
2167+
switch (context.src0->type) {
2168+
case GGML_TYPE_F32:
2169+
defines.push_back("SRC0_INNER_TYPE=f32");
2170+
defines.push_back("MUL_ACC_FLOAT");
2171+
variant += "_f32";
2172+
break;
2173+
case GGML_TYPE_F16:
2174+
defines.push_back("SRC0_INNER_TYPE=f16");
2175+
defines.push_back("MUL_ACC_FLOAT");
2176+
variant += "_f16";
2177+
break;
2178+
default:
2179+
{
2180+
// Quantized types: use helpers but accumulate in f16
2181+
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
2182+
std::string src0_name = src0_traits->type_name;
2183+
std::string type_upper = src0_name;
2184+
variant += "_" + src0_name;
2185+
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
2186+
2187+
defines.push_back("BYTE_HELPERS");
2188+
defines.push_back("MUL_ACC_" + type_upper);
2189+
defines.push_back("U32_DEQUANT_HELPERS");
2190+
defines.push_back("SRC0_INNER_TYPE=u32");
2191+
switch (context.src0->type) {
2192+
case GGML_TYPE_IQ1_S:
2193+
case GGML_TYPE_IQ1_M:
2194+
case GGML_TYPE_IQ2_S:
2195+
case GGML_TYPE_IQ3_S:
2196+
case GGML_TYPE_IQ4_NL:
2197+
case GGML_TYPE_IQ4_XS:
2198+
defines.push_back(type_upper + "_GRID");
2199+
break;
2200+
case GGML_TYPE_IQ2_XXS:
2201+
case GGML_TYPE_IQ2_XS:
2202+
case GGML_TYPE_IQ3_XXS:
2203+
defines.push_back(type_upper + "_GRID");
2204+
defines.push_back(type_upper + "_TABLES");
2205+
break;
2206+
default:
2207+
break;
2208+
}
2209+
break;
2210+
}
2211+
}
2212+
2213+
// VEC/SCALAR controls
2214+
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
2215+
2216+
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
2217+
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
2218+
2219+
if (key.src0_type == GGML_TYPE_Q1_0) {
2220+
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
2221+
} else if (key.src0_type >= GGML_TYPE_Q2_K) {
2222+
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
2223+
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
2224+
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
2225+
}
2226+
2227+
// variant suffix for src1 type
2228+
variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
2229+
2230+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
2231+
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
2232+
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
2233+
variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
2234+
if (key.vectorized) {
2235+
variant += "_vectorized";
2236+
}
2237+
2238+
defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts));
2239+
2240+
auto processed = preprocessor.preprocess(shader_src, defines);
2241+
2242+
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
2243+
decisions->wg_size = wg_size;
2244+
decisions->outputs_per_wg = outputs_per_wg;
2245+
2246+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
2247+
pipeline.context = decisions;
2248+
mul_mat_id_vec_pipelines[key] = pipeline;
2249+
return mul_mat_id_vec_pipelines[key];
2250+
}
2251+
21052252
webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) {
21062253
const bool is_unary = context.dst->op == GGML_OP_UNARY;
21072254
const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op;

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1404,7 +1404,6 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
14041404
case GGML_TYPE_Q5_0:
14051405
case GGML_TYPE_Q5_1:
14061406
case GGML_TYPE_Q8_0:
1407-
case GGML_TYPE_Q8_1:
14081407
case GGML_TYPE_Q6_K:
14091408
case GGML_TYPE_Q4_K:
14101409
case GGML_TYPE_Q5_K:
@@ -1527,11 +1526,74 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
15271526
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
15281527
}
15291528

1529+
static webgpu_encoded_op ggml_webgpu_mul_mat_id_vec(webgpu_context & ctx,
1530+
ggml_tensor * src0,
1531+
ggml_tensor * src1,
1532+
ggml_tensor * src2,
1533+
ggml_tensor * dst) {
1534+
const uint32_t param_n_expert = (uint32_t) src0->ne[2];
1535+
const uint32_t param_n_expert_used = (uint32_t) dst->ne[1];
1536+
1537+
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
1538+
shader_lib_ctx.src0 = src0;
1539+
shader_lib_ctx.src1 = src1;
1540+
shader_lib_ctx.src2 = src2;
1541+
shader_lib_ctx.dst = dst;
1542+
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
1543+
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
1544+
1545+
webgpu_pipeline pipeline = ctx->shader_lib->get_mul_mat_id_vec_pipeline(shader_lib_ctx);
1546+
1547+
std::vector<uint32_t> params = {
1548+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1549+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1550+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)),
1551+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1552+
(uint32_t) src0->ne[0],
1553+
(uint32_t) src0->ne[1],
1554+
param_n_expert,
1555+
param_n_expert_used,
1556+
(uint32_t) src1->ne[1],
1557+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1558+
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
1559+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1560+
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
1561+
};
1562+
1563+
std::vector<wgpu::BindGroupEntry> entries = {
1564+
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0),
1565+
ggml_webgpu_tensor_binding_size(ctx, src0)),
1566+
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), ggml_webgpu_tensor_align_offset(ctx, src1),
1567+
ggml_webgpu_tensor_binding_size(ctx, src1)),
1568+
ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(src2), ggml_webgpu_tensor_align_offset(ctx, src2),
1569+
ggml_webgpu_tensor_binding_size(ctx, src2)),
1570+
ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst),
1571+
ggml_webgpu_tensor_binding_size(ctx, dst)),
1572+
};
1573+
1574+
uint32_t wg_x = 1;
1575+
uint32_t wg_y = 1;
1576+
1577+
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
1578+
1579+
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
1580+
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
1581+
uint32_t total_wg = output_groups * param_n_expert_used;
1582+
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
1583+
1584+
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
1585+
}
1586+
15301587
static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
15311588
ggml_tensor * src0,
15321589
ggml_tensor * src1,
15331590
ggml_tensor * src2,
15341591
ggml_tensor * dst) {
1592+
// we can use mat-vec fast path
1593+
if (dst->ne[2] == 1) {
1594+
return ggml_webgpu_mul_mat_id_vec(ctx, src0, src1, src2, dst);
1595+
}
1596+
15351597
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
15361598
shader_lib_ctx.src0 = src0;
15371599
shader_lib_ctx.src1 = src1;
@@ -3879,6 +3941,15 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
38793941
case GGML_TYPE_Q4_K:
38803942
case GGML_TYPE_Q5_K:
38813943
case GGML_TYPE_Q6_K:
3944+
case GGML_TYPE_IQ1_S:
3945+
case GGML_TYPE_IQ1_M:
3946+
case GGML_TYPE_IQ2_XXS:
3947+
case GGML_TYPE_IQ2_XS:
3948+
case GGML_TYPE_IQ2_S:
3949+
case GGML_TYPE_IQ3_XXS:
3950+
case GGML_TYPE_IQ3_S:
3951+
case GGML_TYPE_IQ4_NL:
3952+
case GGML_TYPE_IQ4_XS:
38823953
supports_op = true;
38833954
break;
38843955
default:

0 commit comments

Comments
 (0)