@@ -236,6 +236,17 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
236236 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576 , 512 , 4 >(ctx, dst);
237237 }
238238 } break ;
239+ case 640 : {
240+ // Padded turbo KV cache for GLM-4.7 Flash (K head_dim=576 zero-padded to 640).
241+ // D=640 shared memory (Q storage = ncols*(DKQ/2+4)*4) exceeds hardware limit at ncols1>=4.
242+ // Cap at ncols1=2 (ncols=32): Q=32*324*4=41KB + KV≈37KB = ~78KB total.
243+ GGML_ASSERT (V->ne [0 ] == 512 );
244+ if (Q->ne [1 ] <= 1 ) {
245+ ggml_cuda_flash_attn_ext_mma_f16_case<640 , 512 , 1 , 16 >(ctx, dst);
246+ } else {
247+ ggml_cuda_flash_attn_ext_mma_f16_case<640 , 512 , 2 , 16 >(ctx, dst);
248+ }
249+ } break ;
239250 default :
240251 GGML_ABORT (" fatal error" );
241252 break ;
@@ -325,6 +336,51 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
325336 FATTN_VEC_CASES_ALL_D (GGML_TYPE_BF16, GGML_TYPE_BF16)
326337#endif // GGML_CUDA_FA_ALL_QUANTS
327338
339+ // TurboQuant3 KV cache types (always enabled)
340+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO3_0)
341+
342+ // Mixed turbo3/q8_0 KV cache types
343+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_TURBO3_0, GGML_TYPE_Q8_0)
344+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0)
345+
346+ // Mixed f16/turbo3 KV cache types
347+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_F16, GGML_TYPE_TURBO3_0)
348+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_TURBO3_0, GGML_TYPE_F16)
349+
350+ // TurboQuant2 KV cache types (always enabled)
351+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO2_0)
352+
353+ // Mixed turbo2/q8_0 KV cache types
354+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_TURBO2_0, GGML_TYPE_Q8_0)
355+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0)
356+
357+ // Mixed f16/turbo2 KV cache types
358+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_F16, GGML_TYPE_TURBO2_0)
359+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_TURBO2_0, GGML_TYPE_F16)
360+
361+ // Mixed turbo3/turbo2 KV cache types
362+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0)
363+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0)
364+
365+ // TurboQuant4 KV cache types (always enabled)
366+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0)
367+
368+ // Mixed turbo4/q8_0 KV cache types
369+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_TURBO4_0, GGML_TYPE_Q8_0)
370+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0)
371+
372+ // Mixed f16/turbo4 KV cache types
373+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_F16, GGML_TYPE_TURBO4_0)
374+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_TURBO4_0, GGML_TYPE_F16)
375+
376+ // Mixed turbo4/turbo3 KV cache types
377+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO3_0)
378+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO4_0)
379+
380+ // Mixed turbo4/turbo2 KV cache types
381+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO2_0)
382+ FATTN_VEC_CASES_ALL_D (GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0)
383+
328384 GGML_ABORT (" fatal error" );
329385}
330386
@@ -410,6 +466,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
410466 }
411467 break ;
412468 case 576 :
469+ case 640 :
413470 if (V->ne [0 ] != 512 ) {
414471 return BEST_FATTN_KERNEL_NONE;
415472 }
@@ -423,7 +480,16 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
423480
424481#ifndef GGML_CUDA_FA_ALL_QUANTS
425482 if (K->type != V->type ) {
426- return BEST_FATTN_KERNEL_NONE;
483+ // Allow mixed KV types for combinations that have FA template instances compiled in:
484+ // - turbo2/3/4 + q8_0 (turbo cache work)
485+ // - f16/bf16 + q8_0 (common K=f16, V=q8_0 setup)
486+ auto is_kv_compat = [](ggml_type t) {
487+ return t == GGML_TYPE_TURBO2_0 || t == GGML_TYPE_TURBO3_0 || t == GGML_TYPE_TURBO4_0
488+ || t == GGML_TYPE_Q8_0 || t == GGML_TYPE_F16 || t == GGML_TYPE_BF16;
489+ };
490+ if (!is_kv_compat (K->type ) || !is_kv_compat (V->type )) {
491+ return BEST_FATTN_KERNEL_NONE;
492+ }
427493 }
428494#endif // GGML_CUDA_FA_ALL_QUANTS
429495
@@ -441,6 +507,24 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
441507 case GGML_TYPE_Q8_0:
442508 case GGML_TYPE_BF16:
443509 break ;
510+ case GGML_TYPE_TURBO3_0:
511+ // turbo3 VEC kernel instantiated for D in {64, 128, 256}.
512+ if (K->ne [0 ] % 64 != 0 ) {
513+ return BEST_FATTN_KERNEL_NONE;
514+ }
515+ break ;
516+ case GGML_TYPE_TURBO2_0:
517+ // turbo2 VEC kernel instantiated for D in {64, 128, 256}.
518+ if (K->ne [0 ] % 64 != 0 ) {
519+ return BEST_FATTN_KERNEL_NONE;
520+ }
521+ break ;
522+ case GGML_TYPE_TURBO4_0:
523+ // turbo4 VEC kernel instantiated for D in {64, 128, 256}.
524+ if (K->ne [0 ] % 64 != 0 ) {
525+ return BEST_FATTN_KERNEL_NONE;
526+ }
527+ break ;
444528 default :
445529 return BEST_FATTN_KERNEL_NONE;
446530 }
0 commit comments