Skip to content

Commit a596a5c

Browse files
committed
feat(cuda-fa): add HALF_QUANTS build mode for FA vec K/V pairs
Add GGML_CUDA_FA_HALF_QUANTS as a mutually-exclusive alternative to ALL_QUANTS. The three build modes now compile: - DEFAULT: 27 curated pairs (same as v0.3.2 default, exact list) - HALF_QUANTS: 91 pairs (K>=V triangle, 46% fewer instances than ALL) - ALL_QUANTS: 169 pairs (full 13x13 K/V matrix) Key changes: - CMake: ggml_add_fattn_vec_pair macro per backend (CUDA/HIP/MUSA), mutual-exclusivity FATAL_ERROR, HALF triangular loop, DEFAULT exact list with file-existence checks instead of broad globs - fattn.cu: rank helpers, ggml_cuda_fattn_canonical_kv_type (F32->F16 normalization), ggml_cuda_fattn_pair_compiled, tri-mode dispatch, runtime guard for uncompiled pairs, safer Turbo prefill routing - generate_cu_files.py: unified 13-type TYPES_KV, empty extras - 82 new template instance .cu files for the expanded pair matrix - gen-fattn-vec-dispatch.py: dispatch table code generator - Docs: accurate pair-count claims (91 vs 169 instances)
1 parent 9f1bcc9 commit a596a5c

96 files changed

Lines changed: 1396 additions & 158 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

AGENTS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ cmake -B build -DGGML_METAL=ON -DCMAKE_BUILD_TYPE=Release
3737
cmake --build build -j
3838
```
3939

40-
`GGML_CUDA_FA_ALL_QUANTS=ON` is required for TurboQuant and TCQ cache types. Add `-DCMAKE_CUDA_ARCHITECTURES=86` for RTX 3090, or `-DCMAKE_CUDA_ARCHITECTURES=89` for RTX 4090, if cross-compiling or building in CI without a GPU.
40+
`GGML_CUDA_FA_ALL_QUANTS=ON` is required for TurboQuant and TCQ cache types. `GGML_CUDA_FA_HALF_QUANTS=ON` is an alternative that compiles only the useful K>=V half of the K/V pair matrix (compiling 91 FA vec K/V pairs instead of 169, reducing FA vec pair instances by ~46% vs ALL_QUANTS). These two flags are mutually exclusive. Add `-DCMAKE_CUDA_ARCHITECTURES=86` for RTX 3090, or `-DCMAKE_CUDA_ARCHITECTURES=89` for RTX 4090, if cross-compiling or building in CI without a GPU.
4141

4242
Key binaries: `build/bin/llama-server`, `build/bin/llama-cli`, `build/bin/llama-bench`, `build/bin/llama-perplexity`.
4343

CLAUDE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ cmake -B build -DGGML_METAL=ON -DCMAKE_BUILD_TYPE=Release
3737
cmake --build build -j
3838
```
3939

40-
`GGML_CUDA_FA_ALL_QUANTS=ON` is required for TurboQuant and TCQ cache types. Add `-DCMAKE_CUDA_ARCHITECTURES=86` for RTX 3090, or `-DCMAKE_CUDA_ARCHITECTURES=89` for RTX 4090, if cross-compiling or building in CI without a GPU.
40+
`GGML_CUDA_FA_ALL_QUANTS=ON` is required for TurboQuant and TCQ cache types. `GGML_CUDA_FA_HALF_QUANTS=ON` is an alternative that compiles only the useful K>=V half of the K/V pair matrix (compiling 91 FA vec K/V pairs instead of 169, reducing FA vec pair instances by ~46% vs ALL_QUANTS). These two flags are mutually exclusive. Add `-DCMAKE_CUDA_ARCHITECTURES=86` for RTX 3090, or `-DCMAKE_CUDA_ARCHITECTURES=89` for RTX 4090, if cross-compiling or building in CI without a GPU.
4141

