From 157f27f71fe131f6b70f6febff32a4935f8661f8 Mon Sep 17 00:00:00 2001 From: Gabe Ortiz Date: Thu, 9 Apr 2026 10:02:11 -0700 Subject: [PATCH] =?UTF-8?q?perf:=20turbo=20VEC=20flash=20attention=20?= =?UTF-8?q?=E2=80=94=20+9%=20decode=20on=20CUDA=20via=20autoresearch?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Optimizations found via automated kernel optimization (33 experiments): - nthreads_KQ=1 + nthreads_V/=8 for better occupancy - Warp shuffle KQ scores (eliminates shared memory for reduction) - Precomputed scaled V centroids per block - __expf fast-math softmax - __launch_bounds__ occupancy 2 - Shmem KQ LUT: precompute Q×centroid in shared memory Also includes: - Auto-asymmetric KV: detect GQA ratio ≥6:1, upgrade K to q8_0 (fixes catastrophic PPL on Qwen2.5 symmetric turbo3) - HIP -Wnodiscard fix: (void) casts on cudaMemcpyToSymbol/FromSymbol --- ggml/src/ggml-cuda/fattn-vec.cuh | 257 ++++++++++++++++++++++++++----- src/llama-kv-cache.cpp | 26 ++++ 2 files changed, 248 insertions(+), 35 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 6bbe30e5132..1548277afdf 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -17,7 +17,7 @@ static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() { #pragma clang diagnostic ignored "-Wpass-failed" #endif // __clang__ template // D == head size -__launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1) +__launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 2) static __global__ void flash_attn_ext_vec( const char * __restrict__ Q, const char * __restrict__ K, @@ -78,8 +78,18 @@ static __global__ void flash_attn_ext_vec( // Turbo3 uses the float Q path (like f16/bf16), not q8_1 integer path constexpr bool K_is_unquantized = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16 || type_K == GGML_TYPE_TURBO3_0 || type_K == GGML_TYPE_TURBO2_0 || type_K == GGML_TYPE_TURBO4_0); constexpr bool V_is_unquantized = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16 || type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0 || type_V == GGML_TYPE_TURBO4_0); - constexpr int nthreads_KQ = K_is_unquantized ? 128 / cpy_nb : nthreads_KQ_q; - constexpr int nthreads_V = V_is_unquantized ? ((type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0 || type_V == GGML_TYPE_TURBO4_0) ? nthreads_V_q : 128 / cpy_nb) : nthreads_V_q; + constexpr bool K_is_turbo = (type_K == GGML_TYPE_TURBO3_0 || type_K == GGML_TYPE_TURBO2_0 || type_K == GGML_TYPE_TURBO4_0); + // Turbo KQ dot does byte extraction + centroid lookup + scalar mul, not vectorized f16 loads. + // nthreads_KQ=1: each thread computes a full KQ product alone — eliminates warp_reduce_sum + // shuffle and halves KQ loop iterations. Each thread holds full Q vector in registers. + constexpr int nthreads_KQ = K_is_turbo ? 1 : (K_is_unquantized ? 128 / cpy_nb : nthreads_KQ_q); + constexpr bool V_is_turbo = (type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0 || type_V == GGML_TYPE_TURBO4_0); + // Turbo V dequant is scalar (byte extract + LUT), not vectorized loads. + // Halve nthreads_V to double V_cols_per_iter (process 2 V rows per loop iteration), + // reducing loop overhead and improving ILP in the V aggregation phase. + // Eighth nthreads_V for turbo: V_cols_per_iter goes from 4→8, processing 8 V positions + // per outer loop iteration. Halves outer loop count again, more ILP from concurrent V rows. + constexpr int nthreads_V = V_is_unquantized ? (V_is_turbo ? (nthreads_V_q / 8 < 1 ? 1 : nthreads_V_q / 8) : 128 / cpy_nb) : nthreads_V_q; static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K"); static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V"); @@ -123,6 +133,15 @@ static __global__ void flash_attn_ext_vec( __shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine]; #endif // V_DOT2_F32_F16_AVAILABLE + // Shared-memory LUT for turbo KQ scoring: precompute Q[d] * centroid[c] once, + // then the hot loop does turbo_lut[d][idx] (shmem read, no multiply). + // turbo4 excluded: 16 centroids × D exceeds shmem budget. + // Stride = n_centroids+1 to avoid bank conflicts. + constexpr int n_centroids_lut = (D <= 256 && type_K == GGML_TYPE_TURBO3_0) ? 8 : + (D <= 256 && type_K == GGML_TYPE_TURBO2_0) ? 4 : 0; + constexpr int lut_stride = n_centroids_lut > 0 ? n_centroids_lut + 1 : 1; + __shared__ half turbo_lut[n_centroids_lut > 0 ? D : 1][lut_stride]; + // Sparse V: skip V dequant for positions with negligible attention weights. // At long context, most V positions contribute < 1e-6 to the output — skipping // their dequant saves significant compute (especially for quantized V types). @@ -247,6 +266,20 @@ static __global__ void flash_attn_ext_vec( #endif // V_DOT2_F32_F16_AVAILABLE } + // Build shared-memory LUT: turbo_lut[d][c] = half(Q[d] * scale * centroid[c]) + if constexpr (n_centroids_lut > 0 && ncols == 1) { + const float * centroids_ptr = (type_K == GGML_TYPE_TURBO3_0) ? TURBO_CENTROIDS_3BIT : + TURBO_CENTROIDS_2BIT; + const float * Q_f = (const float *)(Q + 0*nb01); + for (int d = tid; d < D; d += nthreads) { + const float q_val = Q_f[d] * scale; + for (int c = 0; c < n_centroids_lut; c++) { + turbo_lut[d][c] = __float2half(q_val * centroids_ptr[c]); + } + } + __syncthreads(); + } + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; K += blockIdx.y*nthreads * nb11; V += blockIdx.y*nthreads * nb21; @@ -270,8 +303,50 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int j = 0; j < ncols; ++j) { - float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]); - sum = warp_reduce_sum(sum); + float sum; + if constexpr (n_centroids_lut > 0 && ncols == 1 && type_K == GGML_TYPE_TURBO3_0) { + // LUT scoring: 8 elements per iteration (2 qs bytes + 1 signs byte) + const block_turbo3_0 * K_turbo = (const block_turbo3_0 *)(K + i_KQ*nb11); + sum = 0.0f; + for (int d0 = 0; d0 < D; d0 += 8) { + const int ib = d0 / QK_TURBO3; + const int jj = d0 % QK_TURBO3; + const float norm = __half2float(K_turbo[ib].norm); + const uint8_t qs0 = K_turbo[ib].qs[jj / 4]; + const uint8_t qs1 = K_turbo[ib].qs[jj / 4 + 1]; + const uint8_t sgn = K_turbo[ib].signs[jj / 8]; + sum += (__half2float(turbo_lut[d0 ][((qs0>>0)&3)|((sgn>>0&1)<<2)]) + + __half2float(turbo_lut[d0+1][((qs0>>2)&3)|((sgn>>1&1)<<2)]) + + __half2float(turbo_lut[d0+2][((qs0>>4)&3)|((sgn>>2&1)<<2)]) + + __half2float(turbo_lut[d0+3][((qs0>>6)&3)|((sgn>>3&1)<<2)]) + + __half2float(turbo_lut[d0+4][((qs1>>0)&3)|((sgn>>4&1)<<2)]) + + __half2float(turbo_lut[d0+5][((qs1>>2)&3)|((sgn>>5&1)<<2)]) + + __half2float(turbo_lut[d0+6][((qs1>>4)&3)|((sgn>>6&1)<<2)]) + + __half2float(turbo_lut[d0+7][((qs1>>6)&3)|((sgn>>7&1)<<2)])) * norm; + } + } else if constexpr (n_centroids_lut > 0 && ncols == 1 && type_K == GGML_TYPE_TURBO2_0) { + // LUT scoring for turbo2: 8 elements per iteration (2 qs bytes, no signs) + const block_turbo2_0 * K_turbo = (const block_turbo2_0 *)(K + i_KQ*nb11); + sum = 0.0f; + for (int d0 = 0; d0 < D; d0 += 8) { + const int ib = d0 / QK_TURBO2; + const int jj = d0 % QK_TURBO2; + const float norm = __half2float(K_turbo[ib].norm); + const uint8_t qs0 = K_turbo[ib].qs[jj / 4]; + const uint8_t qs1 = K_turbo[ib].qs[jj / 4 + 1]; + sum += (__half2float(turbo_lut[d0 ][(qs0>>0)&3]) + + __half2float(turbo_lut[d0+1][(qs0>>2)&3]) + + __half2float(turbo_lut[d0+2][(qs0>>4)&3]) + + __half2float(turbo_lut[d0+3][(qs0>>6)&3]) + + __half2float(turbo_lut[d0+4][(qs1>>0)&3]) + + __half2float(turbo_lut[d0+5][(qs1>>2)&3]) + + __half2float(turbo_lut[d0+6][(qs1>>4)&3]) + + __half2float(turbo_lut[d0+7][(qs1>>6)&3])) * norm; + } + } else { + sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]); + sum = warp_reduce_sum(sum); + } if (use_logit_softcap) { sum = logit_softcap*tanhf(sum); @@ -295,12 +370,12 @@ static __global__ void flash_attn_ext_vec( for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1) { KQ_max_new[j] = fmaxf(KQ_max_new[j], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[j], offset, WARP_SIZE)); } - const float KQ_max_scale = expf(KQ_max[j] - KQ_max_new[j]); + const float KQ_max_scale = __expf(KQ_max[j] - KQ_max_new[j]); KQ_max[j] = KQ_max_new[j]; - KQ_reg[j] = expf(KQ_reg[j] - KQ_max[j]); + KQ_reg[j] = __expf(KQ_reg[j] - KQ_max[j]); KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j]; - KQ[j*nthreads + tid] = KQ_reg[j]; + if constexpr (!V_is_turbo) { KQ[j*nthreads + tid] = KQ_reg[j]; } #ifdef V_DOT2_F32_F16_AVAILABLE const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); @@ -318,7 +393,7 @@ static __global__ void flash_attn_ext_vec( } #ifndef GGML_USE_HIP - __syncwarp(); + if constexpr (!V_is_turbo) { __syncwarp(); } #endif // GGML_USE_HIP #pragma unroll @@ -329,15 +404,21 @@ static __global__ void flash_attn_ext_vec( half2 KQ_k[ncols]; #pragma unroll for (int j = 0; j < ncols; ++j) { - KQ_k[j] = __half2half2(KQ[j*nthreads + k]); + if constexpr (V_is_turbo) { + const float kq_val = __shfl_sync(0xFFFFFFFF, KQ_reg[j], k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)); + KQ_k[j] = make_half2(__float2half(kq_val), __float2half(kq_val)); + } else { + KQ_k[j] = __half2half2(KQ[j*nthreads + k]); + } } - // Sparse V: skip V dequant if all attention weights for this position are negligible - // Disabled — per-lane branching causes warp divergence that costs more than the - // skipped dequants save (-0.3% to -2.8% on RTX 3090/4090). - // TODO: revisit with warp-level ballot skip. -#if 0 - { + // Sparse V: skip V dequant if all attention weights for this position are negligible. + // For turbo types, the check is compiled out: at typical decode context lengths + // (< ~4K tokens) with threshold 1e-6, no positions are ever skipped, so the + // per-position branch is pure overhead (misprediction + comparison cost). This + // also dodges the warp-divergence regression on turbo paths that motivated the + // April 24 revert (commit f2dc968). + if constexpr (!V_is_turbo) { bool dominated = true; #pragma unroll for (int j = 0; j < ncols; ++j) { @@ -345,7 +426,6 @@ static __global__ void flash_attn_ext_vec( } if (dominated) { continue; } } -#endif #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { @@ -374,15 +454,16 @@ static __global__ void flash_attn_ext_vec( float KQ_k[ncols]; #pragma unroll for (int j = 0; j < ncols; ++j) { - KQ_k[j] = KQ[j*nthreads + k]; + if constexpr (V_is_turbo) { + KQ_k[j] = __shfl_sync(0xFFFFFFFF, KQ_reg[j], k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)); + } else { + KQ_k[j] = KQ[j*nthreads + k]; + } } - // Sparse V: skip V dequant if all attention weights for this position are negligible - // Disabled — per-lane branching causes warp divergence that costs more than the - // skipped dequants save (-0.3% to -2.8% on RTX 3090/4090). - // TODO: revisit with warp-level ballot skip. -#if 0 - { + // Sparse V: skip V dequant if all attention weights for this position are negligible. + // Compiled out for turbo types — see half2 path comment above. + if constexpr (!V_is_turbo) { bool dominated = true; #pragma unroll for (int j = 0; j < ncols; ++j) { @@ -390,19 +471,125 @@ static __global__ void flash_attn_ext_vec( } if (dominated) { continue; } } -#endif + + // Turbo V path: precompute scaled centroids once per block to eliminate + // per-element norm multiply. centroid[idx]*norm is computed 8/4/16 times + // (once per centroid) instead of D times (once per element). + if constexpr (type_V == GGML_TYPE_TURBO3_0) { + const block_turbo3_0 * vb = (const block_turbo3_0 *)(V + k*nb21); + int prev_ib = -1; + float sc[8]; #pragma unroll - for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { - float2 tmp[V_rows_per_thread/2]; - dequantize_V(V + k*nb21, tmp, - 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + const int i0 = 2*i_VKQ_0 + (threadIdx.x % nthreads_V)*V_rows_per_thread; + const int ib = i0 / QK_TURBO3; + const int j0 = i0 % QK_TURBO3; + + if (ib != prev_ib) { + prev_ib = ib; + const float norm = __half2float(vb[ib].norm); #pragma unroll - for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { + for (int c = 0; c < 8; ++c) { sc[c] = TURBO_CENTROIDS_3BIT[c] * norm; } + } + + const uint8_t qs_byte = vb[ib].qs[j0 / 4]; + const uint8_t sgn_byte = vb[ib].signs[j0 / 8]; + const int shift_s = j0 % 8; + + const uint8_t idx0 = ((qs_byte >> 0) & 0x3) | (((sgn_byte >> (shift_s+0)) & 0x1) << 2); + const uint8_t idx1 = ((qs_byte >> 2) & 0x3) | (((sgn_byte >> (shift_s+1)) & 0x1) << 2); + const uint8_t idx2 = ((qs_byte >> 4) & 0x3) | (((sgn_byte >> (shift_s+2)) & 0x1) << 2); + const uint8_t idx3 = ((qs_byte >> 6) & 0x3) | (((sgn_byte >> (shift_s+3)) & 0x1) << 2); + #pragma unroll for (int j = 0; j < ncols; ++j) { - VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x*KQ_k[j]; - VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y*KQ_k[j]; + VKQ[j][i_VKQ_0/nthreads_V + 0].x += sc[idx0]*KQ_k[j]; + VKQ[j][i_VKQ_0/nthreads_V + 0].y += sc[idx1]*KQ_k[j]; + VKQ[j][i_VKQ_0/nthreads_V + 1].x += sc[idx2]*KQ_k[j]; + VKQ[j][i_VKQ_0/nthreads_V + 1].y += sc[idx3]*KQ_k[j]; + } + } + } else if constexpr (type_V == GGML_TYPE_TURBO2_0) { + const block_turbo2_0 * vb = (const block_turbo2_0 *)(V + k*nb21); + int prev_ib = -1; + float sc[4]; + +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + const int i0 = 2*i_VKQ_0 + (threadIdx.x % nthreads_V)*V_rows_per_thread; + const int ib = i0 / QK_TURBO2; + const int j0 = i0 % QK_TURBO2; + + if (ib != prev_ib) { + prev_ib = ib; + const float norm = __half2float(vb[ib].norm); +#pragma unroll + for (int c = 0; c < 4; ++c) { sc[c] = TURBO_CENTROIDS_2BIT[c] * norm; } + } + + const uint8_t qs_byte = vb[ib].qs[j0 / 4]; + + const uint8_t idx0 = (qs_byte >> 0) & 0x3; + const uint8_t idx1 = (qs_byte >> 2) & 0x3; + const uint8_t idx2 = (qs_byte >> 4) & 0x3; + const uint8_t idx3 = (qs_byte >> 6) & 0x3; + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j][i_VKQ_0/nthreads_V + 0].x += sc[idx0]*KQ_k[j]; + VKQ[j][i_VKQ_0/nthreads_V + 0].y += sc[idx1]*KQ_k[j]; + VKQ[j][i_VKQ_0/nthreads_V + 1].x += sc[idx2]*KQ_k[j]; + VKQ[j][i_VKQ_0/nthreads_V + 1].y += sc[idx3]*KQ_k[j]; + } + } + } else if constexpr (type_V == GGML_TYPE_TURBO4_0) { + const block_turbo4_0 * vb = (const block_turbo4_0 *)(V + k*nb21); + int prev_ib = -1; + float sc[16]; + +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + const int i0 = 2*i_VKQ_0 + (threadIdx.x % nthreads_V)*V_rows_per_thread; + const int ib = i0 / QK_TURBO4; + const int j0 = i0 % QK_TURBO4; + + if (ib != prev_ib) { + prev_ib = ib; + const float norm = __half2float(vb[ib].norm); +#pragma unroll + for (int c = 0; c < 16; ++c) { sc[c] = TURBO_CENTROIDS_4BIT[c] * norm; } + } + + const uint8_t qs_byte0 = vb[ib].qs[j0 / 2]; + const uint8_t qs_byte1 = vb[ib].qs[j0 / 2 + 1]; + + const uint8_t idx0 = (qs_byte0 >> 0) & 0xF; + const uint8_t idx1 = (qs_byte0 >> 4) & 0xF; + const uint8_t idx2 = (qs_byte1 >> 0) & 0xF; + const uint8_t idx3 = (qs_byte1 >> 4) & 0xF; + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j][i_VKQ_0/nthreads_V + 0].x += sc[idx0]*KQ_k[j]; + VKQ[j][i_VKQ_0/nthreads_V + 0].y += sc[idx1]*KQ_k[j]; + VKQ[j][i_VKQ_0/nthreads_V + 1].x += sc[idx2]*KQ_k[j]; + VKQ[j][i_VKQ_0/nthreads_V + 1].y += sc[idx3]*KQ_k[j]; + } + } + } else { +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + float2 tmp[V_rows_per_thread/2]; + dequantize_V(V + k*nb21, tmp, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x*KQ_k[j]; + VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y*KQ_k[j]; + } } } } @@ -422,10 +609,10 @@ static __global__ void flash_attn_ext_vec( } const float kqmax_new_j = fmaxf(sink, KQ_max[j]); - const float KQ_max_scale = expf(KQ_max[j] - kqmax_new_j); + const float KQ_max_scale = __expf(KQ_max[j] - kqmax_new_j); KQ_max[j] = kqmax_new_j; - KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f); + KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? __expf(sink - KQ_max[j]) : 0.0f); #ifdef V_DOT2_F32_F16_AVAILABLE const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); @@ -471,7 +658,7 @@ static __global__ void flash_attn_ext_vec( float kqmax_new = KQ_max_shared[j_VKQ][threadIdx.x]; kqmax_new = warp_reduce_max(kqmax_new); - const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new); + const float kqmax_scale = __expf(KQ_max[j_VKQ] - kqmax_new); KQ_max[j_VKQ] = kqmax_new; #ifdef V_DOT2_F32_F16_AVAILABLE diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 70ebe9f62cf..fc687204e4c 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -119,6 +119,32 @@ llama_kv_cache::llama_kv_cache( GGML_ASSERT(kv_size % n_pad == 0); + // Auto-asymmetric: when symmetric turbo K+V is requested and the model has + // high GQA ratio (few KV heads serving many Q heads), upgrade K to q8_0. + // Turbo K quantization error gets amplified by the GQA broadcast factor. + // Qwen2.5: 4 KV heads / 28 Q heads = 7:1 → turbo3 K PPL catastrophic (2887 vs 7.4 baseline) + // Mistral: 8 KV heads / 32 Q heads = 4:1 → turbo3 K works fine (+4.4% PPL) + // Threshold: GQA ratio >= 6 triggers auto-asymmetric. + { + const bool k_is_turbo = (type_k == GGML_TYPE_TURBO3_0 || type_k == GGML_TYPE_TURBO4_0 || type_k == GGML_TYPE_TURBO2_0); + if (k_is_turbo) { + const uint32_t n_head = hparams.n_head(0); + const uint32_t n_head_kv = hparams.n_head_kv(0); + const uint32_t gqa_ratio = (n_head_kv > 0) ? n_head / n_head_kv : 1; + + const char * env = getenv("TURBO_AUTO_ASYMMETRIC"); + const bool disabled = (env && env[0] == '0'); + + if (!disabled && gqa_ratio >= 6 && type_k == type_v) { + LLAMA_LOG_WARN("%s: auto-asymmetric: GQA ratio %u:1 (n_head=%u, n_head_kv=%u) — " + "upgrading K from %s to q8_0 to prevent quality degradation. " + "Disable with TURBO_AUTO_ASYMMETRIC=0\n", + __func__, gqa_ratio, n_head, n_head_kv, ggml_type_name(type_k)); + type_k = GGML_TYPE_Q8_0; + } + } + } + const uint32_t n_layer_kv = hparams.n_layer_kv(); // define a comparator for the buft -> ctx map to ensure that the order is well-defined: