Skip to content

Commit 6fa65bd

Browse files
committed
Fix q8 Turbo V CUDA FA decode routing
Route Turbo-V decode cases with q8_0 K at D>=256 away from the unsafe vector FlashAttention path, matching the existing classic-K guard without broadening classic non-q8 semantics. Add route-policy coverage so q8_0 is included only in the Turbo-V unsafe-K policy.
1 parent 3975b51 commit 6fa65bd

2 files changed

Lines changed: 30 additions & 4 deletions

File tree

ggml/src/ggml-cuda/fattn.cu

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,14 +1471,19 @@ static inline bool ggml_cuda_fattn_prefill_mma_can_materialize_turbo_k_classic_v
14711471
ggml_cuda_fattn_is_classic_non_q8_type(V->type);
14721472
}
14731473

1474+
static inline bool ggml_cuda_fattn_is_turbo_v_decode_unsafe_k_type(const ggml_type type) {
1475+
return type == GGML_TYPE_Q8_0 ||
1476+
ggml_cuda_fattn_is_classic_non_q8_type(type);
1477+
}
1478+
14741479
// Shape guard for the effective K/V pair after Turbo V decode-dequant.
1475-
// Gemma-like D>=256 with classic_K/f16 (non-q8) is unsafe on the vec path.
1480+
// D>=256 with classic-or-q8 K/f16 V is unsafe on the vec path.
14761481
// Only applied when V was actually decoded from Turbo — explicit q5_0/f16
14771482
// at D>=256 is unaffected. D=128 is safe on vec and not gated.
14781483
static inline bool ggml_cuda_fattn_effective_vec_shape_unsafe(
14791484
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V) {
14801485
return Q->ne[0] >= 256 &&
1481-
ggml_cuda_fattn_is_classic_non_q8_type(K->type) &&
1486+
ggml_cuda_fattn_is_turbo_v_decode_unsafe_k_type(K->type) &&
14821487
V->type == GGML_TYPE_F16;
14831488
}
14841489

@@ -2502,8 +2507,8 @@ static ggml_cuda_fattn_route_plan ggml_cuda_fattn_make_route_plan(const int devi
25022507
}
25032508

25042509
// If V was decoded from Turbo to f16 and the effective pair is
2505-
// classic_K/f16 at D>=256, the vec path is unsafe. Only gate vec for
2506-
// Turbo-originated f16 V — explicit q5_0/f16 at D>=256 is unaffected.
2510+
// classic-or-q8 K/f16 at D>=256, the vec path is unsafe. Only gate vec
2511+
// for Turbo-originated f16 V — explicit q5_0/f16 at D>=256 is unaffected.
25072512
// Disable vec so the existing kernel selector picks MMA_F16 or tile with
25082513
// generic f16 K conversion. D=128 is fine on vec and is not affected.
25092514
plan.unsafe_vec_after_turbo_v_decode =

tests/test-cuda-fattn-route-policy.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@ int main(int argc, char ** argv) {
5757
const std::string prefill_policy = slice_between(fattn,
5858
"static inline bool ggml_cuda_fattn_prefill_mma_can_materialize_turbo_k_classic_v",
5959
"// Shape guard for the effective K/V pair after Turbo V decode-dequant.");
60+
const std::string classic_non_q8 = slice_between(fattn,
61+
"static inline bool ggml_cuda_fattn_is_classic_non_q8_type",
62+
"static void ggml_cuda_fattn_materialize_to_f16");
63+
const std::string unsafe_k_helper = slice_between(fattn,
64+
"static inline bool ggml_cuda_fattn_is_turbo_v_decode_unsafe_k_type",
65+
"static inline bool ggml_cuda_fattn_effective_vec_shape_unsafe");
66+
const std::string unsafe_shape = slice_between(fattn,
67+
"static inline bool ggml_cuda_fattn_effective_vec_shape_unsafe",
68+
"static void ggml_cuda_flash_attn_ext_vec");
6069
const std::string exec = slice_between(fattn,
6170
"void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst)",
6271
"bool ggml_cuda_flash_attn_ext_support");
@@ -93,5 +102,17 @@ int main(int argc, char ** argv) {
93102
prefill_policy.find("ggml_cuda_fattn_is_classic_non_q8_type(V->type)") != std::string::npos,
94103
"Turbo K + classic V prefill eligibility must not broaden classic-K/Turbo-V routing");
95104

105+
ok &= expect(!classic_non_q8.empty() &&
106+
classic_non_q8.find("GGML_TYPE_Q8_0") == std::string::npos,
107+
"classic non-q8 helper must not be broadened to include q8_0");
108+
ok &= expect(!unsafe_k_helper.empty() &&
109+
unsafe_k_helper.find("GGML_TYPE_Q8_0") != std::string::npos &&
110+
unsafe_k_helper.find("ggml_cuda_fattn_is_classic_non_q8_type(type)") != std::string::npos,
111+
"Turbo V decode unsafe-K policy must cover q8_0 plus classic non-q8 K types");
112+
ok &= expect(!unsafe_shape.empty() &&
113+
unsafe_shape.find("ggml_cuda_fattn_is_turbo_v_decode_unsafe_k_type(K->type)") != std::string::npos &&
114+
unsafe_shape.find("V->type == GGML_TYPE_F16") != std::string::npos,
115+
"Turbo V decode shape guard must use the unsafe-K policy for effective f16 V");
116+
96117
return ok ? 0 : 1;
97118
}

0 commit comments

Comments
 (0)