Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 222 additions & 35 deletions ggml/src/ggml-cuda/fattn-vec.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() {
#pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
__launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1)
__launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 2)
static __global__ void flash_attn_ext_vec(
const char * __restrict__ Q,
const char * __restrict__ K,
Expand Down Expand Up @@ -78,8 +78,18 @@ static __global__ void flash_attn_ext_vec(
// Turbo3 uses the float Q path (like f16/bf16), not q8_1 integer path
constexpr bool K_is_unquantized = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16 || type_K == GGML_TYPE_TURBO3_0 || type_K == GGML_TYPE_TURBO2_0 || type_K == GGML_TYPE_TURBO4_0);
constexpr bool V_is_unquantized = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16 || type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0 || type_V == GGML_TYPE_TURBO4_0);
constexpr int nthreads_KQ = K_is_unquantized ? 128 / cpy_nb : nthreads_KQ_q;
constexpr int nthreads_V = V_is_unquantized ? ((type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0 || type_V == GGML_TYPE_TURBO4_0) ? nthreads_V_q : 128 / cpy_nb) : nthreads_V_q;
constexpr bool K_is_turbo = (type_K == GGML_TYPE_TURBO3_0 || type_K == GGML_TYPE_TURBO2_0 || type_K == GGML_TYPE_TURBO4_0);
// Turbo KQ dot does byte extraction + centroid lookup + scalar mul, not vectorized f16 loads.
// nthreads_KQ=1: each thread computes a full KQ product alone — eliminates warp_reduce_sum
// shuffle and halves KQ loop iterations. Each thread holds full Q vector in registers.
constexpr int nthreads_KQ = K_is_turbo ? 1 : (K_is_unquantized ? 128 / cpy_nb : nthreads_KQ_q);
constexpr bool V_is_turbo = (type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0 || type_V == GGML_TYPE_TURBO4_0);
// Turbo V dequant is scalar (byte extract + LUT), not vectorized loads.
// Halve nthreads_V to double V_cols_per_iter (process 2 V rows per loop iteration),
// reducing loop overhead and improving ILP in the V aggregation phase.
// Eighth nthreads_V for turbo: V_cols_per_iter goes from 4→8, processing 8 V positions
// per outer loop iteration. Halves outer loop count again, more ILP from concurrent V rows.
constexpr int nthreads_V = V_is_unquantized ? (V_is_turbo ? (nthreads_V_q / 8 < 1 ? 1 : nthreads_V_q / 8) : 128 / cpy_nb) : nthreads_V_q;

