Skip to content

Commit c91d49f

Browse files
ogad-tetherclaude
andcommitted
tts-cpp: chatterbox-mtl — drop f32-fallback KV guard (superseded by align-cast)
The align-probe dequant fix makes a quantized KV cache run on the GPU for the MTL variant, so the f32-fallback guard added as the stopgap is no longer needed and would otherwise force f32 and negate the fix. Remove it: - chatterbox_mtl_guard_kv_type (decl in chatterbox_t3_internal.h, def in main.cpp) and its call in load_model_gguf_mtl. - its pass-through unit asserts in test_kv_cache_type. Vulkan quantized K/V is still force-f32'd inside chatterbox_resolve_kv_type (separate coopmat2 issue) — untouched. Repurpose the test_metal_ops sentinel: it now also asserts that CAST(q8_0 strided -> f32) IS supported on Metal — the op the align-probe fix relies on — so a future ggml regression that breaks the dequant cast fails the test loudly (CONT(q8_0) staying unsupported is now informational, not a hard fail). Refs QVAC-19557 Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 9f270dc commit c91d49f

5 files changed

Lines changed: 50 additions & 83 deletions

File tree

tts-cpp/src/chatterbox_t3_internal.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -119,19 +119,6 @@ ggml_type chatterbox_kv_type_from_str(const std::string & s);
119119
ggml_type chatterbox_resolve_kv_type(ggml_backend_t backend, ggml_type requested,
120120
int head_dim, int n_head, int n_kv_head);
121121

