Skip to content

Commit 3fbadb0

Browse files
authored
vulkan: fuse SSM_CONV + BIAS + SILU (#22653)
1 parent 1a68ec9 commit 3fbadb0

2 files changed

Lines changed: 129 additions & 9 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 118 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,8 @@ struct vk_device_struct {
854854
vk_pipeline pipeline_ssm_scan_f32_d128;
855855
vk_pipeline pipeline_ssm_scan_f32_d256;
856856
vk_pipeline pipeline_ssm_conv_f32;
857+
vk_pipeline pipeline_ssm_conv_silu_f32;
858+
vk_pipeline pipeline_ssm_conv_bias_silu_f32;
857859
vk_pipeline pipeline_opt_step_adamw_f32;
858860
vk_pipeline pipeline_opt_step_sgd_f32;
859861
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f32[CONV_SHAPE_COUNT];
@@ -4900,7 +4902,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
49004902
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
49014903
}
49024904

4903-
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1);
4905+
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 0}, 1);
4906+
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_silu_f32, "ssm_conv_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 1}, 1);
4907+
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_bias_silu_f32, "ssm_conv_bias_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 1, 1}, 1);
49044908

49054909
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
49064910

@@ -9936,7 +9940,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
99369940
return nullptr;
99379941
case GGML_OP_SSM_CONV:
99389942
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9939-
return ctx->device->pipeline_ssm_conv_f32;
9943+
switch (ctx->num_additional_fused_ops) {
9944+
case 0: return ctx->device->pipeline_ssm_conv_f32;
9945+
case 1: return ctx->device->pipeline_ssm_conv_silu_f32;
9946+
case 2: return ctx->device->pipeline_ssm_conv_bias_silu_f32;
9947+
default: return nullptr;
9948+
}
99409949
}
99419950
return nullptr;
99429951
case GGML_OP_OPT_STEP_ADAMW:
@@ -10877,11 +10886,28 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
1087710886
pc, elements);
1087810887
}
1087910888

10880-
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
10881-
const ggml_tensor * src0 = dst->src[0];
10882-
const ggml_tensor * src1 = dst->src[1];
10889+
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
10890+
ggml_tensor * conv = cgraph->nodes[node_idx];
10891+
const ggml_tensor * src0 = conv->src[0];
10892+
const ggml_tensor * src1 = conv->src[1];
10893+
10894+
// Pick the destination tensor (last node in the fused chain) and the optional bias.
10895+
// Fusion modes: 0 = ssm_conv, 1 = ssm_conv+silu, 2 = ssm_conv+add(bias)+silu.
10896+
ggml_tensor * dst = conv;
10897+
const ggml_tensor * bias = nullptr;
10898+
10899+
if (ctx->num_additional_fused_ops == 1) {
10900+
dst = cgraph->nodes[node_idx + 1]; // silu
10901+
} else if (ctx->num_additional_fused_ops == 2) {
10902+
ggml_tensor * add = cgraph->nodes[node_idx + 1];
10903+
bias = (add->src[0] == conv) ? add->src[1] : add->src[0];
10904+
dst = cgraph->nodes[node_idx + 2]; // silu
10905+
}
1088310906

