Skip to content

Commit d0a6dfe

Browse files
yomaytkreeselevine
andauthored
ggml-webgpu: Add the support of MUL_MAT_ID (ggml-org#21147)
* Add mul_mat_id support to WebGPU * Apply suggestion from @reeselevine --------- Co-authored-by: Reese Levine <reeselevine1@gmail.com>
1 parent 2e1f0a8 commit d0a6dfe

File tree

7 files changed

+1113
-620
lines changed

7 files changed

+1113
-620
lines changed

docs/ops.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ Legend:
6868
| MEAN ||||||||||||
6969
| MUL ||||| 🟡 |||||||
7070
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
71-
| MUL_MAT_ID || 🟡 ||| 🟡 | 🟡 | 🟡 || | 🟡 ||
71+
| MUL_MAT_ID || 🟡 ||| 🟡 | 🟡 | 🟡 || 🟡 | 🟡 ||
7272
| NEG |||| 🟡 |||| 🟡 ||||
7373
| NORM |||||||| 🟡 ||||
7474
| OPT_STEP_ADAMW ||||||||||||

docs/ops/WebGPU.csv

Lines changed: 527 additions & 618 deletions
Large diffs are not rendered by default.

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,26 @@ struct ggml_webgpu_mul_mat_shader_decisions {
658658
uint32_t mul_mat_wg_size;
659659
};
660660

661+
/** MUL_MAT_ID **/
662+
663+
struct ggml_webgpu_mul_mat_id_pipeline_key {
664+
ggml_type src0_type;
665+
ggml_type src1_type;
666+
667+
bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
668+
return src0_type == other.src0_type && src1_type == other.src1_type;
669+
}
670+
};
671+
672+
struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
673+
size_t operator()(const ggml_webgpu_mul_mat_id_pipeline_key & key) const {
674+
size_t seed = 0;
675+
ggml_webgpu_hash_combine(seed, key.src0_type);
676+
ggml_webgpu_hash_combine(seed, key.src1_type);
677+
return seed;
678+
}
679+
};
680+
661681
/** Cpy **/
662682

