Skip to content

Commit 6420f0c

Browse files
QVAC-18192 parakeet-cpp: resolve Sortformer head backend internally (force-CPU safety)
The Sortformer force-CPU path (the Mali-Vulkan miscompute workaround) allocated and computed on the caller-supplied backend, so a caller passing the active GPU backend (as test_sortformer_parity did) would defeat the workaround on Mali and drive the CPU-resident head weights through the GPU. Both production engine callers passed the correct backend, but the contract was a footgun. Resolve the head backend internally via model_sortformer_backend(model) (CPU on Mali-Vulkan, the active backend otherwise) and drop the caller-supplied backend parameter from sortformer_diarize_ggml and sortformer_aosc_step so the contract cannot be violated. Make model_sortformer_backend const and add an internal null-backend guard; delete the now-orphaned caller locals.
1 parent 8bbf1f4 commit 6420f0c

6 files changed

Lines changed: 22 additions & 22 deletions

File tree

parakeet-cpp/src/parakeet_ctc.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1134,7 +1134,7 @@ ggml_backend_t model_active_backend(ParakeetCtcModel & m) {
11341134
return m.impl->backend_active;
11351135
}
11361136

1137-
ggml_backend_t model_sortformer_backend(ParakeetCtcModel & m) {
1137+
ggml_backend_t model_sortformer_backend(const ParakeetCtcModel & m) {
11381138
if (!m.impl) return nullptr;
11391139
return m.impl->sortformer_force_cpu ? m.impl->backend_cpu
11401140
: m.impl->backend_active;

parakeet-cpp/src/parakeet_ctc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ ggml_backend_t model_active_backend(ParakeetCtcModel & m);
342342

343343
// Backend for the Sortformer head: the active backend normally, but CPU on
344344
// Mali-Vulkan (its transformer block 0 miscomputes to NaN; encoder stays on GPU).
345-
ggml_backend_t model_sortformer_backend(ParakeetCtcModel & m);
345+
ggml_backend_t model_sortformer_backend(const ParakeetCtcModel & m);
346346

347347
// True when the head is routed to CPU (Mali-Vulkan); the graph then reads the
348348
// CPU-resident weight copies (model.sortformer_cpu), not the GPU originals.

parakeet-cpp/src/parakeet_engine.cpp

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -593,15 +593,10 @@ static DiarizationResult engine_impl_diarize_helper(Engine::Impl & impl,
593593
sopts.threshold = opts.threshold;
594594
SortformerDiarizationResult dres;
595595

596-
ggml_backend_t head_backend = model_sortformer_backend(impl.model);
597-
if (!head_backend) {
598-
throw std::runtime_error("diarize: no ggml backend for the diarization head");
599-
}
600-
601596
int diarize_rc = sortformer_diarize_ggml(impl.model,
602597
enc_out.encoder_out.data(),
603598
enc_out.n_enc_frames, enc_out.d_model,
604-
head_backend, sopts, dres);
599+
sopts, dres);
605600
if (diarize_rc != 0) {
606601
throw std::runtime_error("diarize: sortformer_diarize failed (rc=" +
607602
std::to_string(diarize_rc) + ")");
@@ -726,16 +721,11 @@ static DiarizationResult engine_impl_diarize_streaming_helper(
726721
s_opts.threshold = opts.threshold;
727722
SortformerDiarizationResult dres;
728723

729-
ggml_backend_t head_backend = model_sortformer_backend(impl.model);
730-
if (!head_backend) {
731-
throw std::runtime_error("diarize_streaming: no ggml backend for the diarization head");
732-
}
733-
734724
if (int rc_ = sortformer_aosc_step(impl.model,
735725
pre_encode.data(),
736726
n_pre_encode_frames, D,
737727
lc, rc, chunk_len_eff,
738-
cache, cfg, head_backend, s_opts, dres);
728+
cache, cfg, s_opts, dres);
739729
rc_ != 0) {
740730
throw std::runtime_error("diarize_streaming: sortformer_aosc_step failed (rc=" +
741731
std::to_string(rc_) + ")");

parakeet-cpp/src/parakeet_sortformer.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ ggml_tensor * sf_build_graph(ggml_context * ctx,
158158

159159
// Allocate, upload input, compute, and download output for a Sortformer graph.
160160
// Returns 0 on success, negative on failure. Caller must free ctx afterwards.
161+
// `backend` must be the head backend (model_sortformer_backend): it is used only
162+
// by the force_cpu branch, which requires the CPU backend; the normal path runs
163+
// through `sched`. sortformer_diarize_ggml resolves it so callers cannot mismatch.
161164
int sf_exec_graph(ggml_context * ctx, ggml_backend_t backend,
162165
ggml_backend_sched_t sched, bool force_cpu,
163166
ggml_tensor * x_in, ggml_tensor * x_out,
@@ -630,7 +633,6 @@ void sortformer_cache_reset(SortformerSpeakerCache & cache, int D) {
630633
int sortformer_diarize_ggml(const ParakeetCtcModel & model,
631634
const float * encoder_out,
632635
int T_enc, int D_enc,
633-
ggml_backend_t backend,
634636
const SortformerDiarizationOptions & opts,
635637
SortformerDiarizationResult & out) {
636638
const auto & enc = model.encoder_cfg;
@@ -657,6 +659,16 @@ int sortformer_diarize_ggml(const ParakeetCtcModel & model,
657659
const int head_dim = tf_d / n_heads;
658660
const auto t0 = std::chrono::steady_clock::now();
659661

662+
// Resolve the head backend from the model itself (CPU on Mali-Vulkan, the
663+
// active backend otherwise). Resolving here -- rather than trusting a caller-
664+
// supplied argument -- makes it impossible to drive the force-CPU workaround's
665+
// CPU-resident weights through the GPU.
666+
ggml_backend_t backend = model_sortformer_backend(model);
667+
if (!backend) {
668+
std::fprintf(stderr, "sortformer_diarize_ggml: no ggml backend for the diarization head\n");
669+
return 1;
670+
}
671+
660672
// 1. Context for graph construction (no-alloc)
661673
const size_t graph_slots = 4096;
662674
const size_t overhead = ggml_tensor_overhead() * graph_slots
@@ -708,7 +720,6 @@ int sortformer_aosc_step(ParakeetCtcModel & model,
708720
int lc, int rc, int chunk_len_eff,
709721
SortformerSpeakerCache & cache,
710722
const SortformerStreamingConfig & cfg,
711-
ggml_backend_t backend,
712723
const SortformerDiarizationOptions & opts,
713724
SortformerDiarizationResult & out) {
714725
const auto & enc = model.encoder_cfg;
@@ -782,7 +793,7 @@ int sortformer_aosc_step(ParakeetCtcModel & model,
782793
// 3. Run the diariser over the full cat.
783794
SortformerDiarizationResult diar_cat;
784795
if (int rc_ = sortformer_diarize_ggml(model, enc_cat.encoder_out.data(),
785-
T_cat, D, backend, opts, diar_cat); rc_ != 0) {
796+
T_cat, D, opts, diar_cat); rc_ != 0) {
786797
return rc_;
787798
}
788799
if (diar_cat.num_spks != num_spks) {

parakeet-cpp/src/parakeet_sortformer.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,12 @@ struct SortformerSpeakerCache {
110110
// Reset to a fresh empty state. Allocates mean_sil_emb to D zeros.
111111
void sortformer_cache_reset(SortformerSpeakerCache & cache, int D);
112112

113+
// The diarization head backend is resolved internally via model_sortformer_backend
114+
// (CPU on Mali-Vulkan, the active backend otherwise) so callers cannot accidentally
115+
// drive the CPU-resident force-CPU path through the GPU.
113116
int sortformer_diarize_ggml(const ParakeetCtcModel & model,
114117
const float * encoder_out,
115118
int T_enc, int D_enc,
116-
ggml_backend_t backend,
117119
const SortformerDiarizationOptions & opts,
118120
SortformerDiarizationResult & out);
119121

@@ -142,7 +144,6 @@ int sortformer_aosc_step(ParakeetCtcModel & model,
142144
int lc, int rc, int chunk_len_eff,
143145
SortformerSpeakerCache & cache,
144146
const SortformerStreamingConfig & cfg,
145-
ggml_backend_t backend,
146147
const SortformerDiarizationOptions & opts,
147148
SortformerDiarizationResult & out);
148149

parakeet-cpp/test/test_sortformer_parity.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,12 @@ int main(int argc, char ** argv) {
169169
std::fprintf(stderr, "[sf-parity] enc : max_abs=%.4e rel=%.4e (%s, rel tol=%.1e)\n",
170170
max_abs_enc, rel_enc, enc_pass ? "PASS" : "FAIL", enc_rel_tol);
171171

172-
ggml_backend_t backend = model_active_backend(model);
173-
if (!backend) { std::fprintf(stderr, " error: no active ggml backend\n"); return 8; }
174172
SortformerDiarizationOptions dopts;
175173
SortformerDiarizationResult dres;
176174
if (sortformer_diarize_ggml(model,
177175
enc_out.encoder_out.data(),
178176
enc_out.n_enc_frames, enc_out.d_model,
179-
backend, dopts, dres) != 0) return 9;
177+
dopts, dres) != 0) return 9;
180178

181179
int worst = 0;
182180

0 commit comments

Comments
 (0)