Skip to content

Commit 098705a

Browse files
authored
CUDA: fuse SSM_CONV + ADD(bias) + SILU (#22478)
1 parent 683c5ac commit 098705a

4 files changed

Lines changed: 129 additions & 8 deletions

File tree

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3556,6 +3556,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
35563556
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
35573557
const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
35583558
const ggml_tensor * silu = cgraph->nodes[node_idx+1];
3559+
if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) {
3560+
return false;
3561+
}
35593562

35603563
if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
35613564
return false;
@@ -3564,6 +3567,31 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
35643567
return true;
35653568
}
35663569

3570+
if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_ADD
3571+
&& ops.begin()[2] == GGML_OP_UNARY && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
3572+
const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
3573+
const ggml_tensor * add = cgraph->nodes[node_idx+1];
3574+
const ggml_tensor * silu = cgraph->nodes[node_idx+2];
3575+
if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) {
3576+
return false;
3577+
}
3578+
3579+
if (ssm_conv->type != GGML_TYPE_F32 || add->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
3580+
return false;
3581+
}
3582+
3583+
// ADD must consume ssm_conv's output and broadcast a 1-D channel-wise bias.
3584+
const ggml_tensor * bias = (add->src[0] == ssm_conv) ? add->src[1] : add->src[0];
3585+
if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) {
3586+
return false;
3587+
}
3588+
if (ggml_nelements(bias) != ssm_conv->ne[0] || bias->ne[0] != ssm_conv->ne[0]) {
3589+
return false;
3590+
}
3591+
3592+
return true;
3593+
}
3594+
35673595
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL
35683596
&& unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) {
35693597
const ggml_tensor * unary = cgraph->nodes[node_idx];
@@ -3966,8 +3994,13 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
39663994
return 1;
39673995
}
39683996

3997+
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
3998+
ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]);
3999+
return 2;
4000+
}
4001+
39694002
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
3970-
ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1]);
4003+
ggml_cuda_op_ssm_conv(*cuda_ctx, node, /*bias_add_node=*/ nullptr, cgraph->nodes[i + 1]);
39714004
return 1;
39724005
}
39734006

ggml/src/ggml-cuda/ssm-conv.cu

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
template <bool apply_silu, size_t split_d_inner, size_t d_conv>
55
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
6+
const float * __restrict__ bias,
67
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
78
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
89
const int64_t n_t) {
@@ -27,6 +28,8 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
2728
w[j] = w_block[tid * stride_w + j];
2829
}
2930

31+
float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0.0f;
32+
3033
for (int64_t i = 0; i < n_t; i++) {
3134
float sumf = 0.0f;
3235

@@ -42,12 +45,14 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
4245
for (size_t j = 0; j < d_conv; j++) {
4346
sumf += x[(i + j) % d_conv] * w[j];
4447
}
48+
sumf += b;
4549
y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
4650
}
4751
}
4852

4953
template <bool apply_silu, size_t split_d_inner, size_t d_conv, int64_t split_n_t>
5054
static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
55+
const float * __restrict__ bias,
5156
const int src0_nb0, const int src0_nb1, const int src0_nb2,
5257
const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
5358
const int dst_nb1, const int dst_nb2, const int64_t n_t) {
@@ -97,19 +102,22 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
97102
w[j] = w_block[tid * stride_w + j];
98103
}
99104

105+
float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0.0f;
106+
100107
// Compute from shared memory
101108
for (int64_t i = 0; i < local_n_t; i++) {
102109
float sumf = 0.0f;
103110
#pragma unroll
104111
for (size_t j = 0; j < d_conv; j++) {
105112
sumf += smem[tid * n_cols + i + j] * w[j];
106113
}
114+
sumf += b;
107115
y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
108116
}
109117
}
110118

111119
template <bool apply_silu>
112-
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
120+
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const float * bias, const int src0_nb0, const int src0_nb1,
113121
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
114122
const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,
115123
const int64_t n_s, cudaStream_t stream) {
@@ -120,14 +128,14 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
120128
constexpr int kNC = decltype(NC)::value;
121129
if (n_t <= 32) {
122130
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
123-
ssm_conv_f32<apply_silu, threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
131+
ssm_conv_f32<apply_silu, threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
124132
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
125133
} else {
126134
const int64_t split_n_t = 32;
127135
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
128136
const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float);
129137
ssm_conv_long_token_f32<apply_silu, threads, kNC, split_n_t><<<blocks, threads, smem_size, stream>>>(
130-
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
138+
src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
131139
}
132140
};
133141

@@ -140,11 +148,18 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
140148
}
141149
}
142150