663683
struct ggml_webgpu_cpy_pipeline_key {
@@ -797,7 +817,10 @@ class ggml_webgpu_shader_lib {
797817
std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash>
798818
mul_mat_vec_pipelines; // fast mat-vec (n==1)
799819
std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash>
800-
mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
820+
mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
821+
std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed
822+
std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
823+
mul_mat_id_pipelines; // src0_type/src1_type
801824

802825
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
803826
set_rows_pipelines;
@@ -1598,6 +1621,115 @@ class ggml_webgpu_shader_lib {
15981621
return mul_mat_legacy_pipelines[key];
15991622
}
16001623

1624+
webgpu_pipeline get_mul_mat_id_gather_pipeline(const ggml_webgpu_shader_lib_context & context) {
1625+
auto it = mul_mat_id_gather_pipelines.find(1);
1626+
if (it != mul_mat_id_gather_pipelines.end()) {
1627+
return it->second;
1628+
}
1629+
std::vector<std::string> defines;
1630+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
1631+
1632+
auto processed = preprocessor.preprocess(wgsl_mul_mat_id_gather, defines);
1633+
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
1634+
decisions->wg_size = context.max_wg_size;
1635+
1636+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, "mul_mat_id_gather");
1637+
pipeline.context = decisions;
1638+
mul_mat_id_gather_pipelines[1] = pipeline;
1639+
return pipeline;
1640+
}
1641+
1642+
webgpu_pipeline get_mul_mat_id_pipeline(const ggml_webgpu_shader_lib_context & context) {
1643+
ggml_webgpu_mul_mat_id_pipeline_key key = {
1644+
.src0_type = context.src0->type,
1645+
.src1_type = context.src1->type,
1646+
};
1647+
1648+
auto it = mul_mat_id_pipelines.find(key);
1649+
if (it != mul_mat_id_pipelines.end()) {
1650+
return it->second;
1651+
}
1652+
1653+
std::vector<std::string> defines;
1654+
std::string variant = "mul_mat_id";
1655+
defines.push_back("MUL_MAT_ID");
1656+
1657+
// src1 type
1658+
switch (context.src1->type) {
1659+
case GGML_TYPE_F32:
1660+
defines.push_back("SRC1_INNER_TYPE=f32");
1661+
break;
1662+
case GGML_TYPE_F16:
1663+
defines.push_back("SRC1_INNER_TYPE=f16");
1664+
break;
1665+
default:
1666+
GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
1667+
}
1668+
1669+
// src0 type
1670+
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
1671+
const char * src0_name = src0_traits->type_name;
1672+
1673+
switch (context.src0->type) {
1674+
case GGML_TYPE_F32:
1675+
defines.push_back("SRC0_INNER_TYPE=f32");
1676+
defines.push_back("FLOAT");
1677+
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
1678+
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
1679+
variant += "_f32";
1680+
break;
1681+
case GGML_TYPE_F16:
1682+
defines.push_back("SRC0_INNER_TYPE=f16");
1683+
defines.push_back("FLOAT");
1684+
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
1685+
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
1686+
variant += "_f16";
1687+
break;
1688+
default:
1689+
{
1690+
std::string type_upper = src0_name;
1691+
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
1692+
1693+
defines.push_back("BYTE_HELPERS");
1694+
defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
1695+
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
1696+
defines.push_back("U32_DEQUANT_HELPERS");
1697+
defines.push_back("SRC0_INNER_TYPE=u32");
1698+
1699+
variant += std::string("_") + src0_name;
1700+
break;
1701+
}
1702+
}
1703+
1704+
defines.push_back("SCALAR");
1705+
1706+
// Tiles
1707+
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
1708+
defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
1709+
defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u");
1710+
1711+
defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
1712+
defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
1713+
1714+
// variant suffix for src1 type
1715+
variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
1716+
1717+
auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines);
1718+
1719+
auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
1720+
decisions->tile_k = WEBGPU_MUL_MAT_TILE_K;
1721+
decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
1722+
decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
1723+
decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;
1724+
decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N;
1725+
decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N;
1726+
1727+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1728+
pipeline.context = decisions;
1729+
mul_mat_id_pipelines[key] = pipeline;
1730+
return mul_mat_id_pipelines[key];
1731+
}
1732+
16011733
webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) {
16021734
const bool is_unary = context.dst->op == GGML_OP_UNARY;
16031735
const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op;

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,6 +1376,163 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
13761376
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y);
13771377
}
13781378

