|
| 1 | +diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu |
| 2 | +index f96739657..e6a919e1e 100644 |
| 3 | +--- a/ggml/src/ggml-cuda/fattn.cu |
| 4 | ++++ b/ggml/src/ggml-cuda/fattn.cu |
| 5 | +@@ -236,6 +236,17 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg |
| 6 | + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); |
| 7 | + } |
| 8 | + } break; |
| 9 | ++ case 640: { |
| 10 | ++ // Padded turbo KV cache for GLM-4.7 Flash (K head_dim=576 zero-padded to 640). |
| 11 | ++ // D=640 shared memory (Q storage = ncols*(DKQ/2+4)*4) exceeds hardware limit at ncols1>=4. |
| 12 | ++ // Cap at ncols1=2 (ncols=32): Q=32*324*4=41KB + KV≈37KB = ~78KB total. |
| 13 | ++ GGML_ASSERT(V->ne[0] == 512); |
| 14 | ++ if (Q->ne[1] <= 1) { |
| 15 | ++ ggml_cuda_flash_attn_ext_mma_f16_case<640, 512, 1, 16>(ctx, dst); |
| 16 | ++ } else { |
| 17 | ++ ggml_cuda_flash_attn_ext_mma_f16_case<640, 512, 2, 16>(ctx, dst); |
| 18 | ++ } |
| 19 | ++ } break; |
| 20 | + default: |
| 21 | + GGML_ABORT("fatal error"); |
| 22 | + break; |
| 23 | +@@ -325,6 +336,51 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t |
| 24 | + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) |
| 25 | + #endif // GGML_CUDA_FA_ALL_QUANTS |
| 26 | + |
| 27 | ++ // TurboQuant3 KV cache types (always enabled) |
| 28 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO3_0) |
| 29 | ++ |
| 30 | ++ // Mixed turbo3/q8_0 KV cache types |
| 31 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_Q8_0) |
| 32 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0) |
| 33 | ++ |
| 34 | ++ // Mixed f16/turbo3 KV cache types |
| 35 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_TURBO3_0) |
| 36 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_F16) |
| 37 | ++ |
| 38 | ++ // TurboQuant2 KV cache types (always enabled) |
| 39 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO2_0) |
| 40 | ++ |
| 41 | ++ // Mixed turbo2/q8_0 KV cache types |
| 42 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_Q8_0) |
| 43 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0) |
| 44 | ++ |
| 45 | ++ // Mixed f16/turbo2 KV cache types |
| 46 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_TURBO2_0) |
| 47 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_F16) |
| 48 | ++ |
| 49 | ++ // Mixed turbo3/turbo2 KV cache types |
| 50 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0) |
| 51 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0) |
| 52 | ++ |
| 53 | ++ // TurboQuant4 KV cache types (always enabled) |
| 54 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0) |
| 55 | ++ |
| 56 | ++ // Mixed turbo4/q8_0 KV cache types |
| 57 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_Q8_0) |
| 58 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0) |
| 59 | ++ |
| 60 | ++ // Mixed f16/turbo4 KV cache types |
| 61 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_TURBO4_0) |
| 62 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_F16) |
| 63 | ++ |
| 64 | ++ // Mixed turbo4/turbo3 KV cache types |
| 65 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO3_0) |
| 66 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO4_0) |
| 67 | ++ |
| 68 | ++ // Mixed turbo4/turbo2 KV cache types |
| 69 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO2_0) |
| 70 | ++ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0) |
| 71 | ++ |
| 72 | + GGML_ABORT("fatal error"); |
| 73 | + } |
| 74 | + |
| 75 | +@@ -410,6 +466,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const |
| 76 | + } |
| 77 | + break; |
| 78 | + case 576: |
| 79 | ++ case 640: |
| 80 | + if (V->ne[0] != 512) { |
| 81 | + return BEST_FATTN_KERNEL_NONE; |
| 82 | + } |
| 83 | +@@ -423,7 +480,16 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const |
| 84 | + |
| 85 | + #ifndef GGML_CUDA_FA_ALL_QUANTS |
| 86 | + if (K->type != V->type) { |
| 87 | +- return BEST_FATTN_KERNEL_NONE; |
| 88 | ++ // Allow mixed KV types for combinations that have FA template instances compiled in: |
| 89 | ++ // - turbo2/3/4 + q8_0 (turbo cache work) |
| 90 | ++ // - f16/bf16 + q8_0 (common K=f16, V=q8_0 setup) |
| 91 | ++ auto is_kv_compat = [](ggml_type t) { |
| 92 | ++ return t == GGML_TYPE_TURBO2_0 || t == GGML_TYPE_TURBO3_0 || t == GGML_TYPE_TURBO4_0 |
| 93 | ++ || t == GGML_TYPE_Q8_0 || t == GGML_TYPE_F16 || t == GGML_TYPE_BF16; |
| 94 | ++ }; |
| 95 | ++ if (!is_kv_compat(K->type) || !is_kv_compat(V->type)) { |
| 96 | ++ return BEST_FATTN_KERNEL_NONE; |
| 97 | ++ } |
| 98 | + } |
| 99 | + #endif // GGML_CUDA_FA_ALL_QUANTS |
| 100 | + |
| 101 | +@@ -441,6 +507,24 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const |
| 102 | + case GGML_TYPE_Q8_0: |
| 103 | + case GGML_TYPE_BF16: |
| 104 | + break; |
| 105 | ++ case GGML_TYPE_TURBO3_0: |
| 106 | ++ // turbo3 VEC kernel instantiated for D in {64, 128, 256}. |
| 107 | ++ if (K->ne[0] % 64 != 0) { |
| 108 | ++ return BEST_FATTN_KERNEL_NONE; |
| 109 | ++ } |
| 110 | ++ break; |
| 111 | ++ case GGML_TYPE_TURBO2_0: |
| 112 | ++ // turbo2 VEC kernel instantiated for D in {64, 128, 256}. |
| 113 | ++ if (K->ne[0] % 64 != 0) { |
| 114 | ++ return BEST_FATTN_KERNEL_NONE; |
| 115 | ++ } |
| 116 | ++ break; |
| 117 | ++ case GGML_TYPE_TURBO4_0: |
| 118 | ++ // turbo4 VEC kernel instantiated for D in {64, 128, 256}. |
| 119 | ++ if (K->ne[0] % 64 != 0) { |
| 120 | ++ return BEST_FATTN_KERNEL_NONE; |
| 121 | ++ } |
| 122 | ++ break; |
| 123 | + default: |
| 124 | + return BEST_FATTN_KERNEL_NONE; |
| 125 | + } |
0 commit comments