Skip to content

Commit ec76e2a

Browse files
committed
Fixed Upstream merger breaking tubroquant
1 parent ea26227 commit ec76e2a

1 file changed

Lines changed: 125 additions & 0 deletions

File tree

fix.patch

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
2+
index f96739657..e6a919e1e 100644
3+
--- a/ggml/src/ggml-cuda/fattn.cu
4+
+++ b/ggml/src/ggml-cuda/fattn.cu
5+
@@ -236,6 +236,17 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
6+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
7+
}
8+
} break;
9+
+ case 640: {
10+
+ // Padded turbo KV cache for GLM-4.7 Flash (K head_dim=576 zero-padded to 640).
11+
+ // D=640 shared memory (Q storage = ncols*(DKQ/2+4)*4) exceeds hardware limit at ncols1>=4.
12+
+ // Cap at ncols1=2 (ncols=32): Q=32*324*4=41KB + KV≈37KB = ~78KB total.
13+
+ GGML_ASSERT(V->ne[0] == 512);
14+
+ if (Q->ne[1] <= 1) {
15+
+ ggml_cuda_flash_attn_ext_mma_f16_case<640, 512, 1, 16>(ctx, dst);
16+
+ } else {
17+
+ ggml_cuda_flash_attn_ext_mma_f16_case<640, 512, 2, 16>(ctx, dst);
18+
+ }
19+
+ } break;
20+
default:
21+
GGML_ABORT("fatal error");
22+
break;
23+
@@ -325,6 +336,51 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
24+
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
25+
#endif // GGML_CUDA_FA_ALL_QUANTS
26+
27+
+ // TurboQuant3 KV cache types (always enabled)
28+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO3_0)
29+
+
30+
+ // Mixed turbo3/q8_0 KV cache types
31+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_Q8_0)
32+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0)
33+
+
34+
+ // Mixed f16/turbo3 KV cache types
35+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_TURBO3_0)
36+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_F16)
37+
+
38+
+ // TurboQuant2 KV cache types (always enabled)
39+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO2_0)
40+
+
41+
+ // Mixed turbo2/q8_0 KV cache types
42+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_Q8_0)
43+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0)
44+
+
45+
+ // Mixed f16/turbo2 KV cache types
46+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_TURBO2_0)
47+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_F16)
48+
+
49+
+ // Mixed turbo3/turbo2 KV cache types
50+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0)
51+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0)
52+
+
53+
+ // TurboQuant4 KV cache types (always enabled)
54+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0)
55+
+
56+
+ // Mixed turbo4/q8_0 KV cache types
57+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_Q8_0)
58+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0)
59+
+
60+
+ // Mixed f16/turbo4 KV cache types
61+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_TURBO4_0)
62+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_F16)
63+
+
64+
+ // Mixed turbo4/turbo3 KV cache types
65+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO3_0)
66+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO4_0)
67+
+
68+
+ // Mixed turbo4/turbo2 KV cache types
69+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO2_0)
70+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0)
71+
+
72+
GGML_ABORT("fatal error");
73+
}
74+
75+
@@ -410,6 +466,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
76+
}
77+
break;
78+
case 576:
79+
+ case 640:
80+
if (V->ne[0] != 512) {
81+
return BEST_FATTN_KERNEL_NONE;
82+
}
83+
@@ -423,7 +480,16 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
84+
85+
#ifndef GGML_CUDA_FA_ALL_QUANTS
86+
if (K->type != V->type) {
87+
- return BEST_FATTN_KERNEL_NONE;
88+
+ // Allow mixed KV types for combinations that have FA template instances compiled in:
89+
+ // - turbo2/3/4 + q8_0 (turbo cache work)
90+
+ // - f16/bf16 + q8_0 (common K=f16, V=q8_0 setup)
91+
+ auto is_kv_compat = [](ggml_type t) {
92+
+ return t == GGML_TYPE_TURBO2_0 || t == GGML_TYPE_TURBO3_0 || t == GGML_TYPE_TURBO4_0
93+
+ || t == GGML_TYPE_Q8_0 || t == GGML_TYPE_F16 || t == GGML_TYPE_BF16;
94+
+ };
95+
+ if (!is_kv_compat(K->type) || !is_kv_compat(V->type)) {
96+
+ return BEST_FATTN_KERNEL_NONE;
97+
+ }
98+
}
99+
#endif // GGML_CUDA_FA_ALL_QUANTS
100+
101+
@@ -441,6 +507,24 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
102+
case GGML_TYPE_Q8_0:
103+
case GGML_TYPE_BF16:
104+
break;
105+
+ case GGML_TYPE_TURBO3_0:
106+
+ // turbo3 VEC kernel instantiated for D in {64, 128, 256}.
107+
+ if (K->ne[0] % 64 != 0) {
108+
+ return BEST_FATTN_KERNEL_NONE;
109+
+ }
110+
+ break;
111+
+ case GGML_TYPE_TURBO2_0:
112+
+ // turbo2 VEC kernel instantiated for D in {64, 128, 256}.
113+
+ if (K->ne[0] % 64 != 0) {
114+
+ return BEST_FATTN_KERNEL_NONE;
115+
+ }
116+
+ break;
117+
+ case GGML_TYPE_TURBO4_0:
118+
+ // turbo4 VEC kernel instantiated for D in {64, 128, 256}.
119+
+ if (K->ne[0] % 64 != 0) {
120+
+ return BEST_FATTN_KERNEL_NONE;
121+
+ }
122+
+ break;
123+
default:
124+
return BEST_FATTN_KERNEL_NONE;
125+
}

0 commit comments

Comments
 (0)