Skip to content

Commit 2c72928

Browse files
fix(stt): pass speaker count hints (#5241)
Infer participant speaker counts for live and batch STT requests, pass them through Hyprnote proxy URLs, and map them to supported provider diarization parameters.
1 parent 1a46ce0 commit 2c72928

12 files changed

Lines changed: 332 additions & 46 deletions

File tree

apps/desktop/src/stt/useRunBatch.test.ts

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { describe, expect, test } from "vitest";
22

3-
import { getBatchProvider } from "./useRunBatch";
3+
import { getBatchProvider, getSessionSpeakerCount } from "./useRunBatch";
44

55
describe("getBatchProvider", () => {
66
test("maps pyannote to the batch transcription provider", () => {
@@ -19,3 +19,40 @@ describe("getBatchProvider", () => {
1919
);
2020
});
2121
});
22+
23+
describe("getSessionSpeakerCount", () => {
24+
test("counts distinct session participants plus the current user", () => {
25+
const rows = new Map([
26+
["mapping-1", { session_id: "session-1", human_id: "human-a" }],
27+
["mapping-2", { session_id: "session-1", human_id: "human-a" }],
28+
["mapping-3", { session_id: "session-1", human_id: "human-b" }],
29+
["mapping-4", { session_id: "other-session", human_id: "human-c" }],
30+
]);
31+
const store = {
32+
forEachRow: (_table: string, callback: (rowId: string) => void) => {
33+
for (const rowId of rows.keys()) callback(rowId);
34+
},
35+
getCell: (_table: string, rowId: string, cellId: string) =>
36+
rows.get(rowId)?.[cellId as "session_id" | "human_id"],
37+
};
38+
39+
expect(getSessionSpeakerCount(store as any, "session-1", "self")).toBe(3);
40+
});
41+
42+
test("returns undefined until at least two speakers are known", () => {
43+
const rows = new Map([
44+
["mapping-1", { session_id: "session-1", human_id: "human-a" }],
45+
]);
46+
const store = {
47+
forEachRow: (_table: string, callback: (rowId: string) => void) => {
48+
for (const rowId of rows.keys()) callback(rowId);
49+
},
50+
getCell: (_table: string, rowId: string, cellId: string) =>
51+
rows.get(rowId)?.[cellId as "session_id" | "human_id"],
52+
};
53+
54+
expect(getSessionSpeakerCount(store as any, "session-1", null)).toBe(
55+
undefined,
56+
);
57+
});
58+
});

apps/desktop/src/stt/useRunBatch.ts

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ type RunOptions = {
3232
maxSpeakers?: number;
3333
};
3434

35+
type Store = NonNullable<ReturnType<typeof main.UI.useStore>>;
36+
3537
const DIRECT_BATCH_PROVIDERS: Set<TranscriptionParams["provider"]> = new Set([
3638
"deepgram",
3739
"soniox",
@@ -81,6 +83,38 @@ export function isStoppedTranscriptionError(error: unknown) {
8183
);
8284
}
8385

86+
export function getSessionSpeakerCount(
87+
store: Store,
88+
sessionId: string,
89+
selfHumanId?: string | null,
90+
): number | undefined {
91+
const humanIds = new Set<string>();
92+
93+
store.forEachRow("mapping_session_participant", (mappingId, _forEachCell) => {
94+
const sid = store.getCell(
95+
"mapping_session_participant",
96+
mappingId,
97+
"session_id",
98+
);
99+
if (sid !== sessionId) return;
100+
101+
const humanId = store.getCell(
102+
"mapping_session_participant",
103+
mappingId,
104+
"human_id",
105+
);
106+
if (typeof humanId === "string" && humanId) {
107+
humanIds.add(humanId);
108+
}
109+
});
110+
111+
if (typeof selfHumanId === "string" && selfHumanId) {
112+
humanIds.add(selfHumanId);
113+
}
114+
115+
return humanIds.size > 1 ? humanIds.size : undefined;
116+
}
117+
84118
export const useRunBatch = (sessionId: string) => {
85119
const store = main.UI.useStore(main.STORE_ID);
86120
const indexes = main.UI.useIndexes(main.STORE_ID);
@@ -114,6 +148,12 @@ export const useRunBatch = (sessionId: string) => {
114148
const createdAt = new Date().toISOString();
115149
const memoMd = store.getCell("sessions", sessionId, "raw_md");
116150
let transcriptId: string | null = null;
151+
const inferredNumSpeakers =
152+
options?.numSpeakers === undefined &&
153+
options?.minSpeakers === undefined &&
154+
options?.maxSpeakers === undefined
155+
? getSessionSpeakerCount(store, sessionId, user_id)
156+
: undefined;
117157

118158
const handlePersist: BatchPersistCallback | undefined =
119159
options?.handlePersist;
@@ -232,7 +272,7 @@ export const useRunBatch = (sessionId: string) => {
232272
languages:
233273
options?.languages ??
234274
getTranscriptionLanguages(aiLanguage, spokenLanguages),
235-
num_speakers: options?.numSpeakers,
275+
num_speakers: options?.numSpeakers ?? inferredNumSpeakers,
236276
min_speakers: options?.minSpeakers,
237277
max_speakers: options?.maxSpeakers,
238278
};

crates/listener-core/src/actors/listener/adapters.rs

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -410,32 +410,25 @@ fn i16_bytes_to_f32(bytes: &Bytes) -> Vec<f32> {
410410
}
411411

412412
fn build_listen_params(args: &ListenerArgs) -> owhisper_interface::ListenParams {
413-
let adapter_kind =
414-
AdapterKind::from_url_and_languages(&args.base_url, &args.languages, Some(&args.model));
415413
let redemption_time_ms = if args.onboarding { "60" } else { "400" };
416-
let mut custom_query = std::collections::HashMap::from([(
414+
let custom_query = std::collections::HashMap::from([(
417415
"redemption_time_ms".to_string(),
418416
redemption_time_ms.to_string(),
419417
)]);
420-
421-
if adapter_kind == AdapterKind::AssemblyAI
422-
&& let Some(expected_speakers) = assemblyai_expected_speakers(args)
423-
{
424-
custom_query.insert("speaker_labels".to_string(), "true".to_string());
425-
custom_query.insert("max_speakers".to_string(), expected_speakers.to_string());
426-
}
418+
let num_speakers = expected_speakers(args);
427419

428420
owhisper_interface::ListenParams {
429421
model: Some(args.model.clone()),
430422
languages: args.languages.clone(),
431423
sample_rate: super::super::SAMPLE_RATE,
432424
keywords: args.keywords.clone(),
425+
num_speakers,
433426
custom_query: Some(custom_query),
434427
..Default::default()
435428
}
436429
}
437430

438-
fn assemblyai_expected_speakers(args: &ListenerArgs) -> Option<u32> {
431+
fn expected_speakers(args: &ListenerArgs) -> Option<u32> {
439432
let mut participants = args.participant_human_ids.clone();
440433

441434
if let Some(self_human_id) = &args.self_human_id
@@ -655,31 +648,26 @@ mod tests {
655648
}
656649

657650
#[test]
658-
fn assemblyai_expected_speakers_counts_distinct_participants() {
651+
fn expected_speakers_counts_distinct_participants() {
659652
let mut args = listener_args("https://api.assemblyai.com", "u3-rt-pro");
660653
args.participant_human_ids = vec!["remote".to_string(), "self".to_string()];
661654
args.self_human_id = Some("self".to_string());
662655

663-
assert_eq!(assemblyai_expected_speakers(&args), Some(2));
656+
assert_eq!(expected_speakers(&args), Some(2));
664657
}
665658

666659
#[test]
667-
fn build_listen_params_adds_assemblyai_diarization_hints() {
660+
fn build_listen_params_sets_num_speakers_without_assemblyai_custom_query() {
668661
let mut args = listener_args("https://api.assemblyai.com", "u3-rt-pro");
669662
args.participant_human_ids = vec!["remote".to_string()];
670663
args.self_human_id = Some("self".to_string());
671664

672665
let params = build_listen_params(&args);
673666
let custom_query = params.custom_query.expect("custom query");
674667

675-
assert_eq!(
676-
custom_query.get("speaker_labels").map(String::as_str),
677-
Some("true")
678-
);
679-
assert_eq!(
680-
custom_query.get("max_speakers").map(String::as_str),
681-
Some("2")
682-
);
668+
assert_eq!(params.num_speakers, Some(2));
669+
assert!(!custom_query.contains_key("speaker_labels"));
670+
assert!(!custom_query.contains_key("max_speakers"));
683671
}
684672

685673
#[test]
@@ -691,6 +679,7 @@ mod tests {
691679
let params = build_listen_params(&args);
692680
let custom_query = params.custom_query.expect("custom query");
693681

682+
assert_eq!(params.num_speakers, Some(2));
694683
assert!(!custom_query.contains_key("speaker_labels"));
695684
assert!(!custom_query.contains_key("max_speakers"));
696685
}

crates/owhisper-client/src/adapter/assemblyai/live.rs

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,13 @@ impl RealtimeSttAdapter for AssemblyAIAdapter {
5757
query_pairs.append_pair("max_turn_silence", max_silence);
5858
}
5959

60-
if matches!(resolved_model, ResolvedLiveModel::U3RtPro)
61-
&& let Some(custom) = &params.custom_query
62-
{
63-
if custom
64-
.get("speaker_labels")
65-
.is_some_and(|value| value == "true")
66-
{
60+
if matches!(resolved_model, ResolvedLiveModel::U3RtPro) {
61+
if Self::streaming_speaker_labels_enabled(params) {
6762
query_pairs.append_pair("speaker_labels", "true");
6863
}
6964

70-
if let Some(max_speakers) = custom.get("max_speakers") {
71-
query_pairs.append_pair("max_speakers", max_speakers);
65+
if let Some(max_speakers) = Self::streaming_max_speakers(params) {
66+
query_pairs.append_pair("max_speakers", &max_speakers.to_string());
7267
}
7368
}
7469

@@ -232,6 +227,27 @@ impl AssemblyAIAdapter {
232227
}
233228
}
234229

230+
fn streaming_speaker_labels_enabled(params: &ListenParams) -> bool {
231+
params.num_speakers.is_some()
232+
|| params.min_speakers.is_some()
233+
|| params.max_speakers.is_some()
234+
|| params
235+
.custom_query
236+
.as_ref()
237+
.and_then(|custom| custom.get("speaker_labels"))
238+
.is_some_and(|value| value == "true")
239+
}
240+
241+
fn streaming_max_speakers(params: &ListenParams) -> Option<u32> {
242+
params.max_speakers.or(params.num_speakers).or_else(|| {
243+
params
244+
.custom_query
245+
.as_ref()
246+
.and_then(|custom| custom.get("max_speakers"))
247+
.and_then(|value| value.parse().ok())
248+
})
249+
}
250+
235251
fn parse_speaker_label(label: Option<&str>) -> Option<i32> {
236252
let label = label?.trim();
237253
if label.is_empty() || label.eq_ignore_ascii_case("unknown") {
@@ -339,8 +355,6 @@ impl ResolvedLiveModel {
339355

340356
#[cfg(test)]
341357
mod tests {
342-
use std::collections::HashMap;
343-
344358
use hypr_language::ISO639;
345359
use owhisper_interface::ListenParams;
346360
use owhisper_interface::stream::StreamResponse;
@@ -424,10 +438,7 @@ mod tests {
424438
API_BASE,
425439
&owhisper_interface::ListenParams {
426440
model: Some("u3-rt-pro".to_string()),
427-
custom_query: Some(HashMap::from([
428-
("speaker_labels".to_string(), "true".to_string()),
429-
("max_speakers".to_string(), "3".to_string()),
430-
])),
441+
num_speakers: Some(3),
431442
..Default::default()
432443
},
433444
1,
@@ -439,14 +450,28 @@ mod tests {
439450
}
440451

441452
#[test]
442-
fn test_whisper_fallback_omits_streaming_diarization_hints() {
453+
fn test_streaming_min_speakers_enables_diarization() {
454+
let url = AssemblyAIAdapter.build_ws_url(
455+
API_BASE,
456+
&owhisper_interface::ListenParams {
457+
model: Some("u3-rt-pro".to_string()),
458+
min_speakers: Some(2),
459+
..Default::default()
460+
},
461+
1,
462+
);
463+
464+
let query = url.query().expect("query string");
465+
assert!(query.contains("speaker_labels=true"));
466+
assert!(!query.contains("max_speakers"));
467+
}
468+
469+
#[test]
470+
fn test_streaming_diarization_hints_skip_whisper_fallback() {
443471
let url = AssemblyAIAdapter.build_ws_url(
444472
API_BASE,
445473
&owhisper_interface::ListenParams {
446-
custom_query: Some(HashMap::from([
447-
("speaker_labels".to_string(), "true".to_string()),
448-
("max_speakers".to_string(), "3".to_string()),
449-
])),
474+
num_speakers: Some(3),
450475
languages: vec![ISO639::Ko.into()],
451476
..Default::default()
452477
},
@@ -455,8 +480,8 @@ mod tests {
455480

456481
let query = url.query().expect("query string");
457482
assert!(query.contains("speech_model=whisper-rt"));
458-
assert!(!query.contains("speaker_labels=true"));
459-
assert!(!query.contains("max_speakers=3"));
483+
assert!(!query.contains("speaker_labels"));
484+
assert!(!query.contains("max_speakers"));
460485
}
461486

462487
#[test]

crates/owhisper-client/src/adapter/elevenlabs/batch.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ impl ElevenLabsAdapter {
8686
.text("diarize", "true")
8787
.text("timestamps_granularity", "word");
8888

89+
if let Some(num_speakers) = Self::num_speakers_hint(params) {
90+
form = form.text("num_speakers", num_speakers.to_string());
91+
}
92+
8993
if let Some(lang) = params.languages.first() {
9094
form = form.text("language_code", lang.iso639().code().to_string());
9195
}
@@ -116,6 +120,10 @@ impl ElevenLabsAdapter {
116120
Ok(Self::convert_to_batch_response(transcript))
117121
}
118122

123+
fn num_speakers_hint(params: &ListenParams) -> Option<u32> {
124+
params.num_speakers.or(params.max_speakers)
125+
}
126+
119127
fn convert_to_batch_response(response: TranscriptResponse) -> BatchResponse {
120128
let words: Vec<BatchWord> = response
121129
.words
@@ -164,6 +172,26 @@ mod tests {
164172
use super::*;
165173
use crate::http_client::create_client;
166174

175+
#[test]
176+
fn num_speakers_hint_prefers_exact_count_then_max() {
177+
let exact = ListenParams {
178+
num_speakers: Some(3),
179+
max_speakers: Some(5),
180+
..Default::default()
181+
};
182+
let ranged = ListenParams {
183+
max_speakers: Some(5),
184+
..Default::default()
185+
};
186+
187+
assert_eq!(ElevenLabsAdapter::num_speakers_hint(&exact), Some(3));
188+
assert_eq!(ElevenLabsAdapter::num_speakers_hint(&ranged), Some(5));
189+
assert_eq!(
190+
ElevenLabsAdapter::num_speakers_hint(&ListenParams::default()),
191+
None
192+
);
193+
}
194+
167195
#[test]
168196
fn speaker_labeled_words_use_mixed_capture_channel() {
169197
let response = TranscriptResponse {

0 commit comments

Comments
 (0)