Skip to content

Commit 0fa61d7

Browse files
authored
Merge pull request #115 from TheTom/pr53-auto-asymmetric
Cherry-pick signalnine PR #53: auto-asymmetric GQA + turbo VEC FA opts
2 parents cde3e1a + 157f27f commit 0fa61d7

2 files changed

Lines changed: 248 additions & 35 deletions

File tree

ggml/src/ggml-cuda/fattn-vec.cuh

Lines changed: 222 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() {
1717
#pragma clang diagnostic ignored "-Wpass-failed"
1818
#endif // __clang__
1919
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
20-
__launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1)
20+
__launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 2)
2121
static __global__ void flash_attn_ext_vec(
2222
const char * __restrict__ Q,
2323
const char * __restrict__ K,
@@ -78,8 +78,18 @@ static __global__ void flash_attn_ext_vec(
7878
// Turbo3 uses the float Q path (like f16/bf16), not q8_1 integer path
7979
constexpr bool K_is_unquantized = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16 || type_K == GGML_TYPE_TURBO3_0 || type_K == GGML_TYPE_TURBO2_0 || type_K == GGML_TYPE_TURBO4_0);
8080
constexpr bool V_is_unquantized = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16 || type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0 || type_V == GGML_TYPE_TURBO4_0);
81-
constexpr int nthreads_KQ = K_is_unquantized ? 128 / cpy_nb : nthreads_KQ_q;
82-
constexpr int nthreads_V = V_is_unquantized ? ((type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0 || type_V == GGML_TYPE_TURBO4_0) ? nthreads_V_q : 128 / cpy_nb) : nthreads_V_q;
81+
constexpr bool K_is_turbo = (type_K == GGML_TYPE_TURBO3_0 || type_K == GGML_TYPE_TURBO2_0 || type_K == GGML_TYPE_TURBO4_0);
82+
// Turbo KQ dot does byte extraction + centroid lookup + scalar mul, not vectorized f16 loads.
83+
// nthreads_KQ=1: each thread computes a full KQ product alone — eliminates warp_reduce_sum
84+
// shuffle and halves KQ loop iterations. Each thread holds full Q vector in registers.
85+
constexpr int nthreads_KQ = K_is_turbo ? 1 : (K_is_unquantized ? 128 / cpy_nb : nthreads_KQ_q);
86+
constexpr bool V_is_turbo = (type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0 || type_V == GGML_TYPE_TURBO4_0);
87+
// Turbo V dequant is scalar (byte extract + LUT), not vectorized loads.
88+
// Halve nthreads_V to double V_cols_per_iter (process 2 V rows per loop iteration),
89+
// reducing loop overhead and improving ILP in the V aggregation phase.
90+
// Eighth nthreads_V for turbo: V_cols_per_iter goes from 4→8, processing 8 V positions
91+
// per outer loop iteration. Halves outer loop count again, more ILP from concurrent V rows.
92+
constexpr int nthreads_V = V_is_unquantized ? (V_is_turbo ? (nthreads_V_q / 8 < 1 ? 1 : nthreads_V_q / 8) : 128 / cpy_nb) : nthreads_V_q;
8393

