@@ -45,8 +45,13 @@ static __global__ void rope_norm(const T * x,
4545 D * dst,
4646 const int ne0,
4747 const int ne1,
48+ const int ne2,
4849 const int nb01,
4950 const int nb02,
51+ const int nb03,
52+ const int nb11,
53+ const int nb12,
54+ const int nb13,
5055 const int n_dims,
5156 const int32_t * pos,
5257 const float freq_scale,
@@ -65,17 +70,17 @@ static __global__ void rope_norm(const T * x,
6570
6671 const int row_dst = blockDim .x *blockIdx .x + threadIdx .x ;
6772
68- const int row_x = row_dst % ne1;
69- const int channel_x = row_dst / ne1;
70-
71- int idst = row_dst * ne0 + i0;
72- const int ix = channel_x*nb02 + row_x*nb01 + i0;
73+ const uint32_t i3 = row_dst / (ne1*ne2);
74+ const uint32_t i2 = (row_dst - i3 * ne1 * ne2) / ne1;
75+ const uint32_t i1 = row_dst - i3 * ne1 * ne2 - i2 * ne1;
7376
77+ int idst = i0 + i1 * nb11 + i2 * nb12 + i3 * nb13;
78+ const int ix = i0 + i1 * nb01 + i2 * nb02 + i3 * nb03;
7479 // Fusion optimization: ROPE + VIEW + SET_ROWS.
7580 // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
7681 if (set_rows_stride != 0 ) {
77- idst = row_x * ne0 + i0;
78- idst += row_indices[channel_x ] * set_rows_stride;
82+ idst = i1 * nb11 + i0;
83+ idst += row_indices[i2 ] * set_rows_stride;
7984 }
8085
8186 const auto & store_coaelsced = [&](float x0, float x1) {
@@ -92,7 +97,7 @@ static __global__ void rope_norm(const T * x,
9297 return ;
9398 }
9499
95- const float theta_base = pos[channel_x ]*powf (theta_scale, i0/2 .0f );
100+ const float theta_base = pos[i2 ]*powf (theta_scale, i0/2 .0f );
96101
97102 const float freq_factor = has_ff ? freq_factors[i0/2 ] : 1 .0f ;
98103
@@ -327,8 +332,13 @@ static void rope_norm_cuda(const T * x,
327332 D * dst,
328333 const int ne0,
329334 const int ne1,
335+ const int ne2,
330336 const int nb01,
331337 const int nb02,
338+ const int nb03,
339+ const int nb11,
340+ const int nb12,
341+ const int nb13,
332342 const int n_dims,
333343 const int nr,
334344 const int32_t * pos,
@@ -343,19 +353,19 @@ static void rope_norm_cuda(const T * x,
343353 cudaStream_t stream) {
344354 GGML_ASSERT (ne0 % 2 == 0 );
345355 const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
346- const int n_blocks_x = (ne0 + 2 * CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 * CUDA_ROPE_BLOCK_SIZE);
356+ const int n_blocks_x = (ne0 + 2 * CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 * CUDA_ROPE_BLOCK_SIZE);
347357 const dim3 block_nums (nr, n_blocks_x, 1 );
348358
349- const float theta_scale = powf (freq_base, -2 .0f / n_dims);
359+ const float theta_scale = powf (freq_base, -2 .0f / n_dims);
350360
351361 if (freq_factors == nullptr ) {
352362 rope_norm<forward, false ><<<block_nums, block_dims, 0 , stream>>> (
353- x, dst, ne0, ne1, nb01, nb02, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale ,
354- freq_factors, row_indices, set_rows_stride);
363+ x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
364+ corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
355365 } else {
356366 rope_norm<forward, true ><<<block_nums, block_dims, 0 , stream>>> (
357- x, dst, ne0, ne1, nb01, nb02, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale ,
358- freq_factors, row_indices, set_rows_stride);
367+ x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
368+ corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
359369 }
360370}
361371
@@ -622,17 +632,20 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
622632 }
623633 } else {
624634 if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
625- rope_norm_cuda<forward, float , float >((const float *) src0_d, (float *) dst_d, ne00, ne01, nb01, nb02, n_dims,
626- nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
627- freq_factors, row_indices, set_rows_stride, stream);
635+ rope_norm_cuda<forward, float , float >((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02,
636+ nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
637+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
638+ set_rows_stride, stream);
628639 } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
629- rope_norm_cuda<forward, float , half>((const float *) src0_d, (half *) dst_d, ne00, ne01, nb01, nb02, n_dims,
630- nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
631- freq_factors, row_indices, set_rows_stride, stream);
640+ rope_norm_cuda<forward, float , half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02,
641+ nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
642+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
643+ set_rows_stride, stream);
632644 } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
633- rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, nb01, nb02, n_dims, nr,
634- pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
635- freq_factors, row_indices, set_rows_stride, stream);
645+ rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02,
646+ nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
647+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
648+ set_rows_stride, stream);
636649 } else {
637650 GGML_ABORT (" fatal error" );
638651 }
0 commit comments