@@ -43,9 +43,9 @@ static __device__ void rope_yarn(
4343template <bool forward, bool has_ff, typename T, typename D>
4444static __global__ void rope_norm (const T * x,
4545 D * dst,
46- const int ne0 ,
47- const int ne1 ,
48- const int ne2 ,
46+ const int ne00 ,
47+ const int ne01 ,
48+ const int ne02 ,
4949 const int nb01,
5050 const int nb02,
5151 const int nb03,
@@ -64,15 +64,15 @@ static __global__ void rope_norm(const T * x,
6464 const int set_rows_stride) {
6565 const int i0 = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
6666
67- if (i0 >= ne0 ) {
67+ if (i0 >= ne00 ) {
6868 return ;
6969 }
7070
7171 const int row_dst = blockDim .x *blockIdx .x + threadIdx .x ;
7272
73- const uint32_t i3 = row_dst / (ne1*ne2 );
74- const uint32_t i2 = (row_dst - i3 * ne1 * ne2 ) / ne1 ;
75- const uint32_t i1 = row_dst - i3 * ne1 * ne2 - i2 * ne1 ;
73+ const uint32_t i3 = row_dst / (ne01 * ne02 );
74+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02 ) / ne01 ;
75+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01 ;
7676
7777 int idst = i0 + i1 * nb11 + i2 * nb12 + i3 * nb13;
7878 const int ix = i0 + i1 * nb01 + i2 * nb02 + i3 * nb03;
@@ -115,9 +115,9 @@ static __global__ void rope_norm(const T * x,
115115template <bool forward, bool has_ff, typename T, typename D>
116116static __global__ void rope_neox (const T * x,
117117 D * dst,
118- const int ne0 ,
119- const int ne1 ,
120- const int ne2 ,
118+ const int ne00 ,
119+ const int ne01 ,
120+ const int ne02 ,
121121 const int nb01,
122122 const int nb02,
123123 const int nb03,
@@ -136,15 +136,15 @@ static __global__ void rope_neox(const T * x,
136136 const int set_rows_stride) {
137137 const int i0 = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
138138
139- if (i0 >= ne0 ) {
139+ if (i0 >= ne00 ) {
140140 return ;
141141 }
142142
143143 const int row_dst = blockDim .x *blockIdx .x + threadIdx .x ;
144144
145- const uint32_t i3 = row_dst / (ne1*ne2 );
146- const uint32_t i2 = (row_dst - i3 * ne1 * ne2 ) / ne1 ;
147- const uint32_t i1 = row_dst - i3 * ne1 * ne2 - i2 * ne1 ;
145+ const uint32_t i3 = row_dst / (ne01 * ne02 );
146+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02 ) / ne01 ;
147+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01 ;
148148
149149 int idst = i0 / 2 + i1 * nb11 + i2 * nb12 + i3 * nb13;
150150 const int ix = i0 / 2 + + i1 * nb01 + i2 * nb02 + i3 * nb03;
@@ -182,9 +182,9 @@ static __global__ void rope_neox(const T * x,
182182template <bool forward, bool has_ff, typename T>
183183static __global__ void rope_multi (const T * x,
184184 T * dst,
185- const int ne0 ,
186- const int ne1 ,
187- const int ne2 ,
185+ const int ne00 ,
186+ const int ne01 ,
187+ const int ne02 ,
188188 const int nb01,
189189 const int nb02,
190190 const int nb03,
@@ -203,15 +203,15 @@ static __global__ void rope_multi(const T * x,
203203 const bool is_imrope) {
204204 const int i0 = 2 * (blockDim .y * blockIdx .y + threadIdx .y );
205205
206- if (i0 >= ne0 ) {
206+ if (i0 >= ne00 ) {
207207 return ;
208208 }
209209
210210 const int row_dst = blockDim .x *blockIdx .x + threadIdx .x ;
211211
212- const uint32_t i3 = row_dst / (ne1*ne2 );
213- const uint32_t i2 = (row_dst - i3 * ne1 * ne2 ) / ne1 ;
214- const uint32_t i1 = row_dst - i3 * ne1 * ne2 - i2 * ne1 ;
212+ const uint32_t i3 = row_dst / (ne01 * ne02 );
213+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02 ) / ne01 ;
214+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01 ;
215215
216216 int idst = i0 / 2 + i1 * nb11 + i2 * nb12 + i3 * nb13;
217217 const int ix = i0 / 2 + + i1 * nb01 + i2 * nb02 + i3 * nb03;
@@ -230,23 +230,23 @@ static __global__ void rope_multi(const T * x,
230230 float theta_base = 0.0 ;
231231 if (is_imrope) {
232232 if (sector % 3 == 1 && sector < 3 * sections.v [1 ]) { // h
233- theta_base = pos[i2 + ne2 * 1 ] * powf (theta_scale, i0 / 2 .0f );
233+ theta_base = pos[i2 + ne02 * 1 ] * powf (theta_scale, i0 / 2 .0f );
234234 } else if (sector % 3 == 2 && sector < 3 * sections.v [2 ]) { // w
235- theta_base = pos[i2 + ne2 * 2 ] * powf (theta_scale, i0 / 2 .0f );
235+ theta_base = pos[i2 + ne02 * 2 ] * powf (theta_scale, i0 / 2 .0f );
236236 } else if (sector % 3 == 0 && sector < 3 * sections.v [0 ]) { // t
237237 theta_base = pos[i2] * powf (theta_scale, i0 / 2 .0f );
238238 } else {
239- theta_base = pos[i2 + ne2 * 3 ] * powf (theta_scale, i0 / 2 .0f );
239+ theta_base = pos[i2 + ne02 * 3 ] * powf (theta_scale, i0 / 2 .0f );
240240 }
241241 } else {
242242 if (sector < sections.v [0 ]) {
243243 theta_base = pos[i2] * powf (theta_scale, i0 / 2 .0f );
244244 } else if (sector >= sections.v [0 ] && sector < sec_w) {
245- theta_base = pos[i2 + ne2 * 1 ] * powf (theta_scale, i0 / 2 .0f );
245+ theta_base = pos[i2 + ne02 * 1 ] * powf (theta_scale, i0 / 2 .0f );
246246 } else if (sector >= sec_w && sector < sec_w + sections.v [2 ]) {
247- theta_base = pos[i2 + ne2 * 2 ] * powf (theta_scale, i0 / 2 .0f );
247+ theta_base = pos[i2 + ne02 * 2 ] * powf (theta_scale, i0 / 2 .0f );
248248 } else if (sector >= sec_w + sections.v [2 ]) {
249- theta_base = pos[i2 + ne2 * 3 ] * powf (theta_scale, i0 / 2 .0f );
249+ theta_base = pos[i2 + ne02 * 3 ] * powf (theta_scale, i0 / 2 .0f );
250250 }
251251 }
252252
@@ -267,9 +267,9 @@ static __global__ void rope_multi(const T * x,
267267template <bool forward, bool has_ff, typename T>
268268static __global__ void rope_vision (const T * x,
269269 T * dst,
270- const int ne0 ,
271- const int ne1 ,
272- const int ne2 ,
270+ const int ne00 ,
271+ const int ne01 ,
272+ const int ne02 ,
273273 const int nb01,
274274 const int nb02,
275275 const int nb03,
@@ -287,15 +287,15 @@ static __global__ void rope_vision(const T * x,
287287 const mrope_sections sections) {
288288 const int i0 = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
289289
290- if (i0 >= ne0 ) {
290+ if (i0 >= ne00 ) {
291291 return ;
292292 }
293293
294294 const int row_dst = blockDim .x *blockIdx .x + threadIdx .x ;
295295
296- const uint32_t i3 = row_dst / (ne1 * ne2 );
297- const uint32_t i2 = (row_dst - i3 * ne1 * ne2 ) / ne1 ;
298- const uint32_t i1 = row_dst - i3 * ne1 * ne2 - i2 * ne1 ;
296+ const uint32_t i3 = row_dst / (ne01 * ne02 );
297+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02 ) / ne01 ;
298+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01 ;
299299
300300 int idst = i0 / 2 + i1 * nb11 + i2 * nb12 + i3 * nb13;
301301 const int ix = i0 / 2 + +i1 * nb01 + i2 * nb02 + i3 * nb03;
@@ -310,7 +310,7 @@ static __global__ void rope_vision(const T * x,
310310 theta_base = pos[i2] * powf (theta_scale, p);
311311 } else if (sector >= sections.v [0 ] && sector < sec_w) {
312312 const int p = sector - sections.v [0 ];
313- theta_base = pos[i2 + ne2 ] * powf (theta_scale, p);
313+ theta_base = pos[i2 + ne02 ] * powf (theta_scale, p);
314314 }
315315
316316 const float freq_factor = has_ff ? freq_factors[i0/2 ] : 1 .0f ;
@@ -330,9 +330,9 @@ static __global__ void rope_vision(const T * x,
330330template <bool forward, typename T, typename D>
331331static void rope_norm_cuda (const T * x,
332332 D * dst,
333- const int ne0 ,
334- const int ne1 ,
335- const int ne2 ,
333+ const int ne00 ,
334+ const int ne01 ,
335+ const int ne02 ,
336336 const int nb01,
337337 const int nb02,
338338 const int nb03,
@@ -351,30 +351,30 @@ static void rope_norm_cuda(const T * x,
351351 const int64_t * row_indices,
352352 const int set_rows_stride,
353353 cudaStream_t stream) {
354- GGML_ASSERT (ne0 % 2 == 0 );
354+ GGML_ASSERT (ne00 % 2 == 0 );
355355 const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
356- const int n_blocks_x = (ne0 + 2 * CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 * CUDA_ROPE_BLOCK_SIZE);
356+ const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 * CUDA_ROPE_BLOCK_SIZE);
357357 const dim3 block_nums (nr, n_blocks_x, 1 );
358358
359359 const float theta_scale = powf (freq_base, -2 .0f / n_dims);
360360
361361 if (freq_factors == nullptr ) {
362362 rope_norm<forward, false ><<<block_nums, block_dims, 0 , stream>>> (
363- x, dst, ne0, ne1, ne2 , nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor ,
364- corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
363+ x, dst, ne00, ne01, ne02 , nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
364+ attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
365365 } else {
366366 rope_norm<forward, true ><<<block_nums, block_dims, 0 , stream>>> (
367- x, dst, ne0, ne1, ne2 , nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor ,
368- corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
367+ x, dst, ne00, ne01, ne02 , nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
368+ attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
369369 }
370370}
371371
372372template <bool forward, typename T, typename D>
373373static void rope_neox_cuda (const T * x,
374374 D * dst,
375- const int ne0 ,
376- const int ne1 ,
377- const int ne2 ,
375+ const int ne00 ,
376+ const int ne01 ,
377+ const int ne02 ,
378378 const int nb01,
379379 const int nb02,
380380 const int nb03,
@@ -393,30 +393,30 @@ static void rope_neox_cuda(const T * x,
393393 const int64_t * row_indices,
394394 const int set_rows_stride,
395395 cudaStream_t stream) {
396- GGML_ASSERT (ne0 % 2 == 0 );
396+ GGML_ASSERT (ne00 % 2 == 0 );
397397 const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
398- const int n_blocks_x = (ne0 + 2 * CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 * CUDA_ROPE_BLOCK_SIZE);
398+ const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 * CUDA_ROPE_BLOCK_SIZE);
399399 const dim3 block_nums (nr, n_blocks_x, 1 );
400400
401401 const float theta_scale = powf (freq_base, -2 .0f / n_dims);
402402
403403 if (freq_factors == nullptr ) {
404404 rope_neox<forward, false ><<<block_nums, block_dims, 0 , stream>>> (
405- x, dst, ne0, ne1, ne2 , nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor ,
406- corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
405+ x, dst, ne00, ne01, ne02 , nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
406+ attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
407407 } else {
408408 rope_neox<forward, true ><<<block_nums, block_dims, 0 , stream>>> (
409- x, dst, ne0, ne1, ne2 , nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor ,
410- corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
409+ x, dst, ne00, ne01, ne02 , nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
410+ attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
411411 }
412412}
413413
414414template <bool forward, typename T>
415415static void rope_multi_cuda (const T * x,
416416 T * dst,
417- const int ne0 ,
418- const int ne1 ,
419- const int ne2 ,
417+ const int ne00 ,
418+ const int ne01 ,
419+ const int ne02 ,
420420 const int nb01,
421421 const int nb02,
422422 const int nb03,
@@ -435,30 +435,30 @@ static void rope_multi_cuda(const T * x,
435435 const mrope_sections sections,
436436 const bool is_imrope,
437437 cudaStream_t stream) {
438- GGML_ASSERT (ne0 % 2 == 0 );
438+ GGML_ASSERT (ne00 % 2 == 0 );
439439 const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
440- const int n_blocks_x = (ne0 + 2 * CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 * CUDA_ROPE_BLOCK_SIZE);
440+ const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 * CUDA_ROPE_BLOCK_SIZE);
441441 const dim3 block_nums (nr, n_blocks_x, 1 );
442442
443443 const float theta_scale = powf (freq_base, -2 .0f / n_dims);
444444
445445 if (freq_factors == nullptr ) {
446446 rope_multi<forward, false , T><<<block_nums, block_dims, 0 , stream>>> (
447- x, dst, ne0, ne1, ne2 , nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor ,
448- corr_dims, theta_scale, freq_factors, sections, is_imrope);
447+ x, dst, ne00, ne01, ne02 , nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
448+ attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
449449 } else {
450450 rope_multi<forward, true , T><<<block_nums, block_dims, 0 , stream>>> (
451- x, dst, ne0, ne1, ne2 , nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor ,
452- corr_dims, theta_scale, freq_factors, sections, is_imrope);
451+ x, dst, ne00, ne01, ne02 , nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
452+ attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
453453 }
454454}
455455
456456template <bool forward, typename T>
457457static void rope_vision_cuda (const T * x,
458458 T * dst,
459- const int ne0 ,
460- const int ne1 ,
461- const int ne2 ,
459+ const int ne00 ,
460+ const int ne01 ,
461+ const int ne02 ,
462462 const int nb01,
463463 const int nb02,
464464 const int nb03,
@@ -476,9 +476,9 @@ static void rope_vision_cuda(const T * x,
476476 const float * freq_factors,
477477 const mrope_sections sections,
478478 cudaStream_t stream) {
479- GGML_ASSERT (ne0 % 2 == 0 );
479+ GGML_ASSERT (ne00 % 2 == 0 );
480480 const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
481- const int n_blocks_x = (ne0 + 2 * CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 * CUDA_ROPE_BLOCK_SIZE);
481+ const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 * CUDA_ROPE_BLOCK_SIZE);
482482 const dim3 block_nums (nr, n_blocks_x, 1 );
483483 // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
484484 // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
@@ -487,12 +487,12 @@ static void rope_vision_cuda(const T * x,
487487
488488 if (freq_factors == nullptr ) {
489489 rope_vision<forward, false , T><<<block_nums, block_dims, 0 , stream>>> (
490- x, dst, ne0, ne1, ne2 , nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor ,
491- corr_dims, theta_scale, freq_factors, sections);
490+ x, dst, ne00, ne01, ne02 , nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
491+ attn_factor, corr_dims, theta_scale, freq_factors, sections);
492492 } else {
493493 rope_vision<forward, true , T><<<block_nums, block_dims, 0 , stream>>> (
494- x, dst, ne0, ne1, ne2 , nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor ,
495- corr_dims, theta_scale, freq_factors, sections);
494+ x, dst, ne00, ne01, ne02 , nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
495+ attn_factor, corr_dims, theta_scale, freq_factors, sections);
496496 }
497497}
498498
0 commit comments