143-
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) {
151+
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node, ggml_tensor * silu_dst) {
144152
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
145153
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
154+
const bool fuse_bias = bias_add_node != nullptr;
146155
const bool fuse_silu = silu_dst != nullptr;
147156

157+
// bias always comes with silu.
158+
GGML_ASSERT(!fuse_bias || fuse_silu);
159+
160+
// The bias (when fused) is the non-conv operand of the ADD node.
161+
const struct ggml_tensor * bias = fuse_bias ? (bias_add_node->src[0] == dst ? bias_add_node->src[1] : bias_add_node->src[0]) : nullptr;
162+
148163
// When fusing, write to silu_dst (the node downstream references).
149164
const struct ggml_tensor * out = fuse_silu ? silu_dst : dst;
150165

@@ -160,16 +175,23 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, g
160175

161176
const float * src0_d = (const float *) src0->data;
162177
const float * src1_d = (const float *) src1->data;
178+
const float * bias_d = fuse_bias ? (const float *) bias->data : nullptr;
163179
float * dst_d = (float *) out->data;
164180
cudaStream_t stream = ctx.stream();
165181

166182
GGML_ASSERT(src0->type == GGML_TYPE_F32);
167183
GGML_ASSERT(out->type == GGML_TYPE_F32);
184+
if (fuse_bias) {
185+
GGML_ASSERT(bias->type == GGML_TYPE_F32);
186+
GGML_ASSERT(ggml_is_contiguous(bias));
187+
GGML_ASSERT(ggml_nelements(bias) == nr);
188+
}
189+
168190
if (fuse_silu) {
169-
ssm_conv_f32_cuda<true>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
191+
ssm_conv_f32_cuda<true>(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
170192
out->nb[2], nc, nr, n_t, n_s, stream);
171193
} else {
172-
ssm_conv_f32_cuda<false>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
194+
ssm_conv_f32_cuda<false>(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
173195
out->nb[2], nc, nr, n_t, n_s, stream);
174196
}
175197
}

ggml/src/ggml-cuda/ssm-conv.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#include "common.cuh"
22

3-
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr);
3+
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node = nullptr, ggml_tensor * silu_dst = nullptr);

tests/test-backend-ops.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3579,6 +3579,49 @@ struct test_ssm_conv : public test_case {
35793579
}
35803580
};
35813581

3582+
// GGML_OP_SSM_CONV + GGML_OP_ADD (channel-wise bias, optional) + GGML_OP_UNARY(SILU) (fused operation)
3583+
struct test_ssm_conv_bias_silu : public test_case {
3584+
const ggml_type type;
3585+
const std::array<int64_t, 4> ne_a;
3586+
const std::array<int64_t, 4> ne_b;
3587+
const bool fuse_bias;
3588+
3589+
std::string op_desc(ggml_tensor * t) override {
3590+
GGML_UNUSED(t);
3591+
return "SSM_CONV_BIAS_SILU";
3592+
}
3593+
3594+
bool run_whole_graph() override { return true; }
3595+
3596+
std::string vars() override {
3597+
return VARS_TO_STR4(type, ne_a, ne_b, fuse_bias);
3598+
}
3599+
3600+
test_ssm_conv_bias_silu(ggml_type type, std::array<int64_t, 4> ne_a, std::array<int64_t, 4> ne_b,
3601+
bool fuse_bias)
3602+
: type(type), ne_a(ne_a), ne_b(ne_b), fuse_bias(fuse_bias) {}
3603+
3604+
ggml_tensor * build_graph(ggml_context * ctx) override {
3605+
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
3606+
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
3607+
ggml_set_name(a, "a");
3608+
ggml_set_name(b, "b");
3609+
3610+
ggml_tensor * out = ggml_ssm_conv(ctx, a, b);
3611+
3612+
if (fuse_bias) {
3613+
ggml_tensor * bias = ggml_new_tensor_1d(ctx, type, out->ne[0]);
3614+
ggml_set_name(bias, "bias");
3615+
out = ggml_add(ctx, out, bias);
3616+
}
3617+
3618+
out = ggml_silu(ctx, out);
3619+
3620+
ggml_set_name(out, "out");
3621+
return out;
3622+
}
3623+
};
3624+
35823625
// GGML_OP_SSM_SCAN
35833626
struct test_ssm_scan : public test_case {
35843627
const ggml_type type;
@@ -7977,6 +8020,27 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
79778020
}
79788021
}
79798022

8023+
// fused ssm_conv + (optional) bias_add + silu. The bias-only graph (no silu) is intentionally
8024+
// not tested since there's no fusion for that pattern in ggml_cuda_can_fuse.
8025+
for (int64_t d_conv : {3, 4, 9}) {
8026+
for (int64_t d_inner : {1024, 1536, 2048}) {
8027+
for (bool fuse_bias : {false, true}) {
8028+
// short token path (n_t <= 32)
8029+
test_cases.emplace_back(new test_ssm_conv_bias_silu(
8030+
GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias));
8031+
test_cases.emplace_back(new test_ssm_conv_bias_silu(
8032+
GGML_TYPE_F32, {2 * d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias));
8033+
test_cases.emplace_back(new test_ssm_conv_bias_silu(
8034+
GGML_TYPE_F32, {d_conv, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}, fuse_bias));
8035+
// long token path (n_t > 32)
8036+
test_cases.emplace_back(new test_ssm_conv_bias_silu(
8037+
GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias));
8038+
test_cases.emplace_back(new test_ssm_conv_bias_silu(
8039+
GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}, fuse_bias));
8040+
}
8041+
}
8042+
}
8043+
79808044
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
79818045
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
79828046
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64, 8, 2, 32, 4)); // Falcon-H1
@@ -8993,6 +9057,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
89939057
// Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf
89949058
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill
89959059
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1})); // generate
9060+
test_cases.emplace_back(new test_ssm_conv_bias_silu(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1}, true)); // prefill
9061+
test_cases.emplace_back(new test_ssm_conv_bias_silu(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1}, true)); // generate
89969062
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill
89979063
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate
89989064

0 commit comments

Comments
 (0)