Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 36 additions & 14 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1298,27 +1298,49 @@ void launch_fattn(
const int nsm = ggml_cuda_info().devices[id].nsm;

#ifdef GGML_USE_HIP
// HIP/ROCm: bypass the memory pool for f16 temp buffers.
// The legacy pool (ggml_cuda_pool_leg) retains peak-sized allocations permanently.
// For quantized KV dequant, this means the f16 temp buffer stays allocated,
// consuming more VRAM than the quantized KV compression saves — causing OOM.
// Using raw alloc+free ensures the memory is released after the kernel completes.
// HIP/ROCm: the f16 dequant temp buffers for quantized KV need two strategies
// depending on whether the current launch is captured into a HIP graph:
// * NOT capturing (large prefill, eager): raw cudaMalloc/cudaFree so the
// (potentially multi-GB) buffer is freed immediately. The no-VMM legacy
// pool would otherwise retain it at peak size -> negate KV compression.
// * Capturing (decode/speculative captured into a graph): raw cudaMalloc/
// cudaFree/cudaStreamSynchronize are forbidden ("operation not permitted
// when stream is capturing") -> use the pool instead. ggml re-runs warmup
// eagerly on tensor-size changes, so the pool buffer is sized before capture.
// Small batches (Q->ne[1] <= 8) always route through the pool (decode/spec
// cases that get captured). Patch: gemma4-turboquant-rdna4 #0001 part B.
cudaStreamCaptureStatus fa_capture_status = cudaStreamCaptureStatusNone;
CUDA_CHECK(cudaStreamIsCapturing(main_stream, &fa_capture_status));
const bool fa_use_pool = (fa_capture_status != cudaStreamCaptureStatusNone) || (Q->ne[1] <= 8);

struct hip_f16_alloc {
half * ptr = nullptr;
cudaStream_t stream;
hip_f16_alloc(cudaStream_t s) : stream(s) {}
half * ptr = nullptr;
ggml_cuda_pool * mem_pool = nullptr; // non-null => allocate from the pool (graph-safe)
size_t pool_size = 0;
cudaStream_t stream;
hip_f16_alloc(cudaStream_t s, ggml_cuda_pool * p) : mem_pool(p), stream(s) {}
~hip_f16_alloc() {
if (ptr) {
cudaStreamSynchronize(stream);
cudaFree(ptr);
if (!ptr) {
return;
}
if (mem_pool) {
// Pool free is plain bookkeeping (no CUDA calls) -> safe during capture.
mem_pool->free(ptr, pool_size);
} else {
(void) cudaStreamSynchronize(stream);
(void) cudaFree(ptr);
}
}
void alloc(size_t nelements) {
CUDA_CHECK(cudaMalloc(&ptr, nelements * sizeof(half)));
if (mem_pool) {
ptr = (half *) mem_pool->alloc(nelements * sizeof(half), &pool_size);
} else {
CUDA_CHECK(cudaMalloc(&ptr, nelements * sizeof(half)));
}
}
};
hip_f16_alloc K_f16(main_stream);
hip_f16_alloc V_f16(main_stream);
hip_f16_alloc K_f16(main_stream, fa_use_pool ? &pool : nullptr);
hip_f16_alloc V_f16(main_stream, fa_use_pool ? &pool : nullptr);
#else
ggml_cuda_pool_alloc<half> K_f16(pool);
ggml_cuda_pool_alloc<half> V_f16(pool);
Expand Down
Loading