@@ -854,6 +854,8 @@ struct vk_device_struct {
854854 vk_pipeline pipeline_ssm_scan_f32_d128;
855855 vk_pipeline pipeline_ssm_scan_f32_d256;
856856 vk_pipeline pipeline_ssm_conv_f32;
857+ vk_pipeline pipeline_ssm_conv_silu_f32;
858+ vk_pipeline pipeline_ssm_conv_bias_silu_f32;
857859 vk_pipeline pipeline_opt_step_adamw_f32;
858860 vk_pipeline pipeline_opt_step_sgd_f32;
859861 std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f32[CONV_SHAPE_COUNT];
@@ -4900,7 +4902,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
49004902 ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
49014903 }
49024904
4903- ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1);
4905+ ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 0}, 1);
4906+ ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_silu_f32, "ssm_conv_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 1}, 1);
4907+ ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_bias_silu_f32, "ssm_conv_bias_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 1, 1}, 1);
49044908
49054909 ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
49064910
@@ -9936,7 +9940,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
99369940 return nullptr;
99379941 case GGML_OP_SSM_CONV:
99389942 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9939- return ctx->device->pipeline_ssm_conv_f32;
9943+ switch (ctx->num_additional_fused_ops) {
9944+ case 0: return ctx->device->pipeline_ssm_conv_f32;
9945+ case 1: return ctx->device->pipeline_ssm_conv_silu_f32;
9946+ case 2: return ctx->device->pipeline_ssm_conv_bias_silu_f32;
9947+ default: return nullptr;
9948+ }
99409949 }
99419950 return nullptr;
99429951 case GGML_OP_OPT_STEP_ADAMW:
@@ -10877,11 +10886,28 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
1087710886 pc, elements);
1087810887}
1087910888
10880- static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
10881- const ggml_tensor * src0 = dst->src[0];
10882- const ggml_tensor * src1 = dst->src[1];
10889+ static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
10890+ ggml_tensor * conv = cgraph->nodes[node_idx];
10891+ const ggml_tensor * src0 = conv->src[0];
10892+ const ggml_tensor * src1 = conv->src[1];
10893+
10894+ // Pick the destination tensor (last node in the fused chain) and the optional bias.
10895+ // Fusion modes: 0 = ssm_conv, 1 = ssm_conv+silu, 2 = ssm_conv+add(bias)+silu.
10896+ ggml_tensor * dst = conv;
10897+ const ggml_tensor * bias = nullptr;
10898+
10899+ if (ctx->num_additional_fused_ops == 1) {
10900+ dst = cgraph->nodes[node_idx + 1]; // silu
10901+ } else if (ctx->num_additional_fused_ops == 2) {
10902+ ggml_tensor * add = cgraph->nodes[node_idx + 1];
10903+ bias = (add->src[0] == conv) ? add->src[1] : add->src[0];
10904+ dst = cgraph->nodes[node_idx + 2]; // silu
10905+ }
1088310906
10884- ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, {
10907+ // The shader always declares 4 bindings; bind src0 as a dummy when bias isn't fused.
10908+ const ggml_tensor * src2 = bias ? bias : src0;
10909+
10910+ ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SSM_CONV, {
1088510911 (uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
1088610912 (uint32_t)src1->nb[1],
1088710913 (uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
@@ -13556,7 +13582,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1355613582 break;
1355713583
1355813584 case GGML_OP_SSM_CONV:
13559- ggml_vk_ssm_conv(ctx, compute_ctx, node );
13585+ ggml_vk_ssm_conv(ctx, compute_ctx, cgraph, node_idx );
1356013586
1356113587 break;
1356213588
@@ -14453,6 +14479,62 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
1445314479 return true;
1445414480}
1445514481
14482+ // Match SSM_CONV + UNARY(SILU) or SSM_CONV + ADD + UNARY(SILU). num_extra is 1 or 2.
14483+ static bool ggml_vk_can_fuse_ssm_conv(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
14484+ int node_idx, int num_extra) {
14485+ const ggml_tensor * conv = cgraph->nodes[node_idx];
14486+ if (conv->op != GGML_OP_SSM_CONV) {
14487+ return false;
14488+ }
14489+
14490+ const ggml_tensor * silu = nullptr;
14491+ const ggml_tensor * bias = nullptr;
14492+
14493+ if (num_extra == 1) {
14494+ if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_UNARY })) {
14495+ return false;
14496+ }
14497+ silu = cgraph->nodes[node_idx + 1];
14498+ } else if (num_extra == 2) {
14499+ if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY })) {
14500+ return false;
14501+ }
14502+ const ggml_tensor * add = cgraph->nodes[node_idx + 1];
14503+ silu = cgraph->nodes[node_idx + 2];
14504+ bias = (add->src[0] == conv) ? add->src[1] : add->src[0];
14505+
14506+ if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) {
14507+ return false;
14508+ }
14509+ // bias must be channel-wise (one element per channel of the conv output)
14510+ if (ggml_nelements(bias) != conv->ne[0] || bias->ne[0] != conv->ne[0]) {
14511+ return false;
14512+ }
14513+ if (add->type != GGML_TYPE_F32) {
14514+ return false;
14515+ }
14516+ // The shader doesn't apply per-tensor offsets, so reject misaligned bias.
14517+ if (get_misalign_bytes(ctx, bias) != 0) {
14518+ return false;
14519+ }
14520+ } else {
14521+ return false;
14522+ }
14523+
14524+ if (ggml_get_unary_op(silu) != GGML_UNARY_OP_SILU) {
14525+ return false;
14526+ }
14527+ if (conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
14528+ return false;
14529+ }
14530+ // The shader writes to the fused dst using its own strides, but the push constants don't
14531+ // carry a per-tensor offset, so the binding must be naturally aligned.
14532+ if (get_misalign_bytes(ctx, silu) != 0) {
14533+ return false;
14534+ }
14535+ return true;
14536+ }
14537+
1445614538static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
1445714539 int node_idx, topk_moe_mode mode) {
1445814540
@@ -14869,6 +14951,19 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1486914951 // they are overwritten, and one workgroup per row. So close enough.
1487014952 op_srcs_fused_elementwise[0] = true;
1487114953 op_srcs_fused_elementwise[1] = true;
14954+ } else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 2)) {
14955+ ctx->num_additional_fused_ops = 2;
14956+ fusion_string = "SSM_CONV_BIAS_SILU";
14957+ // ssm_conv reads multiple input tokens per output, so it's not elementwise w.r.t. its srcs.
14958+ // The downstream add and silu are elementwise on the conv output.
14959+ op_srcs_fused_elementwise[0] = false;
14960+ op_srcs_fused_elementwise[1] = true;
14961+ op_srcs_fused_elementwise[2] = true;
14962+ } else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 1)) {
14963+ ctx->num_additional_fused_ops = 1;
14964+ fusion_string = "SSM_CONV_SILU";
14965+ op_srcs_fused_elementwise[0] = false;
14966+ op_srcs_fused_elementwise[1] = true;
1487214967 } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
1487314968 ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
1487414969 ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
@@ -15200,7 +15295,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
1520015295 !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
1520115296 !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
1520215297 !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) &&
15203- !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) {
15298+ !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD) &&
15299+ !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_ADD) &&
15300+ !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_UNARY)) {
1520415301 ok = false;
1520515302 break;
1520615303 }
@@ -15283,6 +15380,19 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
1528315380 }
1528415381 }
1528515382 }
15383+ // SSM_CONV + ADD + UNARY: pull the consuming UNARY forward
15384+ if (j > 0 &&
15385+ graph->nodes[j]->op == GGML_OP_ADD &&
15386+ graph->nodes[j-1]->op == GGML_OP_SSM_CONV) {
15387+ for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
15388+ if (graph->nodes[k]->op == GGML_OP_UNARY &&
15389+ graph->nodes[k]->src[0] == graph->nodes[j]) {
15390+ current_set.push_back(k);
15391+ used[k] = true;
15392+ break;
15393+ }
15394+ }
15395+ }
1528615396 }
1528715397 }
1528815398 // Second pass grabs view nodes.
0 commit comments