@@ -17,7 +17,7 @@ static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() {
1717#pragma clang diagnostic ignored "-Wpass-failed"
1818#endif // __clang__
1919template <int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
20- __launch_bounds__ (ggml_cuda_fattn_vec_get_nthreads_device(), 1 )
20+ __launch_bounds__ (ggml_cuda_fattn_vec_get_nthreads_device(), 2 )
2121static __global__ void flash_attn_ext_vec(
2222 const char * __restrict__ Q,
2323 const char * __restrict__ K,
@@ -78,8 +78,18 @@ static __global__ void flash_attn_ext_vec(
7878 // Turbo3 uses the float Q path (like f16/bf16), not q8_1 integer path
7979 constexpr bool K_is_unquantized = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16 || type_K == GGML_TYPE_TURBO3_0 || type_K == GGML_TYPE_TURBO2_0 || type_K == GGML_TYPE_TURBO4_0);
8080 constexpr bool V_is_unquantized = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16 || type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0 || type_V == GGML_TYPE_TURBO4_0);
81- constexpr int nthreads_KQ = K_is_unquantized ? 128 / cpy_nb : nthreads_KQ_q;
82- constexpr int nthreads_V = V_is_unquantized ? ((type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0 || type_V == GGML_TYPE_TURBO4_0) ? nthreads_V_q : 128 / cpy_nb) : nthreads_V_q;
81+ constexpr bool K_is_turbo = (type_K == GGML_TYPE_TURBO3_0 || type_K == GGML_TYPE_TURBO2_0 || type_K == GGML_TYPE_TURBO4_0);
82+ // Turbo KQ dot does byte extraction + centroid lookup + scalar mul, not vectorized f16 loads.
83+ // nthreads_KQ=1: each thread computes a full KQ product alone — eliminates warp_reduce_sum
84+ // shuffle and halves KQ loop iterations. Each thread holds full Q vector in registers.
85+ constexpr int nthreads_KQ = K_is_turbo ? 1 : (K_is_unquantized ? 128 / cpy_nb : nthreads_KQ_q);
86+ constexpr bool V_is_turbo = (type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0 || type_V == GGML_TYPE_TURBO4_0);
87+ // Turbo V dequant is scalar (byte extract + LUT), not vectorized loads.
88+ // Halve nthreads_V to double V_cols_per_iter (process 2 V rows per loop iteration),
89+ // reducing loop overhead and improving ILP in the V aggregation phase.
90+ // Eighth nthreads_V for turbo: V_cols_per_iter goes from 4→8, processing 8 V positions
91+ // per outer loop iteration. Halves outer loop count again, more ILP from concurrent V rows.
92+ constexpr int nthreads_V = V_is_unquantized ? (V_is_turbo ? (nthreads_V_q / 8 < 1 ? 1 : nthreads_V_q / 8 ) : 128 / cpy_nb) : nthreads_V_q;
8393
8494 static_assert (WARP_SIZE % nthreads_KQ == 0 , " bad nthreads_K" );
8595 static_assert (WARP_SIZE % nthreads_V == 0 , " bad nthreads_V" );
@@ -123,6 +133,15 @@ static __global__ void flash_attn_ext_vec(
123133 __shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
124134#endif // V_DOT2_F32_F16_AVAILABLE
125135
136+ // Shared-memory LUT for turbo KQ scoring: precompute Q[d] * centroid[c] once,
137+ // then the hot loop does turbo_lut[d][idx] (shmem read, no multiply).
138+ // turbo4 excluded: 16 centroids × D exceeds shmem budget.
139+ // Stride = n_centroids+1 to avoid bank conflicts.
140+ constexpr int n_centroids_lut = (D <= 256 && type_K == GGML_TYPE_TURBO3_0) ? 8 :
141+ (D <= 256 && type_K == GGML_TYPE_TURBO2_0) ? 4 : 0 ;
142+ constexpr int lut_stride = n_centroids_lut > 0 ? n_centroids_lut + 1 : 1 ;
143+ __shared__ half turbo_lut[n_centroids_lut > 0 ? D : 1 ][lut_stride];
144+
126145 // Sparse V: skip V dequant for positions with negligible attention weights.
127146 // At long context, most V positions contribute < 1e-6 to the output — skipping
128147 // their dequant saves significant compute (especially for quantized V types).
@@ -247,6 +266,20 @@ static __global__ void flash_attn_ext_vec(
247266#endif // V_DOT2_F32_F16_AVAILABLE
248267 }
249268
269+ // Build shared-memory LUT: turbo_lut[d][c] = half(Q[d] * scale * centroid[c])
270+ if constexpr (n_centroids_lut > 0 && ncols == 1 ) {
271+ const float * centroids_ptr = (type_K == GGML_TYPE_TURBO3_0) ? TURBO_CENTROIDS_3BIT :
272+ TURBO_CENTROIDS_2BIT;
273+ const float * Q_f = (const float *)(Q + 0 *nb01);
274+ for (int d = tid; d < D; d += nthreads) {
275+ const float q_val = Q_f[d] * scale;
276+ for (int c = 0 ; c < n_centroids_lut; c++) {
277+ turbo_lut[d][c] = __float2half (q_val * centroids_ptr[c]);
278+ }
279+ }
280+ __syncthreads ();
281+ }
282+
250283 const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim .x + blockIdx .x ] : ne11;
251284 K += blockIdx .y *nthreads * nb11;
252285 V += blockIdx .y *nthreads * nb21;
@@ -270,8 +303,50 @@ static __global__ void flash_attn_ext_vec(
270303
271304#pragma unroll
272305 for (int j = 0 ; j < ncols; ++j) {
273- float sum = vec_dot_KQ (K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);
274- sum = warp_reduce_sum<nthreads_KQ>(sum);
306+ float sum;
307+ if constexpr (n_centroids_lut > 0 && ncols == 1 && type_K == GGML_TYPE_TURBO3_0) {
308+ // LUT scoring: 8 elements per iteration (2 qs bytes + 1 signs byte)
309+ const block_turbo3_0 * K_turbo = (const block_turbo3_0 *)(K + i_KQ*nb11);
310+ sum = 0 .0f ;
311+ for (int d0 = 0 ; d0 < D; d0 += 8 ) {
312+ const int ib = d0 / QK_TURBO3;
313+ const int jj = d0 % QK_TURBO3;
314+ const float norm = __half2float (K_turbo[ib].norm );
315+ const uint8_t qs0 = K_turbo[ib].qs [jj / 4 ];
316+ const uint8_t qs1 = K_turbo[ib].qs [jj / 4 + 1 ];
317+ const uint8_t sgn = K_turbo[ib].signs [jj / 8 ];
318+ sum += (__half2float (turbo_lut[d0 ][((qs0>>0 )&3 )|((sgn>>0 &1 )<<2 )]) +
319+ __half2float (turbo_lut[d0+1 ][((qs0>>2 )&3 )|((sgn>>1 &1 )<<2 )]) +
320+ __half2float (turbo_lut[d0+2 ][((qs0>>4 )&3 )|((sgn>>2 &1 )<<2 )]) +
321+ __half2float (turbo_lut[d0+3 ][((qs0>>6 )&3 )|((sgn>>3 &1 )<<2 )]) +
322+ __half2float (turbo_lut[d0+4 ][((qs1>>0 )&3 )|((sgn>>4 &1 )<<2 )]) +
323+ __half2float (turbo_lut[d0+5 ][((qs1>>2 )&3 )|((sgn>>5 &1 )<<2 )]) +
324+ __half2float (turbo_lut[d0+6 ][((qs1>>4 )&3 )|((sgn>>6 &1 )<<2 )]) +
325+ __half2float (turbo_lut[d0+7 ][((qs1>>6 )&3 )|((sgn>>7 &1 )<<2 )])) * norm;
326+ }
327+ } else if constexpr (n_centroids_lut > 0 && ncols == 1 && type_K == GGML_TYPE_TURBO2_0) {
328+ // LUT scoring for turbo2: 8 elements per iteration (2 qs bytes, no signs)
329+ const block_turbo2_0 * K_turbo = (const block_turbo2_0 *)(K + i_KQ*nb11);
330+ sum = 0 .0f ;
331+ for (int d0 = 0 ; d0 < D; d0 += 8 ) {
332+ const int ib = d0 / QK_TURBO2;
333+ const int jj = d0 % QK_TURBO2;
334+ const float norm = __half2float (K_turbo[ib].norm );
335+ const uint8_t qs0 = K_turbo[ib].qs [jj / 4 ];
336+ const uint8_t qs1 = K_turbo[ib].qs [jj / 4 + 1 ];
337+ sum += (__half2float (turbo_lut[d0 ][(qs0>>0 )&3 ]) +
338+ __half2float (turbo_lut[d0+1 ][(qs0>>2 )&3 ]) +
339+ __half2float (turbo_lut[d0+2 ][(qs0>>4 )&3 ]) +
340+ __half2float (turbo_lut[d0+3 ][(qs0>>6 )&3 ]) +
341+ __half2float (turbo_lut[d0+4 ][(qs1>>0 )&3 ]) +
342+ __half2float (turbo_lut[d0+5 ][(qs1>>2 )&3 ]) +
343+ __half2float (turbo_lut[d0+6 ][(qs1>>4 )&3 ]) +
344+ __half2float (turbo_lut[d0+7 ][(qs1>>6 )&3 ])) * norm;
345+ }
346+ } else {
347+ sum = vec_dot_KQ (K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);
348+ sum = warp_reduce_sum<nthreads_KQ>(sum);
349+ }
275350
276351 if (use_logit_softcap) {
277352 sum = logit_softcap*tanhf (sum);
@@ -295,12 +370,12 @@ static __global__ void flash_attn_ext_vec(
295370 for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1 ) {
296371 KQ_max_new[j] = fmaxf (KQ_max_new[j], __shfl_xor_sync (0xFFFFFFFF , KQ_max_new[j], offset, WARP_SIZE));
297372 }
298- const float KQ_max_scale = expf (KQ_max[j] - KQ_max_new[j]);
373+ const float KQ_max_scale = __expf (KQ_max[j] - KQ_max_new[j]);
299374 KQ_max[j] = KQ_max_new[j];
300375
301- KQ_reg[j] = expf (KQ_reg[j] - KQ_max[j]);
376+ KQ_reg[j] = __expf (KQ_reg[j] - KQ_max[j]);
302377 KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
303- KQ[j*nthreads + tid] = KQ_reg[j];
378+ if constexpr (!V_is_turbo) { KQ[j*nthreads + tid] = KQ_reg[j]; }
304379
305380#ifdef V_DOT2_F32_F16_AVAILABLE
306381 const half2 KQ_max_scale_h2 = make_half2 (KQ_max_scale, KQ_max_scale);
@@ -318,7 +393,7 @@ static __global__ void flash_attn_ext_vec(
318393 }
319394
320395#ifndef GGML_USE_HIP
321- __syncwarp ();
396+ if constexpr (!V_is_turbo) { __syncwarp (); }
322397#endif // GGML_USE_HIP
323398
324399#pragma unroll
@@ -329,23 +404,28 @@ static __global__ void flash_attn_ext_vec(
329404 half2 KQ_k[ncols];
330405#pragma unroll
331406 for (int j = 0 ; j < ncols; ++j) {
332- KQ_k[j] = __half2half2 (KQ[j*nthreads + k]);
407+ if constexpr (V_is_turbo) {
408+ const float kq_val = __shfl_sync (0xFFFFFFFF , KQ_reg[j], k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx .x / nthreads_V));
409+ KQ_k[j] = make_half2 (__float2half (kq_val), __float2half (kq_val));
410+ } else {
411+ KQ_k[j] = __half2half2 (KQ[j*nthreads + k]);
412+ }
333413 }
334414
335- // Sparse V: skip V dequant if all attention weights for this position are negligible
336- // Disabled — per-lane branching causes warp divergence that costs more than the
337- // skipped dequants save (-0.3% to -2.8% on RTX 3090/4090).
338- // TODO: revisit with warp-level ballot skip.
339- #if 0
340- {
415+ // Sparse V: skip V dequant if all attention weights for this position are negligible.
416+ // For turbo types, the check is compiled out: at typical decode context lengths
417+ // (< ~4K tokens) with threshold 1e-6, no positions are ever skipped, so the
418+ // per-position branch is pure overhead (misprediction + comparison cost). This
419+ // also dodges the warp-divergence regression on turbo paths that motivated the
420+ // April 24 revert (commit f2dc968).
421+ if constexpr (!V_is_turbo) {
341422 bool dominated = true ;
342423#pragma unroll
343424 for (int j = 0 ; j < ncols; ++j) {
344425 if (__hgt (__low2half (KQ_k[j]), sparse_v_threshold_h)) { dominated = false ; break ; }
345426 }
346427 if (dominated) { continue ; }
347428 }
348- #endif
349429
350430#pragma unroll
351431 for (int i_VKQ_0 = 0 ; i_VKQ_0 < D/2 ; i_VKQ_0 += nthreads_V*V_rows_per_thread/2 ) {
@@ -374,35 +454,142 @@ static __global__ void flash_attn_ext_vec(
374454 float KQ_k[ncols];
375455#pragma unroll
376456 for (int j = 0 ; j < ncols; ++j) {
377- KQ_k[j] = KQ[j*nthreads + k];
457+ if constexpr (V_is_turbo) {
458+ KQ_k[j] = __shfl_sync (0xFFFFFFFF , KQ_reg[j], k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx .x / nthreads_V));
459+ } else {
460+ KQ_k[j] = KQ[j*nthreads + k];
461+ }
378462 }
379463
380- // Sparse V: skip V dequant if all attention weights for this position are negligible
381- // Disabled — per-lane branching causes warp divergence that costs more than the
382- // skipped dequants save (-0.3% to -2.8% on RTX 3090/4090).
383- // TODO: revisit with warp-level ballot skip.
384- #if 0
385- {
464+ // Sparse V: skip V dequant if all attention weights for this position are negligible.
465+ // Compiled out for turbo types — see half2 path comment above.
466+ if constexpr (!V_is_turbo) {
386467 bool dominated = true ;
387468#pragma unroll
388469 for (int j = 0 ; j < ncols; ++j) {
389470 if (KQ_k[j] >= sparse_v_threshold_f) { dominated = false ; break ; }
390471 }
391472 if (dominated) { continue ; }
392473 }
393- #endif
474+
475+ // Turbo V path: precompute scaled centroids once per block to eliminate
476+ // per-element norm multiply. centroid[idx]*norm is computed 8/4/16 times
477+ // (once per centroid) instead of D times (once per element).
478+ if constexpr (type_V == GGML_TYPE_TURBO3_0) {
479+ const block_turbo3_0 * vb = (const block_turbo3_0 *)(V + k*nb21);
480+ int prev_ib = -1 ;
481+ float sc[8 ];
394482
395483#pragma unroll
396- for (int i_VKQ_0 = 0 ; i_VKQ_0 < D/2 ; i_VKQ_0 += nthreads_V*V_rows_per_thread/2 ) {
397- float2 tmp[V_rows_per_thread/2 ];
398- dequantize_V (V + k*nb21, tmp,
399- 2 *i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx .x : threadIdx .x % nthreads_V)*V_rows_per_thread);
484+ for (int i_VKQ_0 = 0 ; i_VKQ_0 < D/2 ; i_VKQ_0 += nthreads_V*V_rows_per_thread/2 ) {
485+ const int i0 = 2 *i_VKQ_0 + (threadIdx .x % nthreads_V)*V_rows_per_thread;
486+ const int ib = i0 / QK_TURBO3;
487+ const int j0 = i0 % QK_TURBO3;
488+
489+ if (ib != prev_ib) {
490+ prev_ib = ib;
491+ const float norm = __half2float (vb[ib].norm );
400492#pragma unroll
401- for (int i_VKQ_1 = 0 ; i_VKQ_1 < V_rows_per_thread/2 ; ++i_VKQ_1) {
493+ for (int c = 0 ; c < 8 ; ++c) { sc[c] = TURBO_CENTROIDS_3BIT[c] * norm; }
494+ }
495+
496+ const uint8_t qs_byte = vb[ib].qs [j0 / 4 ];
497+ const uint8_t sgn_byte = vb[ib].signs [j0 / 8 ];
498+ const int shift_s = j0 % 8 ;
499+
500+ const uint8_t idx0 = ((qs_byte >> 0 ) & 0x3 ) | (((sgn_byte >> (shift_s+0 )) & 0x1 ) << 2 );
501+ const uint8_t idx1 = ((qs_byte >> 2 ) & 0x3 ) | (((sgn_byte >> (shift_s+1 )) & 0x1 ) << 2 );
502+ const uint8_t idx2 = ((qs_byte >> 4 ) & 0x3 ) | (((sgn_byte >> (shift_s+2 )) & 0x1 ) << 2 );
503+ const uint8_t idx3 = ((qs_byte >> 6 ) & 0x3 ) | (((sgn_byte >> (shift_s+3 )) & 0x1 ) << 2 );
504+
402505#pragma unroll
403506 for (int j = 0 ; j < ncols; ++j) {
404- VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x *KQ_k[j];
405- VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y *KQ_k[j];
507+ VKQ[j][i_VKQ_0/nthreads_V + 0 ].x += sc[idx0]*KQ_k[j];
508+ VKQ[j][i_VKQ_0/nthreads_V + 0 ].y += sc[idx1]*KQ_k[j];
509+ VKQ[j][i_VKQ_0/nthreads_V + 1 ].x += sc[idx2]*KQ_k[j];
510+ VKQ[j][i_VKQ_0/nthreads_V + 1 ].y += sc[idx3]*KQ_k[j];
511+ }
512+ }
513+ } else if constexpr (type_V == GGML_TYPE_TURBO2_0) {
514+ const block_turbo2_0 * vb = (const block_turbo2_0 *)(V + k*nb21);
515+ int prev_ib = -1 ;
516+ float sc[4 ];
517+
518+ #pragma unroll
519+ for (int i_VKQ_0 = 0 ; i_VKQ_0 < D/2 ; i_VKQ_0 += nthreads_V*V_rows_per_thread/2 ) {
520+ const int i0 = 2 *i_VKQ_0 + (threadIdx .x % nthreads_V)*V_rows_per_thread;
521+ const int ib = i0 / QK_TURBO2;
522+ const int j0 = i0 % QK_TURBO2;
523+
524+ if (ib != prev_ib) {
525+ prev_ib = ib;
526+ const float norm = __half2float (vb[ib].norm );
527+ #pragma unroll
528+ for (int c = 0 ; c < 4 ; ++c) { sc[c] = TURBO_CENTROIDS_2BIT[c] * norm; }
529+ }
530+
531+ const uint8_t qs_byte = vb[ib].qs [j0 / 4 ];
532+
533+ const uint8_t idx0 = (qs_byte >> 0 ) & 0x3 ;
534+ const uint8_t idx1 = (qs_byte >> 2 ) & 0x3 ;
535+ const uint8_t idx2 = (qs_byte >> 4 ) & 0x3 ;
536+ const uint8_t idx3 = (qs_byte >> 6 ) & 0x3 ;
537+
538+ #pragma unroll
539+ for (int j = 0 ; j < ncols; ++j) {
540+ VKQ[j][i_VKQ_0/nthreads_V + 0 ].x += sc[idx0]*KQ_k[j];
541+ VKQ[j][i_VKQ_0/nthreads_V + 0 ].y += sc[idx1]*KQ_k[j];
542+ VKQ[j][i_VKQ_0/nthreads_V + 1 ].x += sc[idx2]*KQ_k[j];
543+ VKQ[j][i_VKQ_0/nthreads_V + 1 ].y += sc[idx3]*KQ_k[j];
544+ }
545+ }
546+ } else if constexpr (type_V == GGML_TYPE_TURBO4_0) {
547+ const block_turbo4_0 * vb = (const block_turbo4_0 *)(V + k*nb21);
548+ int prev_ib = -1 ;
549+ float sc[16 ];
550+
551+ #pragma unroll
552+ for (int i_VKQ_0 = 0 ; i_VKQ_0 < D/2 ; i_VKQ_0 += nthreads_V*V_rows_per_thread/2 ) {
553+ const int i0 = 2 *i_VKQ_0 + (threadIdx .x % nthreads_V)*V_rows_per_thread;
554+ const int ib = i0 / QK_TURBO4;
555+ const int j0 = i0 % QK_TURBO4;
556+
557+ if (ib != prev_ib) {
558+ prev_ib = ib;
559+ const float norm = __half2float (vb[ib].norm );
560+ #pragma unroll
561+ for (int c = 0 ; c < 16 ; ++c) { sc[c] = TURBO_CENTROIDS_4BIT[c] * norm; }
562+ }
563+
564+ const uint8_t qs_byte0 = vb[ib].qs [j0 / 2 ];
565+ const uint8_t qs_byte1 = vb[ib].qs [j0 / 2 + 1 ];
566+
567+ const uint8_t idx0 = (qs_byte0 >> 0 ) & 0xF ;
568+ const uint8_t idx1 = (qs_byte0 >> 4 ) & 0xF ;
569+ const uint8_t idx2 = (qs_byte1 >> 0 ) & 0xF ;
570+ const uint8_t idx3 = (qs_byte1 >> 4 ) & 0xF ;
571+
572+ #pragma unroll
573+ for (int j = 0 ; j < ncols; ++j) {
574+ VKQ[j][i_VKQ_0/nthreads_V + 0 ].x += sc[idx0]*KQ_k[j];
575+ VKQ[j][i_VKQ_0/nthreads_V + 0 ].y += sc[idx1]*KQ_k[j];
576+ VKQ[j][i_VKQ_0/nthreads_V + 1 ].x += sc[idx2]*KQ_k[j];
577+ VKQ[j][i_VKQ_0/nthreads_V + 1 ].y += sc[idx3]*KQ_k[j];
578+ }
579+ }
580+ } else {
581+ #pragma unroll
582+ for (int i_VKQ_0 = 0 ; i_VKQ_0 < D/2 ; i_VKQ_0 += nthreads_V*V_rows_per_thread/2 ) {
583+ float2 tmp[V_rows_per_thread/2 ];
584+ dequantize_V (V + k*nb21, tmp,
585+ 2 *i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx .x : threadIdx .x % nthreads_V)*V_rows_per_thread);
586+ #pragma unroll
587+ for (int i_VKQ_1 = 0 ; i_VKQ_1 < V_rows_per_thread/2 ; ++i_VKQ_1) {
588+ #pragma unroll
589+ for (int j = 0 ; j < ncols; ++j) {
590+ VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x *KQ_k[j];
591+ VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y *KQ_k[j];
592+ }
406593 }
407594 }
408595 }
@@ -422,10 +609,10 @@ static __global__ void flash_attn_ext_vec(
422609 }
423610
424611 const float kqmax_new_j = fmaxf (sink, KQ_max[j]);
425- const float KQ_max_scale = expf (KQ_max[j] - kqmax_new_j);
612+ const float KQ_max_scale = __expf (KQ_max[j] - kqmax_new_j);
426613 KQ_max[j] = kqmax_new_j;
427614
428- KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx .x == 0 ? expf (sink - KQ_max[j]) : 0 .0f );
615+ KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx .x == 0 ? __expf (sink - KQ_max[j]) : 0 .0f );
429616
430617#ifdef V_DOT2_F32_F16_AVAILABLE
431618 const half2 KQ_max_scale_h2 = make_half2 (KQ_max_scale, KQ_max_scale);
@@ -471,7 +658,7 @@ static __global__ void flash_attn_ext_vec(
471658
472659 float kqmax_new = KQ_max_shared[j_VKQ][threadIdx .x ];
473660 kqmax_new = warp_reduce_max (kqmax_new);
474- const float kqmax_scale = expf (KQ_max[j_VKQ] - kqmax_new);
661+ const float kqmax_scale = __expf (KQ_max[j_VKQ] - kqmax_new);
475662 KQ_max[j_VKQ] = kqmax_new;
476663
477664#ifdef V_DOT2_F32_F16_AVAILABLE
0 commit comments