8494
static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
8595
static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");
@@ -123,6 +133,15 @@ static __global__ void flash_attn_ext_vec(
123133
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
124134
#endif // V_DOT2_F32_F16_AVAILABLE
125135

136+
// Shared-memory LUT for turbo KQ scoring: precompute Q[d] * centroid[c] once,
137+
// then the hot loop does turbo_lut[d][idx] (shmem read, no multiply).
138+
// turbo4 excluded: 16 centroids × D exceeds shmem budget.
139+
// Stride = n_centroids+1 to avoid bank conflicts.
140+
constexpr int n_centroids_lut = (D <= 256 && type_K == GGML_TYPE_TURBO3_0) ? 8 :
141+
(D <= 256 && type_K == GGML_TYPE_TURBO2_0) ? 4 : 0;
142+
constexpr int lut_stride = n_centroids_lut > 0 ? n_centroids_lut + 1 : 1;
143+
__shared__ half turbo_lut[n_centroids_lut > 0 ? D : 1][lut_stride];
144+
126145
// Sparse V: skip V dequant for positions with negligible attention weights.
127146
// At long context, most V positions contribute < 1e-6 to the output — skipping
128147
// their dequant saves significant compute (especially for quantized V types).
@@ -247,6 +266,20 @@ static __global__ void flash_attn_ext_vec(
247266
#endif // V_DOT2_F32_F16_AVAILABLE
248267
}
249268

269+
// Build shared-memory LUT: turbo_lut[d][c] = half(Q[d] * scale * centroid[c])
270+
if constexpr (n_centroids_lut > 0 && ncols == 1) {
271+
const float * centroids_ptr = (type_K == GGML_TYPE_TURBO3_0) ? TURBO_CENTROIDS_3BIT :
272+
TURBO_CENTROIDS_2BIT;
273+
const float * Q_f = (const float *)(Q + 0*nb01);
274+
for (int d = tid; d < D; d += nthreads) {
275+
const float q_val = Q_f[d] * scale;
276+
for (int c = 0; c < n_centroids_lut; c++) {
277+
turbo_lut[d][c] = __float2half(q_val * centroids_ptr[c]);
278+
}
279+
}
280+
__syncthreads();
281+
}
282+
250283
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
251284
K += blockIdx.y*nthreads * nb11;
252285
V += blockIdx.y*nthreads * nb21;
@@ -270,8 +303,50 @@ static __global__ void flash_attn_ext_vec(
270303

271304
#pragma unroll
272305
for (int j = 0; j < ncols; ++j) {
273-
float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);
274-
sum = warp_reduce_sum<nthreads_KQ>(sum);
306+
float sum;
307+
if constexpr (n_centroids_lut > 0 && ncols == 1 && type_K == GGML_TYPE_TURBO3_0) {
308+
// LUT scoring: 8 elements per iteration (2 qs bytes + 1 signs byte)
309+
const block_turbo3_0 * K_turbo = (const block_turbo3_0 *)(K + i_KQ*nb11);
310+
sum = 0.0f;
311+
for (int d0 = 0; d0 < D; d0 += 8) {
312+
const int ib = d0 / QK_TURBO3;
313+
const int jj = d0 % QK_TURBO3;
314+
const float norm = __half2float(K_turbo[ib].norm);
315+
const uint8_t qs0 = K_turbo[ib].qs[jj / 4];
316+
const uint8_t qs1 = K_turbo[ib].qs[jj / 4 + 1];
317+
const uint8_t sgn = K_turbo[ib].signs[jj / 8];
318+
sum += (__half2float(turbo_lut[d0 ][((qs0>>0)&3)|((sgn>>0&1)<<2)]) +
319+
__half2float(turbo_lut[d0+1][((qs0>>2)&3)|((sgn>>1&1)<<2)]) +
320+
__half2float(turbo_lut[d0+2][((qs0>>4)&3)|((sgn>>2&1)<<2)]) +
321+
__half2float(turbo_lut[d0+3][((qs0>>6)&3)|((sgn>>3&1)<<2)]) +
322+
__half2float(turbo_lut[d0+4][((qs1>>0)&3)|((sgn>>4&1)<<2)]) +
323+
__half2float(turbo_lut[d0+5][((qs1>>2)&3)|((sgn>>5&1)<<2)]) +
324+
__half2float(turbo_lut[d0+6][((qs1>>4)&3)|((sgn>>6&1)<<2)]) +
325+
__half2float(turbo_lut[d0+7][((qs1>>6)&3)|((sgn>>7&1)<<2)])) * norm;
326+
}
327+
} else if constexpr (n_centroids_lut > 0 && ncols == 1 && type_K == GGML_TYPE_TURBO2_0) {
328+
// LUT scoring for turbo2: 8 elements per iteration (2 qs bytes, no signs)
329+
const block_turbo2_0 * K_turbo = (const block_turbo2_0 *)(K + i_KQ*nb11);
330+
sum = 0.0f;
331+
for (int d0 = 0; d0 < D; d0 += 8) {
332+
const int ib = d0 / QK_TURBO2;
333+
const int jj = d0 % QK_TURBO2;
334+
const float norm = __half2float(K_turbo[ib].norm);
335+
const uint8_t qs0 = K_turbo[ib].qs[jj / 4];
336+
const uint8_t qs1 = K_turbo[ib].qs[jj / 4 + 1];
337+
sum += (__half2float(turbo_lut[d0 ][(qs0>>0)&3]) +
338+
__half2float(turbo_lut[d0+1][(qs0>>2)&3]) +
339+
__half2float(turbo_lut[d0+2][(qs0>>4)&3]) +
340+
__half2float(turbo_lut[d0+3][(qs0>>6)&3]) +
341+
__half2float(turbo_lut[d0+4][(qs1>>0)&3]) +
342+
__half2float(turbo_lut[d0+5][(qs1>>2)&3]) +
343+
__half2float(turbo_lut[d0+6][(qs1>>4)&3]) +
344+
__half2float(turbo_lut[d0+7][(qs1>>6)&3])) * norm;
345+
}
346+
} else {
347+
sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);
348+
sum = warp_reduce_sum<nthreads_KQ>(sum);
349+
}
275350

276351
if (use_logit_softcap) {
277352
sum = logit_softcap*tanhf(sum);
@@ -295,12 +370,12 @@ static __global__ void flash_attn_ext_vec(
295370
for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1) {
296371
KQ_max_new[j] = fmaxf(KQ_max_new[j], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[j], offset, WARP_SIZE));
297372
}
298-
const float KQ_max_scale = expf(KQ_max[j] - KQ_max_new[j]);
373+
const float KQ_max_scale = __expf(KQ_max[j] - KQ_max_new[j]);
299374
KQ_max[j] = KQ_max_new[j];
300375

301-
KQ_reg[j] = expf(KQ_reg[j] - KQ_max[j]);
376+
KQ_reg[j] = __expf(KQ_reg[j] - KQ_max[j]);
302377
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
303-
KQ[j*nthreads + tid] = KQ_reg[j];
378+
if constexpr (!V_is_turbo) { KQ[j*nthreads + tid] = KQ_reg[j]; }
304379

305380
#ifdef V_DOT2_F32_F16_AVAILABLE
306381
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
@@ -318,7 +393,7 @@ static __global__ void flash_attn_ext_vec(
318393
}
319394

320395
#ifndef GGML_USE_HIP
321-
__syncwarp();
396+
if constexpr (!V_is_turbo) { __syncwarp(); }
322397
#endif // GGML_USE_HIP
323398

324399
#pragma unroll
@@ -329,23 +404,28 @@ static __global__ void flash_attn_ext_vec(
329404
half2 KQ_k[ncols];
330405
#pragma unroll
331406
for (int j = 0; j < ncols; ++j) {
332-
KQ_k[j] = __half2half2(KQ[j*nthreads + k]);
407+
if constexpr (V_is_turbo) {
408+
const float kq_val = __shfl_sync(0xFFFFFFFF, KQ_reg[j], k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V));
409+
KQ_k[j] = make_half2(__float2half(kq_val), __float2half(kq_val));
410+
} else {
411+
KQ_k[j] = __half2half2(KQ[j*nthreads + k]);
412+
}
333413
}
334414

335-
// Sparse V: skip V dequant if all attention weights for this position are negligible
336-
// Disabled — per-lane branching causes warp divergence that costs more than the
337-
// skipped dequants save (-0.3% to -2.8% on RTX 3090/4090).
338-
// TODO: revisit with warp-level ballot skip.
339-
#if 0
340-
{
415+
// Sparse V: skip V dequant if all attention weights for this position are negligible.
416+
// For turbo types, the check is compiled out: at typical decode context lengths
417+
// (< ~4K tokens) with threshold 1e-6, no positions are ever skipped, so the
418+
// per-position branch is pure overhead (misprediction + comparison cost). This
419+
// also dodges the warp-divergence regression on turbo paths that motivated the
420+
// April 24 revert (commit f2dc968).
421+
if constexpr (!V_is_turbo) {
341422
bool dominated = true;
342423
#pragma unroll
343424
for (int j = 0; j < ncols; ++j) {
344425
if (__hgt(__low2half(KQ_k[j]), sparse_v_threshold_h)) { dominated = false; break; }
345426
}
346427
if (dominated) { continue; }
347428
}
348-
#endif
349429

350430
#pragma unroll
351431
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
@@ -374,35 +454,142 @@ static __global__ void flash_attn_ext_vec(
374454
float KQ_k[ncols];
375455
#pragma unroll
376456
for (int j = 0; j < ncols; ++j) {
377-
KQ_k[j] = KQ[j*nthreads + k];
457+
if constexpr (V_is_turbo) {
458+
KQ_k[j] = __shfl_sync(0xFFFFFFFF, KQ_reg[j], k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V));
459+
} else {
460+
KQ_k[j] = KQ[j*nthreads + k];
461+
}
378462
}
379463

380-
// Sparse V: skip V dequant if all attention weights for this position are negligible
381-
// Disabled — per-lane branching causes warp divergence that costs more than the
382-
// skipped dequants save (-0.3% to -2.8% on RTX 3090/4090).
383-
// TODO: revisit with warp-level ballot skip.
384-
#if 0
385-
{
464+
// Sparse V: skip V dequant if all attention weights for this position are negligible.
465+
// Compiled out for turbo types — see half2 path comment above.
466+
if constexpr (!V_is_turbo) {
386467
bool dominated = true;
387468
#pragma unroll
388469
for (int j = 0; j < ncols; ++j) {
389470
if (KQ_k[j] >= sparse_v_threshold_f) { dominated = false; break; }
390471
}
391472
if (dominated) { continue; }
392473
}
393-
#endif
474+
475+
// Turbo V path: precompute scaled centroids once per block to eliminate
476+
// per-element norm multiply. centroid[idx]*norm is computed 8/4/16 times
477+
// (once per centroid) instead of D times (once per element).
478+
if constexpr (type_V == GGML_TYPE_TURBO3_0) {
479+
const block_turbo3_0 * vb = (const block_turbo3_0 *)(V + k*nb21);
480+
int prev_ib = -1;
481+
float sc[8];
394482

395483
#pragma unroll
396-
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
397-
float2 tmp[V_rows_per_thread/2];
398-
dequantize_V(V + k*nb21, tmp,
399-
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
484+
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
485+
const int i0 = 2*i_VKQ_0 + (threadIdx.x % nthreads_V)*V_rows_per_thread;
486+
const int ib = i0 / QK_TURBO3;
487+
const int j0 = i0 % QK_TURBO3;
488+
489+
if (ib != prev_ib) {
490+
prev_ib = ib;
491+
const float norm = __half2float(vb[ib].norm);
400492
#pragma unroll
401-
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
493+
for (int c = 0; c < 8; ++c) { sc[c] = TURBO_CENTROIDS_3BIT[c] * norm; }
494+
}
495+
496+
const uint8_t qs_byte = vb[ib].qs[j0 / 4];
497+
const uint8_t sgn_byte = vb[ib].signs[j0 / 8];
498+
const int shift_s = j0 % 8;
499+
500+
const uint8_t idx0 = ((qs_byte >> 0) & 0x3) | (((sgn_byte >> (shift_s+0)) & 0x1) << 2);
501+
const uint8_t idx1 = ((qs_byte >> 2) & 0x3) | (((sgn_byte >> (shift_s+1)) & 0x1) << 2);
502+
const uint8_t idx2 = ((qs_byte >> 4) & 0x3) | (((sgn_byte >> (shift_s+2)) & 0x1) << 2);
503+
const uint8_t idx3 = ((qs_byte >> 6) & 0x3) | (((sgn_byte >> (shift_s+3)) & 0x1) << 2);
504+
402505
#pragma unroll
403506
for (int j = 0; j < ncols; ++j) {
404-
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x*KQ_k[j];
405-
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y*KQ_k[j];
507+
VKQ[j][i_VKQ_0/nthreads_V + 0].x += sc[idx0]*KQ_k[j];
508+
VKQ[j][i_VKQ_0/nthreads_V + 0].y += sc[idx1]*KQ_k[j];
509+
VKQ[j][i_VKQ_0/nthreads_V + 1].x += sc[idx2]*KQ_k[j];
510+
VKQ[j][i_VKQ_0/nthreads_V + 1].y += sc[idx3]*KQ_k[j];
511+
}
512+
}
513+
} else if constexpr (type_V == GGML_TYPE_TURBO2_0) {
514+
const block_turbo2_0 * vb = (const block_turbo2_0 *)(V + k*nb21);
515+
int prev_ib = -1;
516+
float sc[4];
517+
518+
#pragma unroll
519+
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
520+
const int i0 = 2*i_VKQ_0 + (threadIdx.x % nthreads_V)*V_rows_per_thread;
521+
const int ib = i0 / QK_TURBO2;
522+
const int j0 = i0 % QK_TURBO2;
523+
524+
if (ib != prev_ib) {
525+
prev_ib = ib;
526+
const float norm = __half2float(vb[ib].norm);
527+
#pragma unroll
528+
for (int c = 0; c < 4; ++c) { sc[c] = TURBO_CENTROIDS_2BIT[c] * norm; }
529+
}
530+
531+
const uint8_t qs_byte = vb[ib].qs[j0 / 4];
532+
533+
const uint8_t idx0 = (qs_byte >> 0) & 0x3;
534+
const uint8_t idx1 = (qs_byte >> 2) & 0x3;
535+
const uint8_t idx2 = (qs_byte >> 4) & 0x3;
536+
const uint8_t idx3 = (qs_byte >> 6) & 0x3;
537+
538+
#pragma unroll
539+
for (int j = 0; j < ncols; ++j) {
540+
VKQ[j][i_VKQ_0/nthreads_V + 0].x += sc[idx0]*KQ_k[j];
541+
VKQ[j][i_VKQ_0/nthreads_V + 0].y += sc[idx1]*KQ_k[j];
542+
VKQ[j][i_VKQ_0/nthreads_V + 1].x += sc[idx2]*KQ_k[j];
543+
VKQ[j][i_VKQ_0/nthreads_V + 1].y += sc[idx3]*KQ_k[j];
544+
}
545+
}
546+
} else if constexpr (type_V == GGML_TYPE_TURBO4_0) {
547+
const block_turbo4_0 * vb = (const block_turbo4_0 *)(V + k*nb21);
548+
int prev_ib = -1;
549+
float sc[16];
550+
551+
#pragma unroll
552+
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
553+
const int i0 = 2*i_VKQ_0 + (threadIdx.x % nthreads_V)*V_rows_per_thread;
554+
const int ib = i0 / QK_TURBO4;
555+
const int j0 = i0 % QK_TURBO4;
556+
557+
if (ib != prev_ib) {
558+
prev_ib = ib;
559+
const float norm = __half2float(vb[ib].norm);
560+
#pragma unroll
561+
for (int c = 0; c < 16; ++c) { sc[c] = TURBO_CENTROIDS_4BIT[c] * norm; }
562+
}
563+
564+
const uint8_t qs_byte0 = vb[ib].qs[j0 / 2];
565+
const uint8_t qs_byte1 = vb[ib].qs[j0 / 2 + 1];
566+
567+
const uint8_t idx0 = (qs_byte0 >> 0) & 0xF;
568+
const uint8_t idx1 = (qs_byte0 >> 4) & 0xF;
569+
const uint8_t idx2 = (qs_byte1 >> 0) & 0xF;
570+
const uint8_t idx3 = (qs_byte1 >> 4) & 0xF;
571+
572+
#pragma unroll
573+
for (int j = 0; j < ncols; ++j) {
574+
VKQ[j][i_VKQ_0/nthreads_V + 0].x += sc[idx0]*KQ_k[j];
575+
VKQ[j][i_VKQ_0/nthreads_V + 0].y += sc[idx1]*KQ_k[j];
576+
VKQ[j][i_VKQ_0/nthreads_V + 1].x += sc[idx2]*KQ_k[j];
577+
VKQ[j][i_VKQ_0/nthreads_V + 1].y += sc[idx3]*KQ_k[j];
578+
}
579+
}
580+
} else {
581+
#pragma unroll
582+
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
583+
float2 tmp[V_rows_per_thread/2];
584+
dequantize_V(V + k*nb21, tmp,
585+
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
586+
#pragma unroll
587+
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
588+
#pragma unroll
589+
for (int j = 0; j < ncols; ++j) {
590+
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x*KQ_k[j];
591+
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y*KQ_k[j];
592+
}
406593
}
407594
}
408595
}
@@ -422,10 +609,10 @@ static __global__ void flash_attn_ext_vec(
422609
}
423610

424611
const float kqmax_new_j = fmaxf(sink, KQ_max[j]);
425-
const float KQ_max_scale = expf(KQ_max[j] - kqmax_new_j);
612+
const float KQ_max_scale = __expf(KQ_max[j] - kqmax_new_j);
426613
KQ_max[j] = kqmax_new_j;
427614

428-
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
615+
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? __expf(sink - KQ_max[j]) : 0.0f);
429616

430617
#ifdef V_DOT2_F32_F16_AVAILABLE
431618
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
@@ -471,7 +658,7 @@ static __global__ void flash_attn_ext_vec(
471658

472659
float kqmax_new = KQ_max_shared[j_VKQ][threadIdx.x];
473660
kqmax_new = warp_reduce_max(kqmax_new);
474-
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
661+
const float kqmax_scale = __expf(KQ_max[j_VKQ] - kqmax_new);
475662
KQ_max[j_VKQ] = kqmax_new;
476663

477664
#ifdef V_DOT2_F32_F16_AVAILABLE

0 commit comments

Comments
 (0)