Skip to content

Commit fb83cc9

Browse files
authored
CUDA: Fix ssm_scan_f32 data-races (ggml-org#24360)
* Add missing syncthreads before resuing cub_temp_storage __syncthreads() is required before being allowed to resue TempStorage smem: https://nvidia.github.io/cccl/unstable/cub/api/classcub_1_1BlockLoad.html#_CPPv4I0EN3cub9BlockLoad4LoadEv20RandomAccessIteratorRA14ItemsPerThread_1Ti * Add one more missing __syncthreads Could also double-buffer, but alternative is to simply ensure all threads have read smem* before writing to it again in the next loop iteration * Remove unused smem from ssm_scan_f32
1 parent 039e20a commit fb83cc9

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

ggml/src/ggml-cuda/ssm-scan.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ __global__ void __launch_bounds__(splitD, 1)
6767
__shared__ CubTempStorage cub_temp_storage;
6868

6969
BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
70+
__syncthreads();
7071
BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
7172
#else
7273
const int stride_s0 = src0_nb2 / sizeof(float);
@@ -105,6 +106,7 @@ __global__ void __launch_bounds__(splitD, 1)
105106
regs0[n] = state;
106107
}
107108
y_block[i * stride_y + threadIdx.x] = sumf;
109+
__syncthreads();
108110
}
109111

110112
#ifdef USE_CUB
@@ -249,9 +251,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
249251
GGML_ASSERT(head_dim == 1);
250252
GGML_ASSERT(n_group == 1);
251253
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
252-
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
253254
if (d_state == 16) {
254-
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, smem_size, stream);
255+
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream);
255256
switch (n_tok)
256257
{
257258
case 1:

0 commit comments

Comments
 (0)