122-
// MTL-variant-only guard (QVAC-19557): the multilingual variant's batched-CFG
123-
// (B=2) decode reads the token-major K/V cache as a 4D strided view, which the
124-
// GPU flash-attn path materialises through a CONT. ggml-metal has no CONT
125-
// kernel for quantized tensors, so a quantized KV cache SIGABRTs at encode time
126-
// on Metal (the MTL path runs a single-backend graph_compute, so the scheduler
127-
// never gets to fall the op back to CPU). This restricts a quantized `kv_type`
128-
// to the CPU backend and returns GGML_TYPE_F32 on any GPU backend; non-quantized
129-
// types and a null/CPU backend pass through unchanged. Pure (no I/O) so the
130-
// caller logs the downgrade and so it stays unit-testable. The Turbo variant
131-
// uses a different eval path that does not hit the CONT and must NOT be routed
132-
// through this guard.
133-
ggml_type chatterbox_mtl_guard_kv_type(ggml_backend_t backend, ggml_type kv_type);
134-
135122
struct gpt2_layer {
136123
ggml_tensor * ln_1_g = nullptr;
137124
ggml_tensor * ln_1_b = nullptr;

tts-cpp/src/main.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -402,20 +402,6 @@ ggml_type chatterbox_resolve_kv_type(ggml_backend_t backend, ggml_type requested
402402
return requested;
403403
}
404404

405-
ggml_type chatterbox_mtl_guard_kv_type(ggml_backend_t backend, ggml_type kv_type) {
406-
// Quantized K/V is only safe on CPU for the MTL variant: the GPU flash-attn
407-
// path CONTs the strided quantized K/V cache, and ggml-metal has no CONT
408-
// kernel for quantized tensors (the resolve probe above validates
409-
// flash_attn_ext but not the downstream CONT, so it can't catch this). Gate
410-
// on "not CPU" by device type rather than a backend name so it stays robust
411-
// across ggml builds whose Metal registry name differs ("Metal" vs "MTL").
412-
if (ggml_is_quantized(kv_type) && backend &&
413-
!::tts_cpp::detail::backend_is_cpu(backend)) {
414-
return GGML_TYPE_F32;
415-
}
416-
return kv_type;
417-
}
418-
419405
bool load_model_gguf(const std::string & path, chatterbox_model & model, int requested_ctx, int n_gpu_layers, ggml_type kv_type) {
420406
{
421407
gguf_init_params peek_params = { /*.no_alloc=*/ true, /*.ctx=*/ nullptr };

tts-cpp/src/t3_mtl.cpp

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1839,23 +1839,16 @@ bool load_model_gguf_mtl(const std::string & path,
18391839
// attention with the requested quantized/f16 K/V.
18401840
hp.kv_type = chatterbox_resolve_kv_type(model.backend, kv_type,
18411841
hp.head_dim, hp.n_head, hp.n_kv_head);
1842-
// QVAC-19557: the MTL variant's batched-CFG (B=2) decode CONTs the
1843-
// strided quantized K/V cache, which ggml-metal can't do (no quantized
1844-
// CONT kernel) — so a quantized KV cache SIGABRTs at eval_step_mtl
1845-
// ("unsupported op 'CONT'") on Metal. The resolve probe above only
1846-
// validates flash_attn_ext, not the downstream CONT, so the guard below
1847-
// restricts quantized K/V to the CPU backend. See
1848-
// chatterbox_mtl_guard_kv_type for the full rationale; it is pure so we
1849-
// log the downgrade here.
1850-
{
1851-
const ggml_type guarded = chatterbox_mtl_guard_kv_type(model.backend, hp.kv_type);
1852-
if (guarded != hp.kv_type) {
1853-
fprintf(stderr, "chatterbox(mtl): quantized (%s) KV cache is only supported on the "
1854-
"CPU backend for the multilingual variant (GPU CONT on quantized "
1855-
"K/V is unsupported); using f32 KV cache\n", ggml_type_name(hp.kv_type));
1856-
hp.kv_type = guarded;
1857-
}
1858-
}
1842+
// QVAC-19557: a quantized (q8_0) KV cache used to SIGABRT on Metal
1843+
// ("unsupported op 'CONT'"). The cause was NOT flash-attention (which
1844+
// reads the q8 strided cache fine on Metal) but the per-(layer,head)
1845+
// alignment probe in build_llama_block, which ggml_cont'd a strided view
1846+
// of the quantized K cache to feed a mul_mat — and ggml-metal has no CONT
1847+
// kernel for quantized tensors. That cont is now a dequantizing
1848+
// ggml_cast to f32 (Metal-supported), so quantized K/V runs on the GPU
1849+
// for the MTL variant and no f32 fallback guard is needed here. Vulkan
1850+
// quantized K/V is still force-f32'd inside chatterbox_resolve_kv_type
1851+
// (separate coopmat2 issue).
18591852
ggml_init_params kv_params = { ggml_tensor_overhead() * 4, nullptr, true };
18601853
model.ctx_kv = ggml_init(kv_params);
18611854
const int64_t kv_elements_b2 =

tts-cpp/test/test_kv_cache_type.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -66,22 +66,6 @@ int main() {
6666
CHECK(chatterbox_resolve_kv_type(cpu, GGML_TYPE_Q8_0, head_dim, n_head, n_kv_head)
6767
== GGML_TYPE_Q8_0, "cpu retains q8_0 KV");
6868

69-
// ---- MTL guard (QVAC-19557): quantized K/V only on CPU ----
70-
// The multilingual variant's batched-CFG decode CONTs the strided quantized
71-
// K/V cache, which ggml-metal can't do; the guard restricts quantized K/V to
72-
// the CPU backend. Here we cover the pass-through branches that hold on any
73-
// runner; the GPU->f32 downgrade is covered (Metal) in test_metal_ops.cpp.
74-
CHECK(chatterbox_mtl_guard_kv_type(cpu, GGML_TYPE_Q8_0) == GGML_TYPE_Q8_0,
75-
"mtl guard: cpu keeps q8_0 (cpu has the quantized CONT kernel)");
76-
CHECK(chatterbox_mtl_guard_kv_type(cpu, GGML_TYPE_F16) == GGML_TYPE_F16,
77-
"mtl guard: cpu keeps f16");
78-
CHECK(chatterbox_mtl_guard_kv_type(cpu, GGML_TYPE_F32) == GGML_TYPE_F32,
79-
"mtl guard: cpu keeps f32");
80-
// Non-quantized types are never downgraded regardless of backend, and a null
81-
// backend is a no-op (null->f32 is chatterbox_resolve_kv_type's job upstream).
82-
CHECK(chatterbox_mtl_guard_kv_type(nullptr, GGML_TYPE_Q8_0) == GGML_TYPE_Q8_0,
83-
"mtl guard: null backend is a no-op");
84-
8569
ggml_backend_free(cpu);
8670

8771
if (g_failures) {

tts-cpp/test/test_metal_ops.cpp

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -335,47 +335,64 @@ static int test_mul_mm_fused(ggml_backend_t cpu, ggml_backend_t gpu,
335335
return 1;
336336
}
337337

338-
// QVAC-19557: regression sentinel for the MTL Metal q8-KV SIGABRT. The
339-
// multilingual Chatterbox variant's batched-CFG (B=2) decode reads the
340-
// token-major K/V cache as a strided 4D view, which the GPU flash-attn path
341-
// materialises through a CONT. ggml-metal has no CONT kernel for quantized
338+
// QVAC-19557: regression sentinel for the MTL Metal q8-KV SIGABRT. With a
339+
// quantized KV cache, the multilingual Chatterbox variant's per-(layer,head)
340+
// alignment probe (build_llama_block) read a strided view of the q8 K cache and
341+
// CONT'd it to feed a mul_mat. ggml-metal has no CONT kernel for quantized
342342
// tensors, so that op is unsupported on Metal — and because the MTL path runs a
343-
// single-backend graph_compute (no scheduler fallback) it crashes at encode
344-
// time. chatterbox_mtl_guard_kv_type exists precisely for this; here we assert
345-
// the underlying ggml limitation directly so this test TRIPS the day ggml grows
346-
// a quantized CONT kernel, at which point the guard can be relaxed and GPU q8 KV
347-
// revisited. The guard's fallback target (f32 CONT) and the CPU quantized CONT
348-
// must both stay supported.
343+
// single-backend graph_compute (no scheduler fallback) it crashed at encode
344+
// time. The fix replaced that ggml_cont with a dequantizing ggml_cast to f32.
345+
// This test pins the two ggml facts the fix depends on:
346+
// 1. CONT(q8_0 strided) is STILL unsupported on Metal — i.e. the plain cont we
347+
// removed really would crash (if this ever flips, the cast can become a
348+
// cheaper cont again).
349+
// 2. CAST(q8_0 strided -> f32) IS supported on Metal — the op the fix relies
350+
// on. If this ever regresses, the align probe would crash again, so the
351+
// test must fail loudly.
352+
// CPU must support both (the MTL variant also runs on CPU).
349353
static int test_quantized_cont_unsupported(ggml_backend_t cpu, ggml_backend_t gpu) {
350354
fprintf(stderr, "[quantized_cont] ");
351-
auto supports_cont = [](ggml_backend_t b, ggml_type t) {
355+
// Strided 4D view of a quantized src, mirroring the MTL token-major K/V read.
356+
auto make_view = [](ggml_context * ctx, ggml_type t) {
357+
ggml_tensor * src = ggml_new_tensor_4d(ctx, t, 64, 256, 16, 2);
358+
return ggml_view_4d(ctx, src, 64, 256, 16, 2,
359+
src->nb[1], src->nb[2] * 2, src->nb[3], 0);
360+
};
361+
auto supports_cont = [&](ggml_backend_t b, ggml_type t) {
352362
ggml_init_params p = { ggml_tensor_overhead() * 8, nullptr, /*no_alloc=*/true };
353363
ggml_context * ctx = ggml_init(p);
354-
// Strided 4D view of a quantized src -> cont, mirroring the MTL
355-
// batched-CFG (B=2) token-major K/V read in build_llama_block.
356-
ggml_tensor * src = ggml_new_tensor_4d(ctx, t, 64, 256, 16, 2);
357-
ggml_tensor * view = ggml_view_4d(ctx, src, 64, 256, 16, 2,
358-
src->nb[1], src->nb[2] * 2, src->nb[3], 0);
359-
bool sup = ggml_backend_supports_op(b, ggml_cont(ctx, view));
364+
bool sup = ggml_backend_supports_op(b, ggml_cont(ctx, make_view(ctx, t)));
365+
ggml_free(ctx);
366+
return sup;
367+
};
368+
auto supports_cast_f32 = [&](ggml_backend_t b, ggml_type t) {
369+
ggml_init_params p = { ggml_tensor_overhead() * 8, nullptr, /*no_alloc=*/true };
370+
ggml_context * ctx = ggml_init(p);
371+
bool sup = ggml_backend_supports_op(b, ggml_cast(ctx, make_view(ctx, t), GGML_TYPE_F32));
360372
ggml_free(ctx);
361373
return sup;
362374
};
363375
int fails = 0;
364376
if (supports_cont(gpu, GGML_TYPE_Q8_0)) {
365-
fprintf(stderr, "\n FAIL: Metal now advertises CONT(q8_0) — revisit the MTL KV guard "
366-
"(chatterbox_mtl_guard_kv_type); GPU q8 KV may be possible again\n");
377+
fprintf(stderr, "\n NOTE: Metal now advertises CONT(q8_0) — the align-probe cast "
378+
"could be simplified back to a cont (not a failure, but revisit)\n");
379+
// informational only; not a hard failure
380+
}
381+
if (!supports_cast_f32(gpu, GGML_TYPE_Q8_0)) {
382+
fprintf(stderr, "\n FAIL: Metal CAST(q8_0 strided -> f32) unsupported — the align-probe "
383+
"dequant fix (build_llama_block) would SIGABRT again\n");
367384
++fails;
368385
}
369-
if (!supports_cont(gpu, GGML_TYPE_F32)) {
370-
fprintf(stderr, "\n FAIL: Metal CONT(f32) unsupported — the MTL guard's f32 fallback target is broken\n");
386+
if (!supports_cast_f32(cpu, GGML_TYPE_Q8_0)) {
387+
fprintf(stderr, "\n FAIL: CPU CAST(q8_0 strided -> f32) unsupported — MTL on CPU would break\n");
371388
++fails;
372389
}
373390
if (!supports_cont(cpu, GGML_TYPE_Q8_0)) {
374-
fprintf(stderr, "\n FAIL: CPU CONT(q8_0) unsupported — MTL keeps q8 KV on CPU and would break\n");
391+
fprintf(stderr, "\n FAIL: CPU CONT(q8_0) unsupported (unexpected)\n");
375392
++fails;
376393
}
377394
if (!fails) {
378-
fprintf(stderr, "ok (Metal CONT(q8_0) unsupported, as the MTL KV guard assumes)\n");
395+
fprintf(stderr, "ok (Metal CAST(q8_0->f32) supported; the align-probe dequant fix holds)\n");
379396
return 0;
380397
}
381398
return 1;

0 commit comments

Comments
 (0)