@@ -57,18 +57,13 @@ impl RealtimeSttAdapter for AssemblyAIAdapter {
5757 query_pairs. append_pair ( "max_turn_silence" , max_silence) ;
5858 }
5959
60- if matches ! ( resolved_model, ResolvedLiveModel :: U3RtPro )
61- && let Some ( custom) = & params. custom_query
62- {
63- if custom
64- . get ( "speaker_labels" )
65- . is_some_and ( |value| value == "true" )
66- {
60+ if matches ! ( resolved_model, ResolvedLiveModel :: U3RtPro ) {
61+ if Self :: streaming_speaker_labels_enabled ( params) {
6762 query_pairs. append_pair ( "speaker_labels" , "true" ) ;
6863 }
6964
70- if let Some ( max_speakers) = custom . get ( "max_speakers" ) {
71- query_pairs. append_pair ( "max_speakers" , max_speakers) ;
65+ if let Some ( max_speakers) = Self :: streaming_max_speakers ( params ) {
66+ query_pairs. append_pair ( "max_speakers" , & max_speakers. to_string ( ) ) ;
7267 }
7368 }
7469
@@ -232,6 +227,27 @@ impl AssemblyAIAdapter {
232227 }
233228 }
234229
230+ fn streaming_speaker_labels_enabled ( params : & ListenParams ) -> bool {
231+ params. num_speakers . is_some ( )
232+ || params. min_speakers . is_some ( )
233+ || params. max_speakers . is_some ( )
234+ || params
235+ . custom_query
236+ . as_ref ( )
237+ . and_then ( |custom| custom. get ( "speaker_labels" ) )
238+ . is_some_and ( |value| value == "true" )
239+ }
240+
241+ fn streaming_max_speakers ( params : & ListenParams ) -> Option < u32 > {
242+ params. max_speakers . or ( params. num_speakers ) . or_else ( || {
243+ params
244+ . custom_query
245+ . as_ref ( )
246+ . and_then ( |custom| custom. get ( "max_speakers" ) )
247+ . and_then ( |value| value. parse ( ) . ok ( ) )
248+ } )
249+ }
250+
235251 fn parse_speaker_label ( label : Option < & str > ) -> Option < i32 > {
236252 let label = label?. trim ( ) ;
237253 if label. is_empty ( ) || label. eq_ignore_ascii_case ( "unknown" ) {
@@ -339,8 +355,6 @@ impl ResolvedLiveModel {
339355
340356#[ cfg( test) ]
341357mod tests {
342- use std:: collections:: HashMap ;
343-
344358 use hypr_language:: ISO639 ;
345359 use owhisper_interface:: ListenParams ;
346360 use owhisper_interface:: stream:: StreamResponse ;
@@ -424,10 +438,7 @@ mod tests {
424438 API_BASE ,
425439 & owhisper_interface:: ListenParams {
426440 model : Some ( "u3-rt-pro" . to_string ( ) ) ,
427- custom_query : Some ( HashMap :: from ( [
428- ( "speaker_labels" . to_string ( ) , "true" . to_string ( ) ) ,
429- ( "max_speakers" . to_string ( ) , "3" . to_string ( ) ) ,
430- ] ) ) ,
441+ num_speakers : Some ( 3 ) ,
431442 ..Default :: default ( )
432443 } ,
433444 1 ,
@@ -439,14 +450,28 @@ mod tests {
439450 }
440451
441452 #[ test]
442- fn test_whisper_fallback_omits_streaming_diarization_hints ( ) {
453+ fn test_streaming_min_speakers_enables_diarization ( ) {
454+ let url = AssemblyAIAdapter . build_ws_url (
455+ API_BASE ,
456+ & owhisper_interface:: ListenParams {
457+ model : Some ( "u3-rt-pro" . to_string ( ) ) ,
458+ min_speakers : Some ( 2 ) ,
459+ ..Default :: default ( )
460+ } ,
461+ 1 ,
462+ ) ;
463+
464+ let query = url. query ( ) . expect ( "query string" ) ;
465+ assert ! ( query. contains( "speaker_labels=true" ) ) ;
466+ assert ! ( !query. contains( "max_speakers" ) ) ;
467+ }
468+
469+ #[ test]
470+ fn test_streaming_diarization_hints_skip_whisper_fallback ( ) {
443471 let url = AssemblyAIAdapter . build_ws_url (
444472 API_BASE ,
445473 & owhisper_interface:: ListenParams {
446- custom_query : Some ( HashMap :: from ( [
447- ( "speaker_labels" . to_string ( ) , "true" . to_string ( ) ) ,
448- ( "max_speakers" . to_string ( ) , "3" . to_string ( ) ) ,
449- ] ) ) ,
474+ num_speakers : Some ( 3 ) ,
450475 languages : vec ! [ ISO639 :: Ko . into( ) ] ,
451476 ..Default :: default ( )
452477 } ,
@@ -455,8 +480,8 @@ mod tests {
455480
456481 let query = url. query ( ) . expect ( "query string" ) ;
457482 assert ! ( query. contains( "speech_model=whisper-rt" ) ) ;
458- assert ! ( !query. contains( "speaker_labels=true " ) ) ;
459- assert ! ( !query. contains( "max_speakers=3 " ) ) ;
483+ assert ! ( !query. contains( "speaker_labels" ) ) ;
484+ assert ! ( !query. contains( "max_speakers" ) ) ;
460485 }
461486
462487 #[ test]
0 commit comments