4242
Key binaries: `build/bin/llama-server`, `build/bin/llama-cli`, `build/bin/llama-bench`, `build/bin/llama-perplexity`.
4343

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ cmake -B build -DGGML_METAL=ON -DCMAKE_BUILD_TYPE=Release
358358
cmake --build build -j
359359
```
360360

361-
`GGML_CUDA_FA_ALL_QUANTS=ON` is required for TurboQuant and TCQ cache types. Add `-DCMAKE_CUDA_ARCHITECTURES=86` for RTX 3090, or `-DCMAKE_CUDA_ARCHITECTURES=89` for RTX 4090, if cross-compiling or building in CI without a GPU.
361+
`GGML_CUDA_FA_ALL_QUANTS=ON` is required for TurboQuant and TCQ cache types. `GGML_CUDA_FA_HALF_QUANTS=ON` is an alternative that compiles only the useful K>=V half of the K/V pair matrix (compiling 91 FA vec K/V pairs instead of 169, reducing FA vec pair instances by ~46% vs ALL_QUANTS). These two flags are mutually exclusive.
362362

363363
### Other Backends
364364

docs/build.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ The following compilation options are also available to tweak performance:
297297
| GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, CDNA and RDNA3+). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. |
298298
| GGML_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models. There may be issues with numerical overflows (except for V100, CDNA and RDNA4 which use FP32 compute type by default) and memory use will be higher. Prompt processing may become faster on recent datacenter GPUs (the custom kernels were tuned primarily for RTX 3000/4000). |
299299
| GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |
300-
| GGML_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. |
300+
| GGML_CUDA_FA_ALL_QUANTS | Boolean | false | Compile CUDA FlashAttention vec kernels for the full supported K/V cache type matrix (169 pairs for the 13-type universe). Mutually exclusive with GGML_CUDA_FA_HALF_QUANTS. |
301+
| GGML_CUDA_FA_HALF_QUANTS | Boolean | false | Compile only the useful K>=V half of the K/V cache type matrix for FlashAttention vec kernels (91 pairs), where types are ranked from higher precision to lower: f16 > bf16 > q8_0 > q6_0 > q5_1 > q5_0 > turbo4 > q4_1 > q4_0 > turbo3_tcq > turbo3 > turbo2_tcq > turbo2. This avoids wasteful reversed asymmetric pairs. Mutually exclusive with GGML_CUDA_FA_ALL_QUANTS. |
301302

302303
## MUSA
303304

docs/quickstart-gemma-4-31b-dflash.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ cmake -B build -DGGML_CUDA=ON -DGGML_NATIVE=ON \
7171
cmake --build build -j
7272
```
7373

74-
`GGML_CUDA_FA_ALL_QUANTS=ON` is required for TurboQuant and TCQ cache types. Add `-DCMAKE_CUDA_ARCHITECTURES=86` for RTX 3090, or `-DCMAKE_CUDA_ARCHITECTURES=89` for RTX 4090, if cross-compiling or building in CI without a GPU.
74+
`GGML_CUDA_FA_ALL_QUANTS=ON` is required for TurboQuant and TCQ cache types. `GGML_CUDA_FA_HALF_QUANTS=ON` is an alternative that compiles only the useful K>=V half of the K/V pair matrix (compiling 91 FA vec K/V pairs instead of 169, reducing FA vec pair instances by ~46% vs ALL_QUANTS). These two flags are mutually exclusive.
7575

7676
**macOS (Metal).**
7777

