Skip to content

Commit 0b05974

Browse files
committed
hip: bypass memory pool for flash attention f16 temp buffers
The legacy memory pool (ggml_cuda_pool_leg) retains peak-sized allocations permanently. For quantized KV flash attention, the f16 dequant temp buffers (K_f16, V_f16) stay allocated in the pool after use, consuming more VRAM than the KV compression saves. This causes quantized KV (q8_0, q4_0) to OOM before f16 at equivalent context lengths on HIP/ROCm where VMM is unavailable. Root cause: ggml_cuda_pool_leg::free() stores buffers in buffer_pool[] for reuse and never calls cudaFree. On CUDA with VMM the OS can reclaim unused virtual memory. On HIP without VMM (all consumer RDNA 3/4 GPUs), the pool permanently consumes peak VRAM. Fix: on HIP, allocate f16 temp buffers with cudaMalloc and free with cudaFree (via RAII wrapper) instead of the pool. Memory is released after the FA kernel completes via cudaStreamSynchronize. Trade-off: one cudaStreamSynchronize per FA call (~5% overhead at 32K). Impact: CUDA/Metal unaffected (#ifdef GGML_USE_HIP only). Confirmed: gfx1100 (RX 7900 XT), gfx1201 (RX 9070 XT) Fixes: ggml-org#22107
1 parent d3271ac commit 0b05974

1 file changed

Lines changed: 9 additions & 4 deletions

File tree

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,14 +1299,19 @@ void launch_fattn(
12991299

13001300
#ifdef GGML_USE_HIP
13011301
// HIP/ROCm: bypass the memory pool for f16 temp buffers.
1302-
// The legacy pool (ggml_cuda_pool_leg) retains peak-sized allocations permanently.
1303-
// For quantized KV dequant, this means the f16 temp buffer stays allocated,
1304-
// consuming more VRAM than the quantized KV compression saves — causing OOM.
1305-
// Using raw alloc+free ensures the memory is released after the kernel completes.
1302+
// The legacy pool (ggml_cuda_pool_leg) retains peak-sized allocations permanently
1303+
// because free() stores buffers for reuse rather than releasing them.
1304+
// On HIP without VMM support (RDNA 3/4), this means the f16 dequant temp buffers
1305+
// for quantized KV stay allocated after use, consuming more VRAM than the KV
1306+
// compression saves — causing OOM before f16 at equivalent context lengths.
1307+
// Using raw cudaMalloc/cudaFree ensures memory is released after the kernel completes.
1308+
// Ref: https://github.com/ggml-org/llama.cpp/issues/22107
13061309
struct hip_f16_alloc {
13071310
half * ptr = nullptr;
13081311
cudaStream_t stream;
13091312
hip_f16_alloc(cudaStream_t s) : stream(s) {}
1313+
hip_f16_alloc(const hip_f16_alloc &) = delete;
1314+
hip_f16_alloc & operator=(const hip_f16_alloc &) = delete;
13101315
~hip_f16_alloc() {
13111316
if (ptr) {
13121317
cudaStreamSynchronize(stream);

0 commit comments

Comments
 (0)