Skip to content

Commit 30c3c23

Browse files
committed
hip: bypass memory pool for flash attention f16 temp buffers
The legacy memory pool (ggml_cuda_pool_leg) retains peak-sized allocations permanently. For quantized KV flash attention, the f16 dequant temp buffers (K_f16, V_f16) stay allocated in the pool after use, consuming more VRAM than the KV compression saves. This causes quantized KV (q8_0, q4_0) to OOM before f16 at equivalent context lengths on HIP/ROCm where VMM is unavailable. Root cause: ggml_cuda_pool_leg::free() stores buffers in buffer_pool[] for reuse and never calls cudaFree. On CUDA with VMM the OS can reclaim unused virtual memory. On HIP without VMM (all consumer RDNA 3/4 GPUs), the pool permanently consumes peak VRAM. Fix: on HIP, allocate f16 temp buffers with cudaMalloc and free with cudaFree (via RAII wrapper) instead of the pool. Memory is released after the FA kernel completes via cudaStreamSynchronize. Trade-off: one cudaStreamSynchronize per FA call (~5% overhead at 32K). Impact: CUDA/Metal unaffected (#ifdef GGML_USE_HIP only). Confirmed: gfx1100 (RX 7900 XT), gfx1201 (RX 9070 XT) Fixes: ggml-org#22107
1 parent 23b8cc4 commit 30c3c23

1 file changed

Lines changed: 29 additions & 0 deletions

File tree

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,8 +946,37 @@ void launch_fattn(
946946
const int cc = ggml_cuda_info().devices[id].cc;
947947
const int nsm = ggml_cuda_info().devices[id].nsm;
948948

949+
#ifdef GGML_USE_HIP
950+
// HIP/ROCm: bypass the memory pool for f16 temp buffers.
951+
// The legacy pool (ggml_cuda_pool_leg) retains peak-sized allocations permanently
952+
// because free() stores buffers for reuse rather than releasing them.
953+
// On HIP without VMM support (RDNA 3/4), this means the f16 dequant temp buffers
954+
// for quantized KV stay allocated after use, consuming more VRAM than the KV
955+
// compression saves — causing OOM before f16 at equivalent context lengths.
956+
// Using raw cudaMalloc/cudaFree ensures memory is released after the kernel completes.
957+
// Ref: https://github.com/ggml-org/llama.cpp/issues/22107
958+
struct hip_f16_alloc {
959+
half * ptr = nullptr;
960+
cudaStream_t stream;
961+
hip_f16_alloc(cudaStream_t s) : stream(s) {}
962+
hip_f16_alloc(const hip_f16_alloc &) = delete;
963+
hip_f16_alloc & operator=(const hip_f16_alloc &) = delete;
964+
~hip_f16_alloc() {
965+
if (ptr) {
966+
cudaStreamSynchronize(stream);
967+
cudaFree(ptr);
968+
}
969+
}
970+
void alloc(size_t nelements) {
971+
CUDA_CHECK(cudaMalloc(&ptr, nelements * sizeof(half)));
972+
}
973+
};
974+
hip_f16_alloc K_f16(main_stream);
975+
hip_f16_alloc V_f16(main_stream);
976+
#else
949977
ggml_cuda_pool_alloc<half> K_f16(pool);
950978
ggml_cuda_pool_alloc<half> V_f16(pool);
979+
#endif
951980
ggml_cuda_pool_alloc<int> KV_max(pool);
952981
ggml_cuda_pool_alloc<float> dst_tmp(pool);
953982
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);

0 commit comments

Comments
 (0)