docs/quickstart-qwen36-dflash.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ cmake -B build -DGGML_CUDA=ON -DGGML_NATIVE=ON \
7171
cmake --build build -j
7272
```
7373

74-
`GGML_CUDA_FA_ALL_QUANTS=ON` is required for TurboQuant and TCQ cache types. Add `-DCMAKE_CUDA_ARCHITECTURES=86` for RTX 3090, or `-DCMAKE_CUDA_ARCHITECTURES=89` for RTX 4090, if cross-compiling or building in CI without a GPU.
74+
`GGML_CUDA_FA_ALL_QUANTS=ON` is required for TurboQuant and TCQ cache types. `GGML_CUDA_FA_HALF_QUANTS=ON` is an alternative that compiles only the useful K>=V half of the K/V pair matrix (compiling 91 FA vec K/V pairs instead of 169, reducing FA vec pair instances by ~46% vs ALL_QUANTS). These two flags are mutually exclusive.
7575

7676
**macOS (Metal).**
7777

ggml/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copie
206206
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
207207
option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON)
208208
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
209+
option(GGML_CUDA_FA_HALF_QUANTS "ggml: compile only K>=V half of KV cache quant pairs for FlashAttention" OFF)
209210
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
210211
option(GGML_CUDA_NCCL "ggml: use NVIDIA Collective Comm. Library" ON)
211212
set (GGML_CUDA_COMPRESSION_MODE "size" CACHE STRING

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 83 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -112,27 +112,93 @@ if (CUDAToolkit_FOUND)
112112
file(GLOB SRCS "template-instances/mmf*.cu")
113113
list(APPEND GGML_SOURCES_CUDA ${SRCS})
114114

115+
macro(ggml_add_fattn_vec_pair OUT_VAR PREFIX K_TYPE V_TYPE)
116+
set(_FILE "${PREFIX}/fattn-vec-instance-${K_TYPE}-${V_TYPE}.cu")
117+
if (NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/${_FILE}")
118+
message(FATAL_ERROR "Missing FA vec template instance: ${_FILE}")
119+
endif()
120+
list(APPEND ${OUT_VAR} "${_FILE}")
121+
endmacro()
122+
123+
if (GGML_CUDA_FA_ALL_QUANTS AND GGML_CUDA_FA_HALF_QUANTS)
124+
message(FATAL_ERROR
125+
"GGML_CUDA_FA_ALL_QUANTS and GGML_CUDA_FA_HALF_QUANTS are mutually exclusive. "
126+
"Use GGML_CUDA_FA_ALL_QUANTS for the full K/V pair matrix, or GGML_CUDA_FA_HALF_QUANTS for only K>=V pairs.")
127+
endif()
128+
129+
set(FA_VEC_PREFIX "template-instances")
130+
115131
if (GGML_CUDA_FA_ALL_QUANTS)
116-
file(GLOB SRCS "template-instances/fattn-vec*.cu")
132+
file(GLOB SRCS "${FA_VEC_PREFIX}/fattn-vec-instance-*.cu")
117133
list(APPEND GGML_SOURCES_CUDA ${SRCS})
118134
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
135+
elseif (GGML_CUDA_FA_HALF_QUANTS)
136+
set(GGML_CUDA_FA_KV_TYPES_ORDERED
137+
f16
138+
bf16
139+
q8_0
140+
q6_0
141+
q5_1
142+
q5_0
143+
turbo4_0
144+
q4_1
145+
q4_0
146+
turbo3_tcq
147+
turbo3_0
148+
turbo2_tcq
149+
turbo2_0
150+
)
151+
152+
list(LENGTH GGML_CUDA_FA_KV_TYPES_ORDERED GGML_CUDA_FA_KV_TYPES_LEN)
153+
math(EXPR GGML_CUDA_FA_KV_TYPES_LAST "${GGML_CUDA_FA_KV_TYPES_LEN} - 1")
154+
155+
foreach(KI RANGE 0 ${GGML_CUDA_FA_KV_TYPES_LAST})
156+
list(GET GGML_CUDA_FA_KV_TYPES_ORDERED ${KI} K_TYPE)
157+
158+
foreach(VI RANGE ${KI} ${GGML_CUDA_FA_KV_TYPES_LAST})
159+
list(GET GGML_CUDA_FA_KV_TYPES_ORDERED ${VI} V_TYPE)
160+
ggml_add_fattn_vec_pair(GGML_SOURCES_CUDA "${FA_VEC_PREFIX}" "${K_TYPE}" "${V_TYPE}")
161+
endforeach()
162+
endforeach()
163+
164+
add_compile_definitions(GGML_CUDA_FA_HALF_QUANTS)
119165
else()
120-
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
121-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
122-
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
123-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
124-
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
125-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
126-
file(GLOB SRCS "template-instances/fattn-vec*bf16-bf16.cu")
127-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
128-
file(GLOB SRCS "template-instances/fattn-vec*turbo2_0*.cu")
129-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
130-
file(GLOB SRCS "template-instances/fattn-vec*turbo3_0*.cu")
131-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
132-
file(GLOB SRCS "template-instances/fattn-vec*turbo*_tcq*.cu")
133-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
134-
file(GLOB SRCS "template-instances/fattn-vec*turbo4_0*.cu")
135-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
166+
set(GGML_CUDA_FA_DEFAULT_KV_PAIRS
167+
f16:f16
168+
q4_0:q4_0
169+
q8_0:q8_0
170+
bf16:bf16
171+
turbo2_0:turbo2_0
172+
turbo3_0:turbo3_0
173+
turbo4_0:turbo4_0
174+
turbo3_tcq:turbo3_tcq
175+
turbo2_tcq:turbo2_tcq
176+
turbo2_0:q8_0
177+
turbo3_0:q8_0
178+
turbo4_0:q8_0
179+
q8_0:turbo2_0
180+
q8_0:turbo3_0
181+
q8_0:turbo4_0
182+
turbo4_0:turbo3_0
183+
turbo3_0:turbo4_0
184+
turbo2_0:turbo3_0
185+
turbo3_0:turbo2_0
186+
turbo3_tcq:q8_0
187+
turbo2_tcq:q8_0
188+
q8_0:turbo3_tcq
189+
q8_0:turbo2_tcq
190+
turbo3_tcq:turbo2_tcq
191+
turbo2_tcq:turbo3_tcq
192+
turbo4_0:turbo3_tcq
193+
turbo3_0:turbo3_tcq
194+
)
195+
196+
foreach(PAIR ${GGML_CUDA_FA_DEFAULT_KV_PAIRS})
197+
string(REPLACE ":" ";" PARTS "${PAIR}")
198+
list(GET PARTS 0 K_TYPE)
199+
list(GET PARTS 1 V_TYPE)
200+
ggml_add_fattn_vec_pair(GGML_SOURCES_CUDA "${FA_VEC_PREFIX}" "${K_TYPE}" "${V_TYPE}")
201+
endforeach()
136202
endif()
137203

138204
ggml_add_backend_library(ggml-cuda

0 commit comments

Comments
 (0)