diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 861d84467095..2bbe37846dee 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -1298,27 +1298,49 @@ void launch_fattn( const int nsm = ggml_cuda_info().devices[id].nsm; #ifdef GGML_USE_HIP - // HIP/ROCm: bypass the memory pool for f16 temp buffers. - // The legacy pool (ggml_cuda_pool_leg) retains peak-sized allocations permanently. - // For quantized KV dequant, this means the f16 temp buffer stays allocated, - // consuming more VRAM than the quantized KV compression saves — causing OOM. - // Using raw alloc+free ensures the memory is released after the kernel completes. + // HIP/ROCm: the f16 dequant temp buffers for quantized KV need two strategies + // depending on whether the current launch is captured into a HIP graph: + // * NOT capturing (large prefill, eager): raw cudaMalloc/cudaFree so the + // (potentially multi-GB) buffer is freed immediately. The no-VMM legacy + // pool would otherwise retain it at peak size -> negate KV compression. + // * Capturing (decode/speculative captured into a graph): raw cudaMalloc/ + // cudaFree/cudaStreamSynchronize are forbidden ("operation not permitted + // when stream is capturing") -> use the pool instead. ggml re-runs warmup + // eagerly on tensor-size changes, so the pool buffer is sized before capture. + // Small batches (Q->ne[1] <= 8) always route through the pool (decode/spec + // cases that get captured). Patch: gemma4-turboquant-rdna4 #0001 part B. + cudaStreamCaptureStatus fa_capture_status = cudaStreamCaptureStatusNone; + CUDA_CHECK(cudaStreamIsCapturing(main_stream, &fa_capture_status)); + const bool fa_use_pool = (fa_capture_status != cudaStreamCaptureStatusNone) || (Q->ne[1] <= 8); + struct hip_f16_alloc { - half * ptr = nullptr; - cudaStream_t stream; - hip_f16_alloc(cudaStream_t s) : stream(s) {} + half * ptr = nullptr; + ggml_cuda_pool * mem_pool = nullptr; // non-null => allocate from the pool (graph-safe) + size_t pool_size = 0; + cudaStream_t stream; + hip_f16_alloc(cudaStream_t s, ggml_cuda_pool * p) : mem_pool(p), stream(s) {} ~hip_f16_alloc() { - if (ptr) { - cudaStreamSynchronize(stream); - cudaFree(ptr); + if (!ptr) { + return; + } + if (mem_pool) { + // Pool free is plain bookkeeping (no CUDA calls) -> safe during capture. + mem_pool->free(ptr, pool_size); + } else { + (void) cudaStreamSynchronize(stream); + (void) cudaFree(ptr); } } void alloc(size_t nelements) { - CUDA_CHECK(cudaMalloc(&ptr, nelements * sizeof(half))); + if (mem_pool) { + ptr = (half *) mem_pool->alloc(nelements * sizeof(half), &pool_size); + } else { + CUDA_CHECK(cudaMalloc(&ptr, nelements * sizeof(half))); + } } }; - hip_f16_alloc K_f16(main_stream); - hip_f16_alloc V_f16(main_stream); + hip_f16_alloc K_f16(main_stream, fa_use_pool ? &pool : nullptr); + hip_f16_alloc V_f16(main_stream, fa_use_pool ? &pool : nullptr); #else ggml_cuda_pool_alloc K_f16(pool); ggml_cuda_pool_alloc V_f16(pool); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 0f7efdef01db..0455b9da54c2 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -1,6 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" #include "fattn-wmma-f16.cuh" +#include "turbo-quant.cuh" // nbatch_fa == number of KQ rows to process per iteration // nbatch_K == number of K columns to load in parallel for KQ calculation @@ -466,12 +467,164 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( ggml_cuda_unroll<5>{}(load); } +// ---- Inline-dequant tile loads for turbo3 (GGML_TYPE_TURBO3_0) ---- +// +// These mirror the thread-tiling loop structure of the f16 flash_attn_tile_load_tile +// overloads above EXACTLY (same i/j coverage, same J_padding stride) so the resulting +// tile_KV layout is byte-for-byte compatible with what flash_attn_tile_iter_KQ / _iter +// consume. The only difference vs the f16 path: instead of memcpy-ing half2 from a +// pre-materialized f16 KV tensor, we decode turbo3 blocks straight out of the original +// quantized KV tensor (no full-KV->f16 dequant tax). +// +// Layout reminder (turbo3): each KV row has `row_ncols` (=DKQ/DV=256) elements packed as +// row_ncols/QK_TURBO3 (=2) blocks of QK_TURBO3 (=128) elements. Row base in BYTES is +// KV_quant_base + row_index*row_stride_bytes. For an element index `e` within the row: +// ib = e / QK_TURBO3; jb = e % QK_TURBO3; val = turbo3_dequant_element(&blk[ib], jb, norm) +// where norm = __half2float(blk[ib].norm) (per-block, hoisted). +// +// `col0` is the element-column offset of this tile within the row (k_KQ_0 for the K +// batches; 0 for the V batch since V covers the full DV columns). +// +// half2-dst overload (FAST_FP16 path; used for K and V tiles in the fp16 kernel): +template +static __device__ __forceinline__ void flash_attn_tile_load_tile_turbo3( + const char * const __restrict__ KV_quant_base, half2 * const __restrict__ tile_KV, + const int row_stride_bytes, const int col0, const int i_sup) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + auto load = [&] __device__ (const int n) { + const int stride_j = warp_size >> n; + + if (stride_j == 0) { + return; + } + + const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j); + const int j0_stop = ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j); + const int stride_i = warp_size / stride_j; + + if (j0_start == j0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j); + + if (i0 + nwarps*stride_i <= I || i < I) { +#pragma unroll + for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { + const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne; + + __align__(16) half2 tmp_h2[cpy_ne]; + const bool in_bounds = !oob_check || i < i_sup; + const block_turbo3_0 * blk = in_bounds + ? (const block_turbo3_0 *) (KV_quant_base + (size_t) i * row_stride_bytes) + : nullptr; +#pragma unroll + for (int l = 0; l < cpy_ne; ++l) { + if (in_bounds) { + // half2 slot (j+l) holds element columns col0 + 2*(j+l) and col0 + 2*(j+l)+1 + const int e0 = col0 + 2*(j + l); + const int e1 = e0 + 1; + const int ib0 = e0 / QK_TURBO3; + const int jb0 = e0 % QK_TURBO3; + const int ib1 = e1 / QK_TURBO3; + const int jb1 = e1 % QK_TURBO3; + const float norm0 = __half2float(blk[ib0].norm); + const float norm1 = __half2float(blk[ib1].norm); + tmp_h2[l] = make_half2( + turbo3_dequant_element(&blk[ib0], jb0, norm0), + turbo3_dequant_element(&blk[ib1], jb1, norm1)); + } else { + tmp_h2[l] = make_half2(0.0f, 0.0f); + } + } + ggml_cuda_memcpy_1(tile_KV + i*(J/2 + J_padding) + j, tmp_h2); + } + } + } + }; + static_assert(J % 8 == 0, "bad J"); + static_assert((J/2) % cpy_ne == 0, "bad J"); + ggml_cuda_unroll<7>{}(load); +} + +// float-dst overload (non-FAST_FP16 path): +template +static __device__ __forceinline__ void flash_attn_tile_load_tile_turbo3( + const char * const __restrict__ KV_quant_base, float * const __restrict__ tile_KV, + const int row_stride_bytes, const int col0, const int i_sup) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + auto load = [&] __device__ (const int n) { + const int stride_j = warp_size >> n; + + if (stride_j == 0) { + return; + } + + const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j); + const int j0_stop = (J/cpy_ne) - (J/cpy_ne) % (1*stride_j); + const int stride_i = warp_size / stride_j; + + if (j0_start == j0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j); + + if (i0 + nwarps*stride_i <= I || i < I) { +#pragma unroll + for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { + // `j` here is in half2 units (matches the f16 float-dst overload), so + // the element column for half2 slot (j+l) is col0 + 2*(j+l). + const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2); + + __align__(16) float2 tmp_f2[cpy_ne/2]; + const bool in_bounds = !oob_check || i < i_sup; + const block_turbo3_0 * blk = in_bounds + ? (const block_turbo3_0 *) (KV_quant_base + (size_t) i * row_stride_bytes) + : nullptr; +#pragma unroll + for (int l = 0; l < cpy_ne/2; ++l) { + if (in_bounds) { + const int e0 = col0 + 2*(j + l); + const int e1 = e0 + 1; + const int ib0 = e0 / QK_TURBO3; + const int jb0 = e0 % QK_TURBO3; + const int ib1 = e1 / QK_TURBO3; + const int jb1 = e1 % QK_TURBO3; + const float norm0 = __half2float(blk[ib0].norm); + const float norm1 = __half2float(blk[ib1].norm); + tmp_f2[l].x = turbo3_dequant_element(&blk[ib0], jb0, norm0); + tmp_f2[l].y = turbo3_dequant_element(&blk[ib1], jb1, norm1); + } else { + tmp_f2[l] = make_float2(0.0f, 0.0f); + } + } + ggml_cuda_memcpy_1(tile_KV + i*(J + J_padding) + 2*j, tmp_f2); + } + } + } + }; + static_assert(J % 8 == 0, "bad J"); + static_assert(J % cpy_ne == 0, "bad J"); + ggml_cuda_unroll<5>{}(load); +} + // Function that performs a single iteration in for the KQ matrix multiplication: template + bool use_logit_softcap, bool oob_check, ggml_type type_K = GGML_TYPE_F16, typename T_vec_dot> static __device__ __forceinline__ void flash_attn_tile_iter_KQ( T_vec_dot * const Q_tmp, const half2 * const __restrict__ K_h2, + const char * const __restrict__ K_quant, + const int K_row_stride_bytes, T_vec_dot * const KV_tmp, const int stride_K2, const int k_VKQ_0, @@ -485,8 +638,14 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ( constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column - flash_attn_tile_load_tile - (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); + if constexpr (type_K == GGML_TYPE_TURBO3_0) { + // Inline-dequant turbo3 K straight into the tile (no full-KV->f16 materialization). + flash_attn_tile_load_tile_turbo3 + (K_quant + int64_t(k_VKQ_0)*K_row_stride_bytes, KV_tmp, K_row_stride_bytes, k_KQ_0, k_VKQ_sup); + } else { + flash_attn_tile_load_tile + (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); + } __syncthreads(); #ifdef FAST_FP16_AVAILABLE @@ -543,11 +702,16 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ( // Function that performs a single iteration of the main loop over up to nbatch_fa tokens. template + bool use_logit_softcap, bool oob_check, ggml_type type_K = GGML_TYPE_F16, ggml_type type_V = GGML_TYPE_F16, + typename T_vec_dot, typename T_KQ, typename T_acc> static __device__ __forceinline__ void flash_attn_tile_iter( T_vec_dot * const Q_tmp, const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, + const char * const __restrict__ K_quant, + const char * const __restrict__ V_quant, + const int K_row_stride_bytes, + const int V_row_stride_bytes, const half * const __restrict__ mask, const uint3 ne01, const float logit_softcap, @@ -594,13 +758,13 @@ static __device__ __forceinline__ void flash_attn_tile_iter( constexpr int nbatch_K_last = DKQ % nbatch_K; #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) { - flash_attn_tile_iter_KQ( - Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + flash_attn_tile_iter_KQ( + Q_tmp, K_h2, K_quant, K_row_stride_bytes, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); } if (nbatch_K_last > 0) { constexpr int k_KQ_0 = DKQ - nbatch_K_last; - flash_attn_tile_iter_KQ( - Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + flash_attn_tile_iter_KQ( + Q_tmp, K_h2, K_quant, K_row_stride_bytes, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); } // Apply logit softcap + mask, update KQ_max: @@ -705,8 +869,15 @@ static __device__ __forceinline__ void flash_attn_tile_iter( static_assert(nbatch_V % np == 0, "bad nbatch_V"); #pragma unroll for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) { - flash_attn_tile_load_tile - (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); + if constexpr (type_V == GGML_TYPE_TURBO3_0) { + // Inline-dequant turbo3 V straight into the tile (no full-KV->f16 materialization). + // V covers element columns 0..DV-1 of the row (col0 == 0). + flash_attn_tile_load_tile_turbo3 + (V_quant + int64_t(k_VKQ_0 + k0)*V_row_stride_bytes, KV_tmp, V_row_stride_bytes, 0, k_VKQ_sup - k0); + } else { + flash_attn_tile_load_tile + (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); + } __syncthreads(); #ifdef FAST_FP16_AVAILABLE @@ -775,7 +946,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter( } } -template // D == head size +template // D == head size __launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_tile( const char * __restrict__ Q, @@ -841,6 +1013,14 @@ static __global__ void flash_attn_tile( const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape + // For inline-dequant turbo3 KV the K/V char* params point at the ORIGINAL quantized + // tensor (launch_fattn was called with need_f16_K/V=false, leaving nb11/nb21 as the + // original byte strides). Reproduce the same sequence/head byte offset as K_h2/V_h2. + const char * K_quant = K + nb13*sequence + nb12*(head0 / gqa_ratio); + const char * V_quant = V + nb23*sequence + nb22*(head0 / gqa_ratio); + const int K_row_stride_bytes = nb11; + const int V_row_stride_bytes = nb21; + const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr; const int stride_K2 = nb11 / sizeof(half2); @@ -939,23 +1119,23 @@ static __global__ void flash_attn_tile( int k_VKQ_0 = blockIdx.y*nbatch_fa; while (k_VKQ_0 < k_VKQ_max - nbatch_fa) { constexpr bool oob_check = false; - flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, K_quant, V_quant, K_row_stride_bytes, V_row_stride_bytes, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); k_VKQ_0 += gridDim.y*nbatch_fa; } if (k_VKQ_0 < k_VKQ_max) { constexpr bool oob_check = true; - flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, K_quant, V_quant, K_row_stride_bytes, V_row_stride_bytes, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); } } else { // Branch without out-of-bounds checks. for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) { constexpr bool oob_check = false; - flash_attn_tile_iter - (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, K_quant, V_quant, K_row_stride_bytes, V_row_stride_bytes, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); } } @@ -1127,7 +1307,8 @@ static __global__ void flash_attn_tile( #endif // FLASH_ATTN_AVAILABLE } -template +template static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; @@ -1137,15 +1318,21 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr size_t nbytes_shared = 0; + // turbo3 KV is inline-dequantized inside the kernel during the tile load, so the + // full-KV->f16 materialization in launch_fattn must be skipped (need_f16 = false). + // Every other type keeps the existing f16 behavior byte-for-byte. + constexpr bool need_f16_K = (type_K != GGML_TYPE_TURBO3_0); + constexpr bool need_f16_V = (type_V != GGML_TYPE_TURBO3_0); + #ifdef GGML_USE_HIP if constexpr (DV <= 128) { if (Q->ne[1] > 32/ncols2) { constexpr int cols_per_block = 64; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - fattn_kernel_t fattn_kernel = flash_attn_tile; + fattn_kernel_t fattn_kernel = flash_attn_tile; launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, need_f16_K, need_f16_V, false, warp_size); return; } } @@ -1159,9 +1346,9 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr int cols_per_block = 32; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - fattn_kernel_t fattn_kernel = flash_attn_tile; + fattn_kernel_t fattn_kernel = flash_attn_tile; launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, need_f16_K, need_f16_V, false, warp_size); return; } } @@ -1170,9 +1357,9 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr int cols_per_block = 16; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - fattn_kernel_t fattn_kernel = flash_attn_tile; + fattn_kernel_t fattn_kernel = flash_attn_tile; launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, need_f16_K, need_f16_V, false, warp_size); return; } @@ -1181,9 +1368,9 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr int cols_per_block = 8; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - fattn_kernel_t fattn_kernel = flash_attn_tile; + fattn_kernel_t fattn_kernel = flash_attn_tile; launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, need_f16_K, need_f16_V, false, warp_size); return; } } @@ -1193,9 +1380,9 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr int cols_per_block = 4; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - fattn_kernel_t fattn_kernel = flash_attn_tile; + fattn_kernel_t fattn_kernel = flash_attn_tile; launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, need_f16_K, need_f16_V, false, warp_size); return; } } @@ -1204,16 +1391,17 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr int cols_per_block = 2; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - fattn_kernel_t fattn_kernel = flash_attn_tile; + fattn_kernel_t fattn_kernel = flash_attn_tile; launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, need_f16_K, need_f16_V, false, warp_size); return; } GGML_ABORT("fatal error"); } -template +template static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; @@ -1234,44 +1422,44 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm if constexpr (DKQ == 576) { if (use_gqa_opt && gqa_ratio % 16 == 0) { - launch_fattn_tile_switch_ncols1(ctx, dst); + launch_fattn_tile_switch_ncols1(ctx, dst); return; } if (use_gqa_opt && gqa_ratio % 4 == 0) { - launch_fattn_tile_switch_ncols1(ctx, dst); + launch_fattn_tile_switch_ncols1(ctx, dst); return; } } if constexpr (DKQ <= 512) { if (use_gqa_opt && gqa_ratio % 8 == 0) { - launch_fattn_tile_switch_ncols1(ctx, dst); + launch_fattn_tile_switch_ncols1(ctx, dst); return; } if (use_gqa_opt && gqa_ratio % 4 == 0) { - launch_fattn_tile_switch_ncols1(ctx, dst); + launch_fattn_tile_switch_ncols1(ctx, dst); return; } if constexpr (DV <= 256) { if (use_gqa_opt && gqa_ratio % 2 == 0) { - launch_fattn_tile_switch_ncols1(ctx, dst); + launch_fattn_tile_switch_ncols1(ctx, dst); return; } - launch_fattn_tile_switch_ncols1(ctx, dst); + launch_fattn_tile_switch_ncols1(ctx, dst); return; } // DV > 256 (e.g. DKQ=DV=512, head_dim=512 models): extend GQA fallback to ncols2=2/1. // Without this, gqa_ratio not divisible by 4 (e.g. ratio=2) reaches GGML_ABORT. if (use_gqa_opt && gqa_ratio % 2 == 0) { - launch_fattn_tile_switch_ncols1(ctx, dst); + launch_fattn_tile_switch_ncols1(ctx, dst); return; } - launch_fattn_tile_switch_ncols1(ctx, dst); + launch_fattn_tile_switch_ncols1(ctx, dst); return; } GGML_ABORT("fatal error"); @@ -1280,10 +1468,29 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm template void ggml_cuda_flash_attn_ext_tile_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); +#ifdef GGML_USE_HIP + // Inline-dequant turbo3 TILE path. STRICT SCOPE: only DKQ==DV==256 with + // type_K==type_V==GGML_TYPE_TURBO3_0. Every other type/dim keeps f16 below. + if constexpr (DKQ == 256 && DV == 256) { + if (K->type == GGML_TYPE_TURBO3_0 && V->type == GGML_TYPE_TURBO3_0) { + if (logit_softcap == 0.0f) { + launch_fattn_tile_switch_ncols2(ctx, dst); + } else { + launch_fattn_tile_switch_ncols2(ctx, dst); + } + return; + } + } +#endif // GGML_USE_HIP + GGML_UNUSED(K); + GGML_UNUSED(V); + if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; launch_fattn_tile_switch_ncols2(ctx, dst); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 38d222b7016b..d4b11115c2d0 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -492,6 +492,28 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // Force VEC path which does inline dequant with zero temp buffer overhead. // Trade-off: prefill is slower (sequential query processing). // Limitation: head_dim > 256 cannot use VEC (falls through to TILE). + // + // NOTE (MTP investigation, 2026-06): routing the speculative verify batch + // (Q->ne[1] >= 3) to TILE/MMA was tested and is STRICTLY WORSE for quantized + // KV — TILE/MMA hardcode need_f16_K/V=true and materialize the full KV cache + // to f16 every step (fattn-common.cuh ~1366), a per-step O(KV) dequant tax that + // outweighs the parallel-verify benefit. VEC's inline dequant is correct here. + // Self-spec MTP is capped at ~1.0x on gfx1201 regardless (no KV-load + // amortization across verify columns); see benchmarks/gemma4-mtp/results. + // + // UPDATE (inline-dequant TILE): all multi-row batches (D=256, turbo3/turbo3 KV, + // Q->ne[1] >= 3) route to the TILE kernel, which inline-dequantizes turbo3 K/V + // during the global->shared tile load (no full-KV->f16 materialization). This + // covers BOTH the speculative-verify batch AND prefill. Measured win: turbo3 + // PREFILL 1.4-1.9x faster than the sequential VEC path (pp2048 1038->1929 t/s), + // within ~6-12% of f16 prefill -> makes 3-bit turbo3 KV practical for long-context + // RAG. (MTP self-spec decode still <=1.0x: that wall is GEMM weight-load + // amortization on RDNA4, not attention.) Only Q->ne[1] <= 2 (decode) keeps VEC + // below, where inline dequant with zero temp buffer is optimal. + if (Q->ne[0] == 256 && K->type == GGML_TYPE_TURBO3_0 && V->type == GGML_TYPE_TURBO3_0 && + Q->ne[1] >= 3 && can_use_vector_kernel) { + return BEST_FATTN_KERNEL_TILE; + } if ((ggml_is_quantized(K->type) || ggml_is_quantized(V->type)) && can_use_vector_kernel) { return BEST_FATTN_KERNEL_VEC; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b5b37a7fc0e4..ce94f18348c0 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8820,6 +8820,16 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_set_rows_tq4_1s(GGML_TYPE_I32, 128, 256, 64)); + // turbo3 D=256 inline-dequant TILE path on HIP/gfx1201. nb=3/4/6 = speculative-verify + // widths; nb=64/128/256 = prefill-like batches (validate the prefill win is correct). + // Emitted FIRST so they run before any stock case that might abort the suite. + for (int nb : { 3, 4, 6, 64, 128, 256 }) { + for (int kv : { 512, 1024 }) { + test_cases.emplace_back(new test_flash_attn_ext( + 256, 256, 4, {4, 1}, kv, nb, true, false, 0.0f, 0.0f, GGML_PREC_F32, GGML_TYPE_TURBO3_0)); + } + } + for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 320, 512, 576 }) { for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) { if (hsk != 192 && hsk != 320 && hsk != 576 && hsk != hsv) continue;