10884-
ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, {
10907+
// The shader always declares 4 bindings; bind src0 as a dummy when bias isn't fused.
10908+
const ggml_tensor * src2 = bias ? bias : src0;
10909+
10910+
ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SSM_CONV, {
1088510911
(uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
1088610912
(uint32_t)src1->nb[1],
1088710913
(uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
@@ -13556,7 +13582,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1355613582
break;
1355713583

1355813584
case GGML_OP_SSM_CONV:
13559-
ggml_vk_ssm_conv(ctx, compute_ctx, node);
13585+
ggml_vk_ssm_conv(ctx, compute_ctx, cgraph, node_idx);
1356013586

1356113587
break;
1356213588

@@ -14453,6 +14479,62 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
1445314479
return true;
1445414480
}
1445514481

14482+
// Match SSM_CONV + UNARY(SILU) or SSM_CONV + ADD + UNARY(SILU). num_extra is 1 or 2.
14483+
static bool ggml_vk_can_fuse_ssm_conv(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
14484+
int node_idx, int num_extra) {
14485+
const ggml_tensor * conv = cgraph->nodes[node_idx];
14486+
if (conv->op != GGML_OP_SSM_CONV) {
14487+
return false;
14488+
}
14489+
14490+
const ggml_tensor * silu = nullptr;
14491+
const ggml_tensor * bias = nullptr;
14492+
14493+
if (num_extra == 1) {
14494+
if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_UNARY })) {
14495+
return false;
14496+
}
14497+
silu = cgraph->nodes[node_idx + 1];
14498+
} else if (num_extra == 2) {
14499+
if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY })) {
14500+
return false;
14501+
}
14502+
const ggml_tensor * add = cgraph->nodes[node_idx + 1];
14503+
silu = cgraph->nodes[node_idx + 2];
14504+
bias = (add->src[0] == conv) ? add->src[1] : add->src[0];
14505+
14506+
if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) {
14507+
return false;
14508+
}
14509+
// bias must be channel-wise (one element per channel of the conv output)
14510+
if (ggml_nelements(bias) != conv->ne[0] || bias->ne[0] != conv->ne[0]) {
14511+
return false;
14512+
}
14513+
if (add->type != GGML_TYPE_F32) {
14514+
return false;
14515+
}
14516+
// The shader doesn't apply per-tensor offsets, so reject misaligned bias.
14517+
if (get_misalign_bytes(ctx, bias) != 0) {
14518+
return false;
14519+
}
14520+
} else {
14521+
return false;
14522+
}
14523+
14524+
if (ggml_get_unary_op(silu) != GGML_UNARY_OP_SILU) {
14525+
return false;
14526+
}
14527+
if (conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
14528+
return false;
14529+
}
14530+
// The shader writes to the fused dst using its own strides, but the push constants don't
14531+
// carry a per-tensor offset, so the binding must be naturally aligned.
14532+
if (get_misalign_bytes(ctx, silu) != 0) {
14533+
return false;
14534+
}
14535+
return true;
14536+
}
14537+
1445614538
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
1445714539
int node_idx, topk_moe_mode mode) {
1445814540

@@ -14869,6 +14951,19 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1486914951
// they are overwritten, and one workgroup per row. So close enough.
1487014952
op_srcs_fused_elementwise[0] = true;
1487114953
op_srcs_fused_elementwise[1] = true;
14954+
} else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 2)) {
14955+
ctx->num_additional_fused_ops = 2;
14956+
fusion_string = "SSM_CONV_BIAS_SILU";
14957+
// ssm_conv reads multiple input tokens per output, so it's not elementwise w.r.t. its srcs.
14958+
// The downstream add and silu are elementwise on the conv output.
14959+
op_srcs_fused_elementwise[0] = false;
14960+
op_srcs_fused_elementwise[1] = true;
14961+
op_srcs_fused_elementwise[2] = true;
14962+
} else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 1)) {
14963+
ctx->num_additional_fused_ops = 1;
14964+
fusion_string = "SSM_CONV_SILU";
14965+
op_srcs_fused_elementwise[0] = false;
14966+
op_srcs_fused_elementwise[1] = true;
1487214967
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
1487314968
ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
1487414969
ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
@@ -15200,7 +15295,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
1520015295
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
1520115296
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
1520215297
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) &&
15203-
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) {
15298+
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD) &&
15299+
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_ADD) &&
15300+
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_UNARY)) {
1520415301
ok = false;
1520515302
break;
1520615303
}
@@ -15283,6 +15380,19 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
1528315380
}
1528415381
}
1528515382
}
15383+
// SSM_CONV + ADD + UNARY: pull the consuming UNARY forward
15384+
if (j > 0 &&
15385+
graph->nodes[j]->op == GGML_OP_ADD &&
15386+
graph->nodes[j-1]->op == GGML_OP_SSM_CONV) {
15387+
for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
15388+
if (graph->nodes[k]->op == GGML_OP_UNARY &&
15389+
graph->nodes[k]->src[0] == graph->nodes[j]) {
15390+
current_set.push_back(k);
15391+
used[k] = true;
15392+
break;
15393+
}
15394+
}
15395+
}
1528615396
}
1528715397
}
1528815398
// Second pass grabs view nodes.

ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@
66

77
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
88
layout(constant_id = 1) const uint TOKENS_PER_WG = 16;
9+
layout(constant_id = 2) const bool APPLY_BIAS = false;
10+
layout(constant_id = 3) const bool APPLY_SILU = false;
911

1012
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;
1113

1214
layout(binding = 0) readonly buffer Src0 { float src0[]; };
1315
layout(binding = 1) readonly buffer Src1 { float src1[]; };
14-
layout(binding = 2) buffer Dst { float dst[]; };
16+
layout(binding = 2) readonly buffer Bias { float bias[]; };
17+
layout(binding = 3) buffer Dst { float dst[]; };
1518

1619
layout(push_constant) uniform PushConstants {
1720
uint nb01; uint nb02;
@@ -45,6 +48,13 @@ void main() {
4548
}
4649
}
4750

51+
if (APPLY_BIAS) {
52+
sum += bias[i1];
53+
}
54+
if (APPLY_SILU) {
55+
sum = sum / (1.0f + exp(-sum));
56+
}
57+
4858
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
4959
dst[dst_idx] = sum;
5060
}

0 commit comments

Comments
 (0)