Skip to content

Commit 073c7b7

Browse files
ogad-tetherclaude
andcommitted
tts-cpp: chatterbox-mtl — probe the align-cast op per-backend (close guard-removal gap)
An adversarial audit of PR #71 flagged that fully removing chatterbox_mtl_guard_kv_type deleted the blanket "force f32 on any non-CPU backend" net, so a quantized KV request now reaches ALL GPU backends for the MTL variant. The shared chatterbox_resolve_kv_type only probes flash_attn_ext — NOT the dequantizing ggml_cast(q8_0 strided -> f32) the alignment probe emits every decode step. A GPU backend with thin op coverage (e.g. some OpenCL/Adreno or Mali-Vulkan builds) can advertise q8 flash-attn yet be unable to encode that cast, and because the MTL path runs a single-backend graph_compute (no scheduler fallback) it would SIGABRT at compute — i.e. removing the guard could trade the Metal crash for a crash on another backend. Fix: chatterbox_mtl_resolve_kv_type wraps the shared resolve and additionally probes the strided q8->f32 cast via ggml_backend_supports_op, falling back to f32 only when the backend can't encode it. This is per-backend-correct: Metal (which supports the cast — verified) keeps q8 on the GPU, and any backend lacking the kernel safely degrades to f32 instead of crashing. Replaces the blunt "non-CPU -> f32" guard, which also blocked Metal (the original bug). Validated (stock ggml Metal, M2): q8 MTL on Metal still retains q8 (no fallback, no crash, byte-identical sample count). test_kv_cache_type extended for the new resolve (cpu retains q8 / null -> f32 / f32 stays f32). Refs QVAC-19557 Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 55ee8d0 commit 073c7b7

4 files changed

Lines changed: 65 additions & 6 deletions

File tree

tts-cpp/src/chatterbox_t3_internal.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,15 @@ ggml_type chatterbox_kv_type_from_str(const std::string & s);
119119
ggml_type chatterbox_resolve_kv_type(ggml_backend_t backend, ggml_type requested,
120120
int head_dim, int n_head, int n_kv_head);
121121