1379+
static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
1380+
wgpu::CommandEncoder & encoder,
1381+
ggml_tensor * src0,
1382+
ggml_tensor * src1,
1383+
ggml_tensor * src2,
1384+
ggml_tensor * dst) {
1385+
ggml_webgpu_shader_lib_context shader_lib_ctx = {
1386+
.src0 = src0,
1387+
.src1 = src1,
1388+
.src2 = src2,
1389+
.dst = dst,
1390+
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1391+
};
1392+
1393+
// Get or create pipeline
1394+
webgpu_pipeline gather_pipeline, main_pipeline;
1395+
1396+
std::vector<webgpu_pipeline> pipelines;
1397+
std::vector<std::vector<uint32_t>> params_list;
1398+
std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
1399+
std::vector<std::pair<uint32_t, uint32_t>> workgroups_list;
1400+
1401+
gather_pipeline = ctx->shader_lib->get_mul_mat_id_gather_pipeline(shader_lib_ctx);
1402+
main_pipeline = ctx->shader_lib->get_mul_mat_id_pipeline(shader_lib_ctx);
1403+
1404+
const uint32_t param_n_expert = (uint32_t) src0->ne[2];
1405+
const uint32_t param_n_expert_used = (uint32_t) dst->ne[1];
1406+
const uint32_t param_n_tokens = (uint32_t) dst->ne[2];
1407+
1408+
// params for mul_mat_id_gather.wgsl
1409+
std::vector<uint32_t> gather_params = {
1410+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)),
1411+
param_n_expert,
1412+
param_n_expert_used,
1413+
param_n_tokens,
1414+
(uint32_t) (src2->nb[1] / ggml_type_size(src2->type)),
1415+
};
1416+
1417+
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
1418+
const size_t gathered_buf_nbytes = src0->ne[2] * src1->ne[2] * sizeof(uint32_t);
1419+
1420+
const size_t gathered_expert_used_align_offset = ROUNDUP_POW2(
1421+
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
1422+
const size_t gathered_tokens_align_offset =
1423+
ROUNDUP_POW2(gathered_expert_used_align_offset + gathered_buf_nbytes,
1424+
ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
1425+
const size_t gathered_count_ids_align_offset =
1426+
ROUNDUP_POW2(gathered_tokens_align_offset + gathered_buf_nbytes,
1427+
ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
1428+
1429+
const size_t gathered_binding_size = ROUNDUP_POW2(gathered_buf_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT);
1430+
const size_t gathered_count_ids_binding_size =
1431+
ROUNDUP_POW2(src0->ne[2] * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
1432+
1433+
// bind group entries for mul_mat_id_gather.wgsl
1434+
std::vector<wgpu::BindGroupEntry> gather_entries = {
1435+
{ .binding = 0,
1436+
.buffer = ggml_webgpu_tensor_buf(src2),
1437+
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
1438+
.size = ggml_webgpu_tensor_binding_size(ctx, src2) },
1439+
{ .binding = 1,
1440+
.buffer = ggml_webgpu_tensor_buf(dst),
1441+
.offset = gathered_expert_used_align_offset,
1442+
.size = gathered_binding_size },
1443+
{ .binding = 2,
1444+
.buffer = ggml_webgpu_tensor_buf(dst),
1445+
.offset = gathered_tokens_align_offset,
1446+
.size = gathered_binding_size },
1447+
{ .binding = 3,
1448+
.buffer = ggml_webgpu_tensor_buf(dst),
1449+
.offset = gathered_count_ids_align_offset,
1450+
.size = gathered_count_ids_binding_size },
1451+
};
1452+
1453+
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
1454+
1455+
const uint32_t gather_total_wg = param_n_expert;
1456+
const uint32_t gather_wg_x = std::min(gather_total_wg, max_wg_per_dim);
1457+
const uint32_t gather_wg_y = CEIL_DIV(gather_total_wg, gather_wg_x);
1458+
1459+
pipelines.push_back(gather_pipeline);
1460+
params_list.push_back(std::move(gather_params));
1461+
entries_list.push_back(std::move(gather_entries));
1462+
workgroups_list.push_back({ gather_wg_x, gather_wg_y });
1463+
1464+
// params for mul_mat_id.wgsl
1465+
std::vector<uint32_t> main_params = {
1466+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1467+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1468+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1469+
(uint32_t) src0->ne[0],
1470+
(uint32_t) src0->ne[1],
1471+
param_n_expert,
1472+
param_n_expert_used,
1473+
param_n_tokens,
1474+
(uint32_t) src1->ne[1],
1475+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1476+
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
1477+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1478+
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
1479+
};
1480+
1481+
// bind group entries for mul_mat_id.wgsl
1482+
std::vector<wgpu::BindGroupEntry> main_entries = {
1483+
{ .binding = 0,
1484+
.buffer = ggml_webgpu_tensor_buf(src0),
1485+
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1486+
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1487+
{ .binding = 1,
1488+
.buffer = ggml_webgpu_tensor_buf(src1),
1489+
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1490+
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
1491+
{ .binding = 2,
1492+
.buffer = ggml_webgpu_tensor_buf(dst),
1493+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1494+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
1495+
{ .binding = 3,
1496+
.buffer = ggml_webgpu_tensor_buf(dst),
1497+
.offset = gathered_expert_used_align_offset,
1498+
.size = gathered_binding_size },
1499+
{ .binding = 4,
1500+
.buffer = ggml_webgpu_tensor_buf(dst),
1501+
.offset = gathered_tokens_align_offset,
1502+
.size = gathered_binding_size },
1503+
{ .binding = 5,
1504+
.buffer = ggml_webgpu_tensor_buf(dst),
1505+
.offset = gathered_count_ids_align_offset,
1506+
.size = gathered_count_ids_binding_size },
1507+
};
1508+
1509+
// Calculate workgroup dimensions
1510+
uint32_t wg_x = 1;
1511+
uint32_t wg_y = 1;
1512+
1513+
auto * main_decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(main_pipeline.context.get());
1514+
1515+
uint32_t wg_m;
1516+
1517+
uint32_t tile_m_s = main_decisions->tile_m * main_decisions->wg_size_m;
1518+
uint32_t tile_n_s = main_decisions->tile_n * main_decisions->wg_size_n;
1519+
wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
1520+
uint32_t total_gathered = dst->ne[1] * dst->ne[2];
1521+
uint32_t max_active_experts = std::min((uint32_t) src0->ne[2], total_gathered);
1522+
uint32_t max_wg_n = CEIL_DIV(total_gathered, tile_n_s) + max_active_experts;
1523+
uint32_t total_wg = wg_m * max_wg_n;
1524+
1525+
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
1526+
1527+
pipelines.push_back(main_pipeline);
1528+
params_list.push_back(std::move(main_params));
1529+
entries_list.push_back(std::move(main_entries));
1530+
workgroups_list.push_back({ wg_x, wg_y });
1531+
1532+
return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list,
1533+
entries_list, workgroups_list);
1534+
}
1535+
13791536
#ifndef __EMSCRIPTEN__
13801537
static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
13811538
wgpu::CommandEncoder & encoder,
@@ -2638,6 +2795,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context
26382795
return ggml_webgpu_get_rows(ctx, encoder, src0, src1, node);
26392796
case GGML_OP_MUL_MAT:
26402797
return ggml_webgpu_mul_mat(ctx, encoder, src0, src1, node);
2798+
case GGML_OP_MUL_MAT_ID:
2799+
return ggml_webgpu_mul_mat_id(ctx, encoder, src0, src1, src2, node);
26412800
case GGML_OP_FLASH_ATTN_EXT:
26422801
#ifndef __EMSCRIPTEN__
26432802
return ggml_webgpu_flash_attn(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node);
@@ -3082,6 +3241,20 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
30823241
}
30833242
}
30843243
break;
3244+
case GGML_OP_MUL_MAT_ID:
3245+
{
3246+
const ggml_tensor * src0 = tensor->src[0];
3247+
const ggml_tensor * src1 = tensor->src[1];
3248+
if (src0 && src1) {
3249+
const size_t gathered_size = sizeof(uint32_t) * tensor->src[0]->ne[2] * tensor->src[1]->ne[2];
3250+
const size_t gathered_count_ids_size = sizeof(uint32_t) * tensor->src[0]->ne[2];
3251+
res = ROUNDUP_POW2(
3252+
res + gathered_size * 2 + gathered_count_ids_size +
3253+
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment * 3,
3254+
WEBGPU_STORAGE_BUF_BINDING_MULT);
3255+
}
3256+
}
3257+
break;
30853258
default:
30863259
break;
30873260
}
@@ -3503,6 +3676,35 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
35033676
}
35043677
break;
35053678
}
3679+
case GGML_OP_MUL_MAT_ID:
3680+
switch (src1->type) {
3681+
case GGML_TYPE_F16:
3682+
supports_op |= (src0->type == GGML_TYPE_F16);
3683+
break;
3684+
case GGML_TYPE_F32:
3685+
switch (src0->type) {
3686+
case GGML_TYPE_F32:
3687+
case GGML_TYPE_F16:
3688+
case GGML_TYPE_Q4_0:
3689+
case GGML_TYPE_Q4_1:
3690+
case GGML_TYPE_Q5_0:
3691+
case GGML_TYPE_Q5_1:
3692+
case GGML_TYPE_Q8_0:
3693+
case GGML_TYPE_Q2_K:
3694+
case GGML_TYPE_Q3_K:
3695+
case GGML_TYPE_Q4_K:
3696+
case GGML_TYPE_Q5_K:
3697+
case GGML_TYPE_Q6_K:
3698+
supports_op = true;
3699+
break;
3700+
default:
3701+
break;
3702+
}
3703+
break;
3704+
default:
3705+
break;
3706+
}
3707+
break;
35063708
case GGML_OP_FLASH_ATTN_EXT:
35073709
{
35083710
#ifndef __EMSCRIPTEN__

0 commit comments

Comments
 (0)