static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");
Expand Down Expand Up @@ -123,6 +133,15 @@ static __global__ void flash_attn_ext_vec(
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
#endif // V_DOT2_F32_F16_AVAILABLE

// Shared-memory LUT for turbo KQ scoring: precompute Q[d] * centroid[c] once,
// then the hot loop does turbo_lut[d][idx] (shmem read, no multiply).
// turbo4 excluded: 16 centroids × D exceeds shmem budget.
// Stride = n_centroids+1 to avoid bank conflicts.
constexpr int n_centroids_lut = (D <= 256 && type_K == GGML_TYPE_TURBO3_0) ? 8 :
(D <= 256 && type_K == GGML_TYPE_TURBO2_0) ? 4 : 0;
constexpr int lut_stride = n_centroids_lut > 0 ? n_centroids_lut + 1 : 1;
__shared__ half turbo_lut[n_centroids_lut > 0 ? D : 1][lut_stride];

// Sparse V: skip V dequant for positions with negligible attention weights.
// At long context, most V positions contribute < 1e-6 to the output — skipping
// their dequant saves significant compute (especially for quantized V types).
Expand Down Expand Up @@ -247,6 +266,20 @@ static __global__ void flash_attn_ext_vec(
#endif // V_DOT2_F32_F16_AVAILABLE
}

// Build shared-memory LUT: turbo_lut[d][c] = half(Q[d] * scale * centroid[c])
if constexpr (n_centroids_lut > 0 && ncols == 1) {
const float * centroids_ptr = (type_K == GGML_TYPE_TURBO3_0) ? TURBO_CENTROIDS_3BIT :
TURBO_CENTROIDS_2BIT;
const float * Q_f = (const float *)(Q + 0*nb01);
for (int d = tid; d < D; d += nthreads) {
const float q_val = Q_f[d] * scale;
for (int c = 0; c < n_centroids_lut; c++) {
turbo_lut[d][c] = __float2half(q_val * centroids_ptr[c]);
}
}
__syncthreads();
}

const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
K += blockIdx.y*nthreads * nb11;
V += blockIdx.y*nthreads * nb21;
Expand All @@ -270,8 +303,50 @@ static __global__ void flash_attn_ext_vec(

#pragma unroll
for (int j = 0; j < ncols; ++j) {
float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum<nthreads_KQ>(sum);
float sum;
if constexpr (n_centroids_lut > 0 && ncols == 1 && type_K == GGML_TYPE_TURBO3_0) {
// LUT scoring: 8 elements per iteration (2 qs bytes + 1 signs byte)
const block_turbo3_0 * K_turbo = (const block_turbo3_0 *)(K + i_KQ*nb11);
sum = 0.0f;
for (int d0 = 0; d0 < D; d0 += 8) {
const int ib = d0 / QK_TURBO3;
const int jj = d0 % QK_TURBO3;
const float norm = __half2float(K_turbo[ib].norm);
const uint8_t qs0 = K_turbo[ib].qs[jj / 4];
const uint8_t qs1 = K_turbo[ib].qs[jj / 4 + 1];
const uint8_t sgn = K_turbo[ib].signs[jj / 8];
sum += (__half2float(turbo_lut[d0 ][((qs0>>0)&3)|((sgn>>0&1)<<2)]) +
__half2float(turbo_lut[d0+1][((qs0>>2)&3)|((sgn>>1&1)<<2)]) +
__half2float(turbo_lut[d0+2][((qs0>>4)&3)|((sgn>>2&1)<<2)]) +
__half2float(turbo_lut[d0+3][((qs0>>6)&3)|((sgn>>3&1)<<2)]) +
__half2float(turbo_lut[d0+4][((qs1>>0)&3)|((sgn>>4&1)<<2)]) +
__half2float(turbo_lut[d0+5][((qs1>>2)&3)|((sgn>>5&1)<<2)]) +
__half2float(turbo_lut[d0+6][((qs1>>4)&3)|((sgn>>6&1)<<2)]) +
__half2float(turbo_lut[d0+7][((qs1>>6)&3)|((sgn>>7&1)<<2)])) * norm;
}
} else if constexpr (n_centroids_lut > 0 && ncols == 1 && type_K == GGML_TYPE_TURBO2_0) {
// LUT scoring for turbo2: 8 elements per iteration (2 qs bytes, no signs)
const block_turbo2_0 * K_turbo = (const block_turbo2_0 *)(K + i_KQ*nb11);
sum = 0.0f;
for (int d0 = 0; d0 < D; d0 += 8) {
const int ib = d0 / QK_TURBO2;
const int jj = d0 % QK_TURBO2;
const float norm = __half2float(K_turbo[ib].norm);
const uint8_t qs0 = K_turbo[ib].qs[jj / 4];
const uint8_t qs1 = K_turbo[ib].qs[jj / 4 + 1];
sum += (__half2float(turbo_lut[d0 ][(qs0>>0)&3]) +
__half2float(turbo_lut[d0+1][(qs0>>2)&3]) +
__half2float(turbo_lut[d0+2][(qs0>>4)&3]) +
__half2float(turbo_lut[d0+3][(qs0>>6)&3]) +
__half2float(turbo_lut[d0+4][(qs1>>0)&3]) +
__half2float(turbo_lut[d0+5][(qs1>>2)&3]) +
__half2float(turbo_lut[d0+6][(qs1>>4)&3]) +
__half2float(turbo_lut[d0+7][(qs1>>6)&3])) * norm;
}
} else {
sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum<nthreads_KQ>(sum);
}

if (use_logit_softcap) {
sum = logit_softcap*tanhf(sum);
Expand All @@ -295,12 +370,12 @@ static __global__ void flash_attn_ext_vec(
for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1) {
KQ_max_new[j] = fmaxf(KQ_max_new[j], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[j], offset, WARP_SIZE));
}
const float KQ_max_scale = expf(KQ_max[j] - KQ_max_new[j]);
const float KQ_max_scale = __expf(KQ_max[j] - KQ_max_new[j]);
KQ_max[j] = KQ_max_new[j];

KQ_reg[j] = expf(KQ_reg[j] - KQ_max[j]);
KQ_reg[j] = __expf(KQ_reg[j] - KQ_max[j]);
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
KQ[j*nthreads + tid] = KQ_reg[j];
if constexpr (!V_is_turbo) { KQ[j*nthreads + tid] = KQ_reg[j]; }

#ifdef V_DOT2_F32_F16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
Expand All @@ -318,7 +393,7 @@ static __global__ void flash_attn_ext_vec(
}

#ifndef GGML_USE_HIP
__syncwarp();
if constexpr (!V_is_turbo) { __syncwarp(); }
#endif // GGML_USE_HIP

#pragma unroll
Expand All @@ -329,23 +404,28 @@ static __global__ void flash_attn_ext_vec(
half2 KQ_k[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
KQ_k[j] = __half2half2(KQ[j*nthreads + k]);
if constexpr (V_is_turbo) {
const float kq_val = __shfl_sync(0xFFFFFFFF, KQ_reg[j], k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V));
KQ_k[j] = make_half2(__float2half(kq_val), __float2half(kq_val));
} else {
KQ_k[j] = __half2half2(KQ[j*nthreads + k]);
}
}

// Sparse V: skip V dequant if all attention weights for this position are negligible
// Disabled — per-lane branching causes warp divergence that costs more than the
// skipped dequants save (-0.3% to -2.8% on RTX 3090/4090).
// TODO: revisit with warp-level ballot skip.
#if 0
{
// Sparse V: skip V dequant if all attention weights for this position are negligible.
// For turbo types, the check is compiled out: at typical decode context lengths
// (< ~4K tokens) with threshold 1e-6, no positions are ever skipped, so the
// per-position branch is pure overhead (misprediction + comparison cost). This
// also dodges the warp-divergence regression on turbo paths that motivated the
// April 24 revert (commit f2dc968).
if constexpr (!V_is_turbo) {
bool dominated = true;
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (__hgt(__low2half(KQ_k[j]), sparse_v_threshold_h)) { dominated = false; break; }
}
if (dominated) { continue; }
}
#endif

#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
Expand Down Expand Up @@ -374,35 +454,142 @@ static __global__ void flash_attn_ext_vec(
float KQ_k[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
KQ_k[j] = KQ[j*nthreads + k];
if constexpr (V_is_turbo) {
KQ_k[j] = __shfl_sync(0xFFFFFFFF, KQ_reg[j], k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V));
} else {
KQ_k[j] = KQ[j*nthreads + k];
}
}

// Sparse V: skip V dequant if all attention weights for this position are negligible
// Disabled — per-lane branching causes warp divergence that costs more than the
// skipped dequants save (-0.3% to -2.8% on RTX 3090/4090).
// TODO: revisit with warp-level ballot skip.
#if 0
{
// Sparse V: skip V dequant if all attention weights for this position are negligible.
// Compiled out for turbo types — see half2 path comment above.
if constexpr (!V_is_turbo) {
bool dominated = true;
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (KQ_k[j] >= sparse_v_threshold_f) { dominated = false; break; }
}
if (dominated) { continue; }
}
#endif

// Turbo V path: precompute scaled centroids once per block to eliminate
// per-element norm multiply. centroid[idx]*norm is computed 8/4/16 times
// (once per centroid) instead of D times (once per element).
if constexpr (type_V == GGML_TYPE_TURBO3_0) {
const block_turbo3_0 * vb = (const block_turbo3_0 *)(V + k*nb21);
int prev_ib = -1;
float sc[8];

#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
float2 tmp[V_rows_per_thread/2];
dequantize_V(V + k*nb21, tmp,
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
const int i0 = 2*i_VKQ_0 + (threadIdx.x % nthreads_V)*V_rows_per_thread;
const int ib = i0 / QK_TURBO3;
const int j0 = i0 % QK_TURBO3;

if (ib != prev_ib) {
prev_ib = ib;
const float norm = __half2float(vb[ib].norm);
#pragma unroll
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
for (int c = 0; c < 8; ++c) { sc[c] = TURBO_CENTROIDS_3BIT[c] * norm; }
}

const uint8_t qs_byte = vb[ib].qs[j0 / 4];
const uint8_t sgn_byte = vb[ib].signs[j0 / 8];
const int shift_s = j0 % 8;

const uint8_t idx0 = ((qs_byte >> 0) & 0x3) | (((sgn_byte >> (shift_s+0)) & 0x1) << 2);
const uint8_t idx1 = ((qs_byte >> 2) & 0x3) | (((sgn_byte >> (shift_s+1)) & 0x1) << 2);
const uint8_t idx2 = ((qs_byte >> 4) & 0x3) | (((sgn_byte >> (shift_s+2)) & 0x1) << 2);
const uint8_t idx3 = ((qs_byte >> 6) & 0x3) | (((sgn_byte >> (shift_s+3)) & 0x1) << 2);

#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x*KQ_k[j];
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y*KQ_k[j];
VKQ[j][i_VKQ_0/nthreads_V + 0].x += sc[idx0]*KQ_k[j];
VKQ[j][i_VKQ_0/nthreads_V + 0].y += sc[idx1]*KQ_k[j];
VKQ[j][i_VKQ_0/nthreads_V + 1].x += sc[idx2]*KQ_k[j];
VKQ[j][i_VKQ_0/nthreads_V + 1].y += sc[idx3]*KQ_k[j];
}
}
} else if constexpr (type_V == GGML_TYPE_TURBO2_0) {
const block_turbo2_0 * vb = (const block_turbo2_0 *)(V + k*nb21);
int prev_ib = -1;
float sc[4];

#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
const int i0 = 2*i_VKQ_0 + (threadIdx.x % nthreads_V)*V_rows_per_thread;
const int ib = i0 / QK_TURBO2;
const int j0 = i0 % QK_TURBO2;

if (ib != prev_ib) {
prev_ib = ib;
const float norm = __half2float(vb[ib].norm);
#pragma unroll
for (int c = 0; c < 4; ++c) { sc[c] = TURBO_CENTROIDS_2BIT[c] * norm; }
}

const uint8_t qs_byte = vb[ib].qs[j0 / 4];

const uint8_t idx0 = (qs_byte >> 0) & 0x3;
const uint8_t idx1 = (qs_byte >> 2) & 0x3;
const uint8_t idx2 = (qs_byte >> 4) & 0x3;
const uint8_t idx3 = (qs_byte >> 6) & 0x3;

#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j][i_VKQ_0/nthreads_V + 0].x += sc[idx0]*KQ_k[j];
VKQ[j][i_VKQ_0/nthreads_V + 0].y += sc[idx1]*KQ_k[j];
VKQ[j][i_VKQ_0/nthreads_V + 1].x += sc[idx2]*KQ_k[j];
VKQ[j][i_VKQ_0/nthreads_V + 1].y += sc[idx3]*KQ_k[j];
}
}
} else if constexpr (type_V == GGML_TYPE_TURBO4_0) {
const block_turbo4_0 * vb = (const block_turbo4_0 *)(V + k*nb21);
int prev_ib = -1;
float sc[16];

#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
const int i0 = 2*i_VKQ_0 + (threadIdx.x % nthreads_V)*V_rows_per_thread;
const int ib = i0 / QK_TURBO4;
const int j0 = i0 % QK_TURBO4;

if (ib != prev_ib) {
prev_ib = ib;
const float norm = __half2float(vb[ib].norm);
#pragma unroll
for (int c = 0; c < 16; ++c) { sc[c] = TURBO_CENTROIDS_4BIT[c] * norm; }
}

const uint8_t qs_byte0 = vb[ib].qs[j0 / 2];
const uint8_t qs_byte1 = vb[ib].qs[j0 / 2 + 1];

const uint8_t idx0 = (qs_byte0 >> 0) & 0xF;
const uint8_t idx1 = (qs_byte0 >> 4) & 0xF;
const uint8_t idx2 = (qs_byte1 >> 0) & 0xF;
const uint8_t idx3 = (qs_byte1 >> 4) & 0xF;

#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j][i_VKQ_0/nthreads_V + 0].x += sc[idx0]*KQ_k[j];
VKQ[j][i_VKQ_0/nthreads_V + 0].y += sc[idx1]*KQ_k[j];
VKQ[j][i_VKQ_0/nthreads_V + 1].x += sc[idx2]*KQ_k[j];
VKQ[j][i_VKQ_0/nthreads_V + 1].y += sc[idx3]*KQ_k[j];
}
}
} else {
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
float2 tmp[V_rows_per_thread/2];
dequantize_V(V + k*nb21, tmp,
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
#pragma unroll
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x*KQ_k[j];
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y*KQ_k[j];
}
}
}
}
Expand All @@ -422,10 +609,10 @@ static __global__ void flash_attn_ext_vec(
}

const float kqmax_new_j = fmaxf(sink, KQ_max[j]);
const float KQ_max_scale = expf(KQ_max[j] - kqmax_new_j);
const float KQ_max_scale = __expf(KQ_max[j] - kqmax_new_j);
KQ_max[j] = kqmax_new_j;

KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? __expf(sink - KQ_max[j]) : 0.0f);

#ifdef V_DOT2_F32_F16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
Expand Down Expand Up @@ -471,7 +658,7 @@ static __global__ void flash_attn_ext_vec(

float kqmax_new = KQ_max_shared[j_VKQ][threadIdx.x];
kqmax_new = warp_reduce_max(kqmax_new);
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
const float kqmax_scale = __expf(KQ_max[j_VKQ] - kqmax_new);
KQ_max[j_VKQ] = kqmax_new;

#ifdef V_DOT2_F32_F16_AVAILABLE
Expand Down
Loading
Loading