@@ -1376,6 +1376,163 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
13761376 return ggml_backend_webgpu_build (ctx->global_ctx , ctx->param_arena , encoder, pipeline, params, entries, wg_x, wg_y);
13771377}
13781378
1379+ static webgpu_encoded_op ggml_webgpu_mul_mat_id (webgpu_context & ctx,
1380+ wgpu::CommandEncoder & encoder,
1381+ ggml_tensor * src0,
1382+ ggml_tensor * src1,
1383+ ggml_tensor * src2,
1384+ ggml_tensor * dst) {
1385+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
1386+ .src0 = src0,
1387+ .src1 = src1,
1388+ .src2 = src2,
1389+ .dst = dst,
1390+ .max_wg_size = ctx->global_ctx ->capabilities .limits .maxComputeInvocationsPerWorkgroup ,
1391+ };
1392+
1393+ // Get or create pipeline
1394+ webgpu_pipeline gather_pipeline, main_pipeline;
1395+
1396+ std::vector<webgpu_pipeline> pipelines;
1397+ std::vector<std::vector<uint32_t >> params_list;
1398+ std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
1399+ std::vector<std::pair<uint32_t , uint32_t >> workgroups_list;
1400+
1401+ gather_pipeline = ctx->shader_lib ->get_mul_mat_id_gather_pipeline (shader_lib_ctx);
1402+ main_pipeline = ctx->shader_lib ->get_mul_mat_id_pipeline (shader_lib_ctx);
1403+
1404+ const uint32_t param_n_expert = (uint32_t ) src0->ne [2 ];
1405+ const uint32_t param_n_expert_used = (uint32_t ) dst->ne [1 ];
1406+ const uint32_t param_n_tokens = (uint32_t ) dst->ne [2 ];
1407+
1408+ // params for mul_mat_id_gather.wgsl
1409+ std::vector<uint32_t > gather_params = {
1410+ (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, src2) / ggml_type_size (src2->type )),
1411+ param_n_expert,
1412+ param_n_expert_used,
1413+ param_n_tokens,
1414+ (uint32_t ) (src2->nb [1 ] / ggml_type_size (src2->type )),
1415+ };
1416+
1417+ const size_t dst_offset = ggml_webgpu_tensor_offset (dst);
1418+ const size_t gathered_buf_nbytes = src0->ne [2 ] * src1->ne [2 ] * sizeof (uint32_t );
1419+
1420+ const size_t gathered_expert_used_align_offset = ROUNDUP_POW2 (
1421+ dst_offset + ggml_nbytes (dst), ctx->global_ctx ->capabilities .limits .minStorageBufferOffsetAlignment );
1422+ const size_t gathered_tokens_align_offset =
1423+ ROUNDUP_POW2 (gathered_expert_used_align_offset + gathered_buf_nbytes,
1424+ ctx->global_ctx ->capabilities .limits .minStorageBufferOffsetAlignment );
1425+ const size_t gathered_count_ids_align_offset =
1426+ ROUNDUP_POW2 (gathered_tokens_align_offset + gathered_buf_nbytes,
1427+ ctx->global_ctx ->capabilities .limits .minStorageBufferOffsetAlignment );
1428+
1429+ const size_t gathered_binding_size = ROUNDUP_POW2 (gathered_buf_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT);
1430+ const size_t gathered_count_ids_binding_size =
1431+ ROUNDUP_POW2 (src0->ne [2 ] * sizeof (uint32_t ), WEBGPU_STORAGE_BUF_BINDING_MULT);
1432+
1433+ // bind group entries for mul_mat_id_gather.wgsl
1434+ std::vector<wgpu::BindGroupEntry> gather_entries = {
1435+ { .binding = 0 ,
1436+ .buffer = ggml_webgpu_tensor_buf (src2),
1437+ .offset = ggml_webgpu_tensor_align_offset (ctx, src2),
1438+ .size = ggml_webgpu_tensor_binding_size (ctx, src2) },
1439+ { .binding = 1 ,
1440+ .buffer = ggml_webgpu_tensor_buf (dst),
1441+ .offset = gathered_expert_used_align_offset,
1442+ .size = gathered_binding_size },
1443+ { .binding = 2 ,
1444+ .buffer = ggml_webgpu_tensor_buf (dst),
1445+ .offset = gathered_tokens_align_offset,
1446+ .size = gathered_binding_size },
1447+ { .binding = 3 ,
1448+ .buffer = ggml_webgpu_tensor_buf (dst),
1449+ .offset = gathered_count_ids_align_offset,
1450+ .size = gathered_count_ids_binding_size },
1451+ };
1452+
1453+ const uint32_t max_wg_per_dim = ctx->global_ctx ->capabilities .limits .maxComputeWorkgroupsPerDimension ;
1454+
1455+ const uint32_t gather_total_wg = param_n_expert;
1456+ const uint32_t gather_wg_x = std::min (gather_total_wg, max_wg_per_dim);
1457+ const uint32_t gather_wg_y = CEIL_DIV (gather_total_wg, gather_wg_x);
1458+
1459+ pipelines.push_back (gather_pipeline);
1460+ params_list.push_back (std::move (gather_params));
1461+ entries_list.push_back (std::move (gather_entries));
1462+ workgroups_list.push_back ({ gather_wg_x, gather_wg_y });
1463+
1464+ // params for mul_mat_id.wgsl
1465+ std::vector<uint32_t > main_params = {
1466+ (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, src0) / ggml_type_size (src0->type )),
1467+ (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, src1) / ggml_type_size (src1->type )),
1468+ (uint32_t ) (ggml_webgpu_tensor_misalignment (ctx, dst) / ggml_type_size (dst->type )),
1469+ (uint32_t ) src0->ne [0 ],
1470+ (uint32_t ) src0->ne [1 ],
1471+ param_n_expert,
1472+ param_n_expert_used,
1473+ param_n_tokens,
1474+ (uint32_t ) src1->ne [1 ],
1475+ (uint32_t ) (src0->nb [1 ] / ggml_type_size (src0->type )),
1476+ (uint32_t ) (src1->nb [1 ] / ggml_type_size (src1->type )),
1477+ (uint32_t ) (src0->nb [2 ] / ggml_type_size (src0->type )),
1478+ (uint32_t ) (src1->nb [2 ] / ggml_type_size (src1->type )),
1479+ };
1480+
1481+ // bind group entries for mul_mat_id.wgsl
1482+ std::vector<wgpu::BindGroupEntry> main_entries = {
1483+ { .binding = 0 ,
1484+ .buffer = ggml_webgpu_tensor_buf (src0),
1485+ .offset = ggml_webgpu_tensor_align_offset (ctx, src0),
1486+ .size = ggml_webgpu_tensor_binding_size (ctx, src0) },
1487+ { .binding = 1 ,
1488+ .buffer = ggml_webgpu_tensor_buf (src1),
1489+ .offset = ggml_webgpu_tensor_align_offset (ctx, src1),
1490+ .size = ggml_webgpu_tensor_binding_size (ctx, src1) },
1491+ { .binding = 2 ,
1492+ .buffer = ggml_webgpu_tensor_buf (dst),
1493+ .offset = ggml_webgpu_tensor_align_offset (ctx, dst),
1494+ .size = ggml_webgpu_tensor_binding_size (ctx, dst) },
1495+ { .binding = 3 ,
1496+ .buffer = ggml_webgpu_tensor_buf (dst),
1497+ .offset = gathered_expert_used_align_offset,
1498+ .size = gathered_binding_size },
1499+ { .binding = 4 ,
1500+ .buffer = ggml_webgpu_tensor_buf (dst),
1501+ .offset = gathered_tokens_align_offset,
1502+ .size = gathered_binding_size },
1503+ { .binding = 5 ,
1504+ .buffer = ggml_webgpu_tensor_buf (dst),
1505+ .offset = gathered_count_ids_align_offset,
1506+ .size = gathered_count_ids_binding_size },
1507+ };
1508+
1509+ // Calculate workgroup dimensions
1510+ uint32_t wg_x = 1 ;
1511+ uint32_t wg_y = 1 ;
1512+
1513+ auto * main_decisions = static_cast <ggml_webgpu_mul_mat_shader_decisions *>(main_pipeline.context .get ());
1514+
1515+ uint32_t wg_m;
1516+
1517+ uint32_t tile_m_s = main_decisions->tile_m * main_decisions->wg_size_m ;
1518+ uint32_t tile_n_s = main_decisions->tile_n * main_decisions->wg_size_n ;
1519+ wg_m = CEIL_DIV (dst->ne [0 ], tile_m_s);
1520+ uint32_t total_gathered = dst->ne [1 ] * dst->ne [2 ];
1521+ uint32_t max_active_experts = std::min ((uint32_t ) src0->ne [2 ], total_gathered);
1522+ uint32_t max_wg_n = CEIL_DIV (total_gathered, tile_n_s) + max_active_experts;
1523+ uint32_t total_wg = wg_m * max_wg_n;
1524+
1525+ compute_2d_workgroups (total_wg, max_wg_per_dim, wg_x, wg_y);
1526+
1527+ pipelines.push_back (main_pipeline);
1528+ params_list.push_back (std::move (main_params));
1529+ entries_list.push_back (std::move (main_entries));
1530+ workgroups_list.push_back ({ wg_x, wg_y });
1531+
1532+ return ggml_backend_webgpu_build_multi (ctx->global_ctx , ctx->param_arena , encoder, pipelines, params_list,
1533+ entries_list, workgroups_list);
1534+ }
1535+
13791536#ifndef __EMSCRIPTEN__
13801537static webgpu_encoded_op ggml_webgpu_flash_attn (webgpu_context & ctx,
13811538 wgpu::CommandEncoder & encoder,
@@ -2638,6 +2795,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context
26382795 return ggml_webgpu_get_rows (ctx, encoder, src0, src1, node);
26392796 case GGML_OP_MUL_MAT:
26402797 return ggml_webgpu_mul_mat (ctx, encoder, src0, src1, node);
2798+ case GGML_OP_MUL_MAT_ID:
2799+ return ggml_webgpu_mul_mat_id (ctx, encoder, src0, src1, src2, node);
26412800 case GGML_OP_FLASH_ATTN_EXT:
26422801#ifndef __EMSCRIPTEN__
26432802 return ggml_webgpu_flash_attn (ctx, encoder, src0, src1, src2, node->src [3 ], node->src [4 ], node);
@@ -3082,6 +3241,20 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
30823241 }
30833242 }
30843243 break ;
3244+ case GGML_OP_MUL_MAT_ID:
3245+ {
3246+ const ggml_tensor * src0 = tensor->src [0 ];
3247+ const ggml_tensor * src1 = tensor->src [1 ];
3248+ if (src0 && src1) {
3249+ const size_t gathered_size = sizeof (uint32_t ) * tensor->src [0 ]->ne [2 ] * tensor->src [1 ]->ne [2 ];
3250+ const size_t gathered_count_ids_size = sizeof (uint32_t ) * tensor->src [0 ]->ne [2 ];
3251+ res = ROUNDUP_POW2 (
3252+ res + gathered_size * 2 + gathered_count_ids_size +
3253+ ctx->webgpu_global_ctx ->capabilities .limits .minStorageBufferOffsetAlignment * 3 ,
3254+ WEBGPU_STORAGE_BUF_BINDING_MULT);
3255+ }
3256+ }
3257+ break ;
30853258 default :
30863259 break ;
30873260 }
@@ -3503,6 +3676,35 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
35033676 }
35043677 break ;
35053678 }
3679+ case GGML_OP_MUL_MAT_ID:
3680+ switch (src1->type ) {
3681+ case GGML_TYPE_F16:
3682+ supports_op |= (src0->type == GGML_TYPE_F16);
3683+ break ;
3684+ case GGML_TYPE_F32:
3685+ switch (src0->type ) {
3686+ case GGML_TYPE_F32:
3687+ case GGML_TYPE_F16:
3688+ case GGML_TYPE_Q4_0:
3689+ case GGML_TYPE_Q4_1:
3690+ case GGML_TYPE_Q5_0:
3691+ case GGML_TYPE_Q5_1:
3692+ case GGML_TYPE_Q8_0:
3693+ case GGML_TYPE_Q2_K:
3694+ case GGML_TYPE_Q3_K:
3695+ case GGML_TYPE_Q4_K:
3696+ case GGML_TYPE_Q5_K:
3697+ case GGML_TYPE_Q6_K:
3698+ supports_op = true ;
3699+ break ;
3700+ default :
3701+ break ;
3702+ }
3703+ break ;
3704+ default :
3705+ break ;
3706+ }
3707+ break ;
35063708 case GGML_OP_FLASH_ATTN_EXT:
35073709 {
35083710#ifndef __EMSCRIPTEN__
0 commit comments