122+
// MTL-variant resolve: chatterbox_resolve_kv_type plus a probe of the extra
123+
// quantized-cache op the multilingual decode graph emits — the alignment
124+
// probe's dequantizing cast of a strided q8 K-cache view to f32
125+
// (build_llama_block). Returns f32 when the backend can't encode that cast, so
126+
// q8 KV stays enabled on backends that support it (Metal) and safely degrades on
127+
// those that don't, without the single-backend MTL graph SIGABRT'ing at compute.
128+
ggml_type chatterbox_mtl_resolve_kv_type(ggml_backend_t backend, ggml_type requested,
129+
int head_dim, int n_head, int n_kv_head);
130+
122131
struct gpt2_layer {
123132
ggml_tensor * ln_1_g = nullptr;
124133
ggml_tensor * ln_1_b = nullptr;

tts-cpp/src/main.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,44 @@ ggml_type chatterbox_resolve_kv_type(ggml_backend_t backend, ggml_type requested
402402
return requested;
403403
}
404404

405+
ggml_type chatterbox_mtl_resolve_kv_type(ggml_backend_t backend, ggml_type requested,
406+
int head_dim, int n_head, int n_kv_head) {
407+
// Start from the shared resolve (flash_attn_ext probe + Vulkan coopmat2
408+
// force-f32). The MTL decode graph emits one MORE quantized-cache op the
409+
// shared probe doesn't cover: the per-(layer,head) alignment probe
410+
// dequantizes a STRIDED view of the quantized K cache via ggml_cast(...->f32)
411+
// (build_llama_block). ggml-metal supports that cast (which is why q8 KV now
412+
// runs on Metal), but a GPU backend with thinner op coverage
413+
// (e.g. some OpenCL/Adreno or Mali-Vulkan builds) can advertise q8 flash-attn
414+
// yet be unable to encode the strided q8->f32 cast — and the MTL path runs a
415+
// single-backend graph_compute with no scheduler fallback, so that would
416+
// SIGABRT at compute. Probe the cast op directly and fall back to f32 when
417+
// the backend can't encode it, instead of the old blanket "force f32 on any
418+
// non-CPU backend" guard (which also blocked Metal, the whole bug).
419+
ggml_type t = chatterbox_resolve_kv_type(backend, requested, head_dim, n_head, n_kv_head);
420+
if (!ggml_is_quantized(t) || !backend) return t;
421+
422+
bool cast_ok = false;
423+
ggml_init_params pp = { ggml_tensor_overhead() * 8, nullptr, /*no_alloc=*/true };
424+
if (ggml_context * pc = ggml_init(pp)) {
425+
// Mirror the align probe: a strided [head_dim, k] view of the token-major
426+
// q8 cache, cast to f32. Strides come from ggml_row_size so the view is
427+
// block-aligned exactly as build_llama_block builds it.
428+
const size_t tok_row = ggml_row_size(t, (size_t) head_dim * n_kv_head);
429+
ggml_tensor * cache = ggml_new_tensor_1d(pc, t, (int64_t) head_dim * n_kv_head * 8);
430+
ggml_tensor * view = ggml_view_2d(pc, cache, head_dim, 4, tok_row, 0);
431+
ggml_tensor * cast = ggml_cast(pc, view, GGML_TYPE_F32);
432+
cast_ok = (cast != nullptr) && ggml_backend_supports_op(backend, cast);
433+
ggml_free(pc);
434+
}
435+
if (!cast_ok) {
436+
fprintf(stderr, "chatterbox(mtl): backend cannot encode the quantized-KV alignment "
437+
"cast (%s strided -> f32); using f32 KV cache\n", ggml_type_name(t));
438+
return GGML_TYPE_F32;
439+
}
440+
return t;
441+
}
442+
405443
bool load_model_gguf(const std::string & path, chatterbox_model & model, int requested_ctx, int n_gpu_layers, ggml_type kv_type) {
406444
{
407445
gguf_init_params peek_params = { /*.no_alloc=*/ true, /*.ctx=*/ nullptr };

tts-cpp/src/t3_mtl.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1837,18 +1837,18 @@ bool load_model_gguf_mtl(const std::string & path,
18371837
// kv_layer_elems * sizeof(float).
18381838
// Fall back to F32 KV if the resolved backend can't run flash
18391839
// attention with the requested quantized/f16 K/V.
1840-
hp.kv_type = chatterbox_resolve_kv_type(model.backend, kv_type,
1841-
hp.head_dim, hp.n_head, hp.n_kv_head);
18421840
// QVAC-19557: a quantized (q8_0) KV cache used to SIGABRT on Metal
18431841
// ("unsupported op 'CONT'"). The cause was NOT flash-attention (which
18441842
// reads the q8 strided cache fine on Metal) but the per-(layer,head)
18451843
// alignment probe in build_llama_block, which ggml_cont'd a strided view
18461844
// of the quantized K cache to feed a mul_mat — and ggml-metal has no CONT
18471845
// kernel for quantized tensors. That cont is now a dequantizing
1848-
// ggml_cast to f32 (Metal-supported), so quantized K/V runs on the GPU
1849-
// for the MTL variant and no f32 fallback guard is needed here. Vulkan
1850-
// quantized K/V is still force-f32'd inside chatterbox_resolve_kv_type
1851-
// (separate coopmat2 issue).
1846+
// ggml_cast to f32 (Metal-supported), so quantized K/V runs on the GPU.
1847+
// chatterbox_mtl_resolve_kv_type probes that cast per-backend and falls
1848+
// back to f32 on any GPU backend that can't encode it (Vulkan coopmat2 is
1849+
// separately force-f32'd inside the shared resolve).
1850+
hp.kv_type = chatterbox_mtl_resolve_kv_type(model.backend, kv_type,
1851+
hp.head_dim, hp.n_head, hp.n_kv_head);
18521852
ggml_init_params kv_params = { ggml_tensor_overhead() * 4, nullptr, true };
18531853
model.ctx_kv = ggml_init(kv_params);
18541854
const int64_t kv_elements_b2 =

tts-cpp/test/test_kv_cache_type.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,18 @@ int main() {
6666
CHECK(chatterbox_resolve_kv_type(cpu, GGML_TYPE_Q8_0, head_dim, n_head, n_kv_head)
6767
== GGML_TYPE_Q8_0, "cpu retains q8_0 KV");
6868

69+
// ---- MTL resolve (QVAC-19557): also probes the align-probe cast(q8->f32) ----
70+
// The CPU backend supports the strided q8->f32 cast, so q8 is retained; a
71+
// backend lacking that cast kernel would be downgraded to f32 (the branch
72+
// that stops the single-backend MTL graph SIGABRT'ing at compute). f32
73+
// requests are unaffected.
74+
CHECK(chatterbox_mtl_resolve_kv_type(cpu, GGML_TYPE_F32, head_dim, n_head, n_kv_head)
75+
== GGML_TYPE_F32, "mtl resolve: f32 stays f32 on cpu");
76+
CHECK(chatterbox_mtl_resolve_kv_type(cpu, GGML_TYPE_Q8_0, head_dim, n_head, n_kv_head)
77+
== GGML_TYPE_Q8_0, "mtl resolve: cpu retains q8_0 (supports the cast)");
78+
CHECK(chatterbox_mtl_resolve_kv_type(nullptr, GGML_TYPE_Q8_0, head_dim, n_head, n_kv_head)
79+
== GGML_TYPE_F32, "mtl resolve: null backend -> f32");
80+
6981
ggml_backend_free(cpu);
7082

7183
if (g_failures) {

0 commit comments

Comments
 (0)