33
44template <bool apply_silu, size_t split_d_inner, size_t d_conv>
55static __global__ void ssm_conv_f32 (const float * __restrict__ src0, const float * __restrict__ src1,
6+ const float * __restrict__ bias,
67 const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
78 float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
89 const int64_t n_t ) {
@@ -27,6 +28,8 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
2728 w[j] = w_block[tid * stride_w + j];
2829 }
2930
31+ float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0 .0f ;
32+
3033 for (int64_t i = 0 ; i < n_t ; i++) {
3134 float sumf = 0 .0f ;
3235
@@ -42,12 +45,14 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
4245 for (size_t j = 0 ; j < d_conv; j++) {
4346 sumf += x[(i + j) % d_conv] * w[j];
4447 }
48+ sumf += b;
4549 y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single (sumf) : sumf;
4650 }
4751}
4852
4953template <bool apply_silu, size_t split_d_inner, size_t d_conv, int64_t split_n_t >
5054static __global__ void ssm_conv_long_token_f32 (const float * __restrict__ src0, const float * __restrict__ src1,
55+ const float * __restrict__ bias,
5156 const int src0_nb0, const int src0_nb1, const int src0_nb2,
5257 const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
5358 const int dst_nb1, const int dst_nb2, const int64_t n_t ) {
@@ -97,19 +102,22 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
97102 w[j] = w_block[tid * stride_w + j];
98103 }
99104
105+ float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0 .0f ;
106+
100107 // Compute from shared memory
101108 for (int64_t i = 0 ; i < local_n_t ; i++) {
102109 float sumf = 0 .0f ;
103110#pragma unroll
104111 for (size_t j = 0 ; j < d_conv; j++) {
105112 sumf += smem[tid * n_cols + i + j] * w[j];
106113 }
114+ sumf += b;
107115 y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single (sumf) : sumf;
108116 }
109117}
110118
111119template <bool apply_silu>
112- static void ssm_conv_f32_cuda (const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
120+ static void ssm_conv_f32_cuda (const float * src0, const float * src1, const float * bias, const int src0_nb0, const int src0_nb1,
113121 const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
114122 const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t ,
115123 const int64_t n_s, cudaStream_t stream) {
@@ -120,14 +128,14 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
120128 constexpr int kNC = decltype (NC)::value;
121129 if (n_t <= 32 ) {
122130 const dim3 blocks (n_s, (nr + threads - 1 ) / threads, 1 );
123- ssm_conv_f32<apply_silu, threads, kNC ><<<blocks, threads, 0 , stream>>> (src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
131+ ssm_conv_f32<apply_silu, threads, kNC ><<<blocks, threads, 0 , stream>>> (src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
124132 dst, dst_nb0, dst_nb1, dst_nb2, n_t );
125133 } else {
126134 const int64_t split_n_t = 32 ;
127135 dim3 blocks (n_s, (nr + threads - 1 ) / threads, (n_t + split_n_t - 1 ) / split_n_t );
128136 const size_t smem_size = threads * (kNC - 1 + split_n_t ) * sizeof (float );
129137 ssm_conv_long_token_f32<apply_silu, threads, kNC , split_n_t ><<<blocks, threads, smem_size, stream>>> (
130- src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t );
138+ src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t );
131139 }
132140 };
133141
@@ -140,11 +148,18 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
140148 }
141149}
142150
143- void ggml_cuda_op_ssm_conv (ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) {
151+ void ggml_cuda_op_ssm_conv (ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node, ggml_tensor * silu_dst) {
144152 const struct ggml_tensor * src0 = dst->src [0 ]; // conv_x
145153 const struct ggml_tensor * src1 = dst->src [1 ]; // conv1d.weight
154+ const bool fuse_bias = bias_add_node != nullptr ;
146155 const bool fuse_silu = silu_dst != nullptr ;
147156
157+ // bias always comes with silu.
158+ GGML_ASSERT (!fuse_bias || fuse_silu);
159+
160+ // The bias (when fused) is the non-conv operand of the ADD node.
161+ const struct ggml_tensor * bias = fuse_bias ? (bias_add_node->src [0 ] == dst ? bias_add_node->src [1 ] : bias_add_node->src [0 ]) : nullptr ;
162+
148163 // When fusing, write to silu_dst (the node downstream references).
149164 const struct ggml_tensor * out = fuse_silu ? silu_dst : dst;
150165
@@ -160,16 +175,23 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, g
160175
161176 const float * src0_d = (const float *) src0->data ;
162177 const float * src1_d = (const float *) src1->data ;
178+ const float * bias_d = fuse_bias ? (const float *) bias->data : nullptr ;
163179 float * dst_d = (float *) out->data ;
164180 cudaStream_t stream = ctx.stream ();
165181
166182 GGML_ASSERT (src0->type == GGML_TYPE_F32);
167183 GGML_ASSERT (out->type == GGML_TYPE_F32);
184+ if (fuse_bias) {
185+ GGML_ASSERT (bias->type == GGML_TYPE_F32);
186+ GGML_ASSERT (ggml_is_contiguous (bias));
187+ GGML_ASSERT (ggml_nelements (bias) == nr);
188+ }
189+
168190 if (fuse_silu) {
169- ssm_conv_f32_cuda<true >(src0_d, src1_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src1->nb [1 ], dst_d, out->nb [0 ], out->nb [1 ],
191+ ssm_conv_f32_cuda<true >(src0_d, src1_d, bias_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src1->nb [1 ], dst_d, out->nb [0 ], out->nb [1 ],
170192 out->nb [2 ], nc, nr, n_t , n_s, stream);
171193 } else {
172- ssm_conv_f32_cuda<false >(src0_d, src1_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src1->nb [1 ], dst_d, out->nb [0 ], out->nb [1 ],
194+ ssm_conv_f32_cuda<false >(src0_d, src1_d, bias_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src1->nb [1 ], dst_d, out->nb [0 ], out->nb [1 ],
173195 out->nb [2 ], nc, nr, n_t , n_s, stream);
174196 }
175197}
0 commit comments