Skip to content

Commit e4611f3

Browse files
committed
Fix rope_norm
1 parent 99b7b15 commit e4611f3

1 file changed

Lines changed: 36 additions & 23 deletions

File tree

ggml/src/ggml-cuda/rope.cu

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,13 @@ static __global__ void rope_norm(const T * x,
4545
D * dst,
4646
const int ne0,
4747
const int ne1,
48+
const int ne2,
4849
const int nb01,
4950
const int nb02,
51+
const int nb03,
52+
const int nb11,
53+
const int nb12,
54+
const int nb13,
5055
const int n_dims,
5156
const int32_t * pos,
5257
const float freq_scale,
@@ -65,17 +70,17 @@ static __global__ void rope_norm(const T * x,
6570

6671
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
6772

68-
const int row_x = row_dst % ne1;
69-
const int channel_x = row_dst / ne1;
70-
71-
int idst = row_dst * ne0 + i0;
72-
const int ix = channel_x*nb02 + row_x*nb01 + i0;
73+
const uint32_t i3 = row_dst / (ne1*ne2);
74+
const uint32_t i2 = (row_dst - i3 * ne1 * ne2) / ne1;
75+
const uint32_t i1 = row_dst - i3 * ne1 * ne2 - i2 * ne1;
7376

77+
int idst = i0 + i1 * nb11 + i2 * nb12 + i3 * nb13;
78+
const int ix = i0 + i1 * nb01 + i2 * nb02 + i3 * nb03;
7479
// Fusion optimization: ROPE + VIEW + SET_ROWS.
7580
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
7681
if (set_rows_stride != 0) {
77-
idst = row_x * ne0 + i0;
78-
idst += row_indices[channel_x] * set_rows_stride;
82+
idst = i1 * nb11 + i0;
83+
idst += row_indices[i2] * set_rows_stride;
7984
}
8085

8186
const auto & store_coaelsced = [&](float x0, float x1) {
@@ -92,7 +97,7 @@ static __global__ void rope_norm(const T * x,
9297
return;
9398
}
9499

95-
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
100+
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
96101

97102
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
98103

@@ -327,8 +332,13 @@ static void rope_norm_cuda(const T * x,
327332
D * dst,
328333
const int ne0,
329334
const int ne1,
335+
const int ne2,
330336
const int nb01,
331337
const int nb02,
338+
const int nb03,
339+
const int nb11,
340+
const int nb12,
341+
const int nb13,
332342
const int n_dims,
333343
const int nr,
334344
const int32_t * pos,
@@ -343,19 +353,19 @@ static void rope_norm_cuda(const T * x,
343353
cudaStream_t stream) {
344354
GGML_ASSERT(ne0 % 2 == 0);
345355
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
346-
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
356+
const int n_blocks_x = (ne0 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
347357
const dim3 block_nums(nr, n_blocks_x, 1);
348358

349-
const float theta_scale = powf(freq_base, -2.0f/n_dims);
359+
const float theta_scale = powf(freq_base, -2.0f / n_dims);
350360

351361
if (freq_factors == nullptr) {
352362
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
353-
x, dst, ne0, ne1, nb01, nb02, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
354-
freq_factors, row_indices, set_rows_stride);
363+
x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
364+
corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
355365
} else {
356366
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
357-
x, dst, ne0, ne1, nb01, nb02, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
358-
freq_factors, row_indices, set_rows_stride);
367+
x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
368+
corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
359369
}
360370
}
361371

@@ -622,17 +632,20 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
622632
}
623633
} else {
624634
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
625-
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, nb01, nb02, n_dims,
626-
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
627-
freq_factors, row_indices, set_rows_stride, stream);
635+
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02,
636+
nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
637+
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
638+
set_rows_stride, stream);
628639
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
629-
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, nb01, nb02, n_dims,
630-
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
631-
freq_factors, row_indices, set_rows_stride, stream);
640+
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02,
641+
nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
642+
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
643+
set_rows_stride, stream);
632644
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
633-
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, nb01, nb02, n_dims, nr,
634-
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
635-
freq_factors, row_indices, set_rows_stride, stream);
645+
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02,
646+
nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
647+
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
648+
set_rows_stride, stream);
636649
} else {
637650
GGML_ABORT("fatal error");
638651
}

0 commit comments

Comments
 (0)