@@ -5,7 +5,7 @@ use axum::{
55 response:: IntoResponse ,
66} ;
77use base64:: Engine ;
8- use bytes:: { BufMut , Bytes , BytesMut } ;
8+ use bytes:: { BufMut , BytesMut } ;
99use futures_util:: {
1010 sink:: SinkExt ,
1111 stream:: { SplitStream , StreamExt } ,
@@ -15,13 +15,7 @@ use tokio::sync::mpsc;
1515use uuid:: Uuid ;
1616
1717use crate :: {
18- ai:: {
19- ChatSession ,
20- bailian:: cosyvoice,
21- elevenlabs,
22- openai:: realtime:: * ,
23- vad:: { VadRealtimeClient , VadRealtimeEvent } ,
24- } ,
18+ ai:: { ChatSession , bailian:: cosyvoice, elevenlabs, openai:: realtime:: * , vad:: VadSession } ,
2519 config:: * ,
2620} ;
2721
@@ -44,11 +38,11 @@ pub struct RealtimeSession {
4438 pub input_audio_buffer : BytesMut ,
4539 pub triggered : bool ,
4640 pub is_generating : bool ,
47- pub vad_realtime_client : Option < VadRealtimeClient > ,
41+ pub vad_session : Option < VadSession > ,
4842}
4943
5044impl RealtimeSession {
51- pub fn new ( chat_session : ChatSession ) -> Self {
45+ pub fn new ( chat_session : ChatSession , vad_session : Option < VadSession > ) -> Self {
5246 Self {
5347 client : reqwest:: Client :: new ( ) ,
5448 chat_session,
@@ -58,7 +52,7 @@ impl RealtimeSession {
5852 input_audio_buffer : BytesMut :: new ( ) ,
5953 triggered : false ,
6054 is_generating : false ,
61- vad_realtime_client : None ,
55+ vad_session ,
6256 }
6357 }
6458}
@@ -72,7 +66,6 @@ pub struct StableRealtimeConfig {
7266
7367enum RealtimeEvent {
7468 ClientEvent ( ClientEvent ) ,
75- VadEvent ( VadRealtimeEvent ) ,
7669}
7770
7871pub async fn ws_handler (
@@ -100,24 +93,30 @@ async fn handle_socket(config: Arc<StableRealtimeConfig>, socket: WebSocket) {
10093 chat_session. system_prompts = parts. sys_prompts ;
10194 chat_session. messages = parts. dynamic_prompts ;
10295
103- // 创建新的 Realtime 会话
104- let mut session = RealtimeSession :: new ( chat_session) ;
105- let mut realtime_rx: Option < _ > = None ;
106-
107- if let Some ( vad_realtime_url) = & config. asr . vad_realtime_url {
108- match crate :: ai:: vad:: vad_realtime_client ( & session. client , vad_realtime_url. clone ( ) ) . await {
109- Ok ( ( client, rx) ) => {
110- session. vad_realtime_client = Some ( client) ;
111- realtime_rx = Some ( rx) ;
112- log:: info!( "Connected to VAD realtime service at {}" , vad_realtime_url) ;
113- }
114- Err ( e) => {
115- log:: error!( "Failed to connect to VAD realtime service: {}" , e) ;
116- }
96+ // Initialize built-in silero VAD session
97+ let vad_session = match crate :: ai:: vad:: VadSession :: new (
98+ & config. asr . vad ,
99+ Box :: new (
100+ silero_vad_burn:: SileroVAD6Model :: new ( & burn:: backend:: ndarray:: NdArrayDevice :: default ( ) )
101+ . expect ( "Failed to create silero VAD model" ) ,
102+ ) ,
103+ burn:: backend:: ndarray:: NdArrayDevice :: default ( ) ,
104+ ) {
105+ Ok ( session) => {
106+ log:: info!( "Initialized built-in silero VAD session" ) ;
107+ Some ( session)
117108 }
118- }
109+ Err ( e) => {
110+ log:: error!( "Failed to initialize silero VAD session: {}" , e) ;
111+ None
112+ }
113+ } ;
119114
120- let turn_detection = if realtime_rx. is_some ( ) {
115+ // 创建新的 Realtime 会话
116+ let has_vad = vad_session. is_some ( ) ;
117+ let mut session = RealtimeSession :: new ( chat_session, vad_session) ;
118+
119+ let turn_detection = if has_vad {
121120 TurnDetection :: server_vad ( )
122121 } else {
123122 TurnDetection :: none ( )
@@ -244,33 +243,10 @@ async fn handle_socket(config: Arc<StableRealtimeConfig>, socket: WebSocket) {
244243 None
245244 }
246245
247- async fn select_event (
248- socket : & mut SplitStream < WebSocket > ,
249- realtime_rx : & mut Option < crate :: ai:: vad:: VadRealtimeRx > ,
250- ) -> Option < RealtimeEvent > {
251- if let Some ( rx) = realtime_rx {
252- tokio:: select! {
253- client_event = recv_client_event( socket) => {
254- client_event. map( RealtimeEvent :: ClientEvent )
255- }
256- vad_event = rx. next_event( ) => {
257- match vad_event {
258- Ok ( event) => Some ( RealtimeEvent :: VadEvent ( event) ) ,
259- Err ( e) => {
260- log:: error!( "Failed to receive VAD event: {}" , e) ;
261- None
262- }
263- }
264- }
265- }
266- } else {
267- recv_client_event ( socket)
268- . await
269- . map ( RealtimeEvent :: ClientEvent )
270- }
271- }
272-
273- while let Some ( event) = select_event ( & mut receiver, & mut realtime_rx) . await {
246+ while let Some ( event) = recv_client_event ( & mut receiver)
247+ . await
248+ . map ( RealtimeEvent :: ClientEvent )
249+ {
274250 if let Err ( e) = handle_client_message (
275251 event,
276252 & mut session,
@@ -360,14 +336,14 @@ async fn handle_client_message(
360336 return Ok ( ( ) ) ;
361337 }
362338 if turn_detection. turn_type == TurnDetectionType :: ServerVad
363- && session. vad_realtime_client . is_none ( )
339+ && session. vad_session . is_none ( )
364340 {
365341 let error_event = ServerEvent :: Error {
366342 event_id : Uuid :: new_v4 ( ) . to_string ( ) ,
367343 error : ErrorDetails {
368344 error_type : "invalid_request_error" . to_string ( ) ,
369- code : Some ( "vad_realtime_not_connected " . to_string ( ) ) ,
370- message : "VAD realtime service is not connected " . to_string ( ) ,
345+ code : Some ( "vad_not_available " . to_string ( ) ) ,
346+ message : "VAD session is not available " . to_string ( ) ,
371347 param : Some ( "turn_detection.type" . to_string ( ) ) ,
372348 event_id : None ,
373349 } ,
@@ -433,11 +409,11 @@ async fn handle_client_message(
433409 . as_ref ( )
434410 . map ( |t| t. turn_type == TurnDetectionType :: ServerVad )
435411 . unwrap_or_default ( )
436- && session. vad_realtime_client . is_some ( ) ;
412+ && session. vad_session . is_some ( ) ;
437413
438414 log:: debug!(
439415 "Server VAD status: {} {:?}" ,
440- session. vad_realtime_client . is_some( ) ,
416+ session. vad_session . is_some( ) ,
441417 session. config. turn_detection
442418 ) ;
443419
@@ -473,26 +449,55 @@ async fn handle_client_message(
473449 }
474450 }
475451
452+ // Process audio through built-in silero VAD
476453 if server_vad {
477- let samples_24k = audio_data
478- . chunks_exact ( 2 )
479- . map ( |chunk| {
480- i16:: from_le_bytes ( [ chunk[ 0 ] , chunk[ 1 ] ] ) as f32 / i16:: MAX as f32
481- } )
482- . collect :: < Vec < f32 > > ( ) ;
483- let sample_16k = wav_io:: resample:: linear ( samples_24k, 1 , 24000 , 16000 ) ;
484-
485- let sample_16k = crate :: util:: convert_samples_f32_to_i16_bytes ( & sample_16k) ;
486- log:: debug!(
487- "Sending audio chunk to VAD realtime service, length: {}" ,
488- sample_16k. len( )
489- ) ;
490- session
491- . vad_realtime_client
492- . as_mut ( )
493- . unwrap ( )
494- . push_audio_16k_chunk ( Bytes :: from ( sample_16k) )
495- . await ?;
454+ if let Some ( vad_session) = session. vad_session . as_mut ( ) {
455+ // Convert 24kHz PCM16 to 16kHz f32 for VAD
456+ let samples_24k: Vec < f32 > = audio_data
457+ . chunks_exact ( 2 )
458+ . map ( |chunk| {
459+ i16:: from_le_bytes ( [ chunk[ 0 ] , chunk[ 1 ] ] ) as f32 / i16:: MAX as f32
460+ } )
461+ . collect ( ) ;
462+ let samples_16k =
463+ wav_io:: resample:: linear ( samples_24k, 1 , 24000 , 16000 ) ;
464+
465+ // Process through VAD in chunks
466+ let chunk_size = VadSession :: vad_chunk_size ( ) ;
467+ let mut speech_detected = false ;
468+ for chunk in samples_16k. chunks ( chunk_size) {
469+ if let Ok ( is_speech) = vad_session. detect ( chunk) {
470+ if is_speech {
471+ speech_detected = true ;
472+ } else if session. triggered && !is_speech {
473+ // Speech ended - trigger commit
474+ log:: info!( "VAD detected speech end, triggering commit" ) ;
475+ if handle_audio_buffer_commit ( session, tx, None , asr)
476+ . await ?
477+ {
478+ generate_response ( session, tx, tts) . await ?;
479+ }
480+ session. triggered = false ;
481+ if let Some ( vs) = session. vad_session . as_mut ( ) {
482+ vs. reset_state ( ) ;
483+ }
484+ break ;
485+ }
486+ }
487+ }
488+
489+ if speech_detected && !session. triggered {
490+ log:: info!( "VAD detected speech start" ) ;
491+ session. triggered = true ;
492+ // Send speech started event
493+ let event = ServerEvent :: InputAudioBufferSpeechStarted {
494+ event_id : Uuid :: new_v4 ( ) . to_string ( ) ,
495+ audio_start_ms : 0 ,
496+ item_id : Uuid :: new_v4 ( ) . to_string ( ) ,
497+ } ;
498+ let _ = tx. send ( event) . await ;
499+ }
500+ }
496501 }
497502 }
498503
@@ -630,28 +635,6 @@ async fn handle_client_message(
630635 }
631636 }
632637 }
633- RealtimeEvent :: VadEvent ( vad_realtime_event) => match vad_realtime_event {
634- VadRealtimeEvent :: Event { event } => match event. as_str ( ) {
635- "speech_start" => {
636- log:: debug!( "VAD speech start detected" ) ;
637- session. triggered = true ;
638- }
639- "speech_end" => {
640- log:: debug!( "VAD speech end detected" ) ;
641- session. triggered = false ;
642- if handle_audio_buffer_commit ( session, tx, None , asr) . await ? {
643- log:: debug!( "Audio buffer committed, generating response" ) ;
644- generate_response ( session, tx, tts) . await ?;
645- }
646- }
647- _ => {
648- log:: warn!( "Unhandled VAD event: {}" , event) ;
649- }
650- } ,
651- VadRealtimeEvent :: Error { message, .. } => {
652- return Err ( anyhow:: anyhow!( "VAD error: {}" , message) ) ;
653- }
654- } ,
655638 }
656639
657640 Ok ( ( ) )
@@ -682,9 +665,30 @@ async fn handle_audio_buffer_commit(
682665 } ;
683666 let _ = tx. send ( committed_event) . await ;
684667
685- if let Some ( vad_url) = & config. vad_url {
686- let vad = crate :: ai:: vad:: vad_detect ( & session. client , vad_url, wav_audio. clone ( ) ) . await ?;
687- if vad. timestamps . is_empty ( ) {
668+ // Check for speech using built-in silero VAD
669+ if let Some ( vad_session) = session. vad_session . as_mut ( ) {
670+ // Convert 24kHz PCM16 to 16kHz f32 for VAD
671+ let samples_24k: Vec < f32 > = audio_data
672+ . chunks_exact ( 2 )
673+ . map ( |chunk| i16:: from_le_bytes ( [ chunk[ 0 ] , chunk[ 1 ] ] ) as f32 / i16:: MAX as f32 )
674+ . collect ( ) ;
675+ let samples_16k = wav_io:: resample:: linear ( samples_24k, 1 , 24000 , 16000 ) ;
676+
677+ // Process through VAD to check if there's any speech
678+ let chunk_size = VadSession :: vad_chunk_size ( ) ;
679+ let mut has_speech = false ;
680+ vad_session. reset_state ( ) ;
681+ for chunk in samples_16k. chunks ( chunk_size) {
682+ if let Ok ( is_speech) = vad_session. detect ( chunk) {
683+ if is_speech {
684+ has_speech = true ;
685+ break ;
686+ }
687+ }
688+ }
689+
690+ if !has_speech {
691+ log:: debug!( "No speech detected in audio buffer, skipping ASR" ) ;
688692 let transcription_completed =
689693 ServerEvent :: ConversationItemInputAudioTranscriptionCompleted {
690694 event_id : Uuid :: new_v4 ( ) . to_string ( ) ,
0 commit comments