Skip to content

Commit 6f01e93

Browse files
juntaoclaude
andcommitted
Refactor realtime_ws to use built-in silero VAD
Replace deprecated HTTP-based vad_url and WebSocket-based vad_realtime_url with the built-in silero_vad_burn library: - Add VadSession to RealtimeSession for local VAD processing - Process audio through silero VAD inline when receiving audio chunks - Detect speech start/end events locally instead of via remote service - Use local VAD in handle_audio_buffer_commit instead of HTTP call - Remove VadRealtimeClient and VadRealtimeRx dependencies Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent fe01bbd commit 6f01e93

1 file changed

Lines changed: 108 additions & 104 deletions

File tree

src/services/realtime_ws.rs

Lines changed: 108 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use axum::{
55
response::IntoResponse,
66
};
77
use base64::Engine;
8-
use bytes::{BufMut, Bytes, BytesMut};
8+
use bytes::{BufMut, BytesMut};
99
use futures_util::{
1010
sink::SinkExt,
1111
stream::{SplitStream, StreamExt},
@@ -15,13 +15,7 @@ use tokio::sync::mpsc;
1515
use uuid::Uuid;
1616

1717
use crate::{
18-
ai::{
19-
ChatSession,
20-
bailian::cosyvoice,
21-
elevenlabs,
22-
openai::realtime::*,
23-
vad::{VadRealtimeClient, VadRealtimeEvent},
24-
},
18+
ai::{ChatSession, bailian::cosyvoice, elevenlabs, openai::realtime::*, vad::VadSession},
2519
config::*,
2620
};
2721

@@ -44,11 +38,11 @@ pub struct RealtimeSession {
4438
pub input_audio_buffer: BytesMut,
4539
pub triggered: bool,
4640
pub is_generating: bool,
47-
pub vad_realtime_client: Option<VadRealtimeClient>,
41+
pub vad_session: Option<VadSession>,
4842
}
4943

5044
impl RealtimeSession {
51-
pub fn new(chat_session: ChatSession) -> Self {
45+
pub fn new(chat_session: ChatSession, vad_session: Option<VadSession>) -> Self {
5246
Self {
5347
client: reqwest::Client::new(),
5448
chat_session,
@@ -58,7 +52,7 @@ impl RealtimeSession {
5852
input_audio_buffer: BytesMut::new(),
5953
triggered: false,
6054
is_generating: false,
61-
vad_realtime_client: None,
55+
vad_session,
6256
}
6357
}
6458
}
@@ -72,7 +66,6 @@ pub struct StableRealtimeConfig {
7266

7367
enum RealtimeEvent {
7468
ClientEvent(ClientEvent),
75-
VadEvent(VadRealtimeEvent),
7669
}
7770

7871
pub async fn ws_handler(
@@ -100,24 +93,30 @@ async fn handle_socket(config: Arc<StableRealtimeConfig>, socket: WebSocket) {
10093
chat_session.system_prompts = parts.sys_prompts;
10194
chat_session.messages = parts.dynamic_prompts;
10295

103-
// 创建新的 Realtime 会话
104-
let mut session = RealtimeSession::new(chat_session);
105-
let mut realtime_rx: Option<_> = None;
106-
107-
if let Some(vad_realtime_url) = &config.asr.vad_realtime_url {
108-
match crate::ai::vad::vad_realtime_client(&session.client, vad_realtime_url.clone()).await {
109-
Ok((client, rx)) => {
110-
session.vad_realtime_client = Some(client);
111-
realtime_rx = Some(rx);
112-
log::info!("Connected to VAD realtime service at {}", vad_realtime_url);
113-
}
114-
Err(e) => {
115-
log::error!("Failed to connect to VAD realtime service: {}", e);
116-
}
96+
// Initialize built-in silero VAD session
97+
let vad_session = match crate::ai::vad::VadSession::new(
98+
&config.asr.vad,
99+
Box::new(
100+
silero_vad_burn::SileroVAD6Model::new(&burn::backend::ndarray::NdArrayDevice::default())
101+
.expect("Failed to create silero VAD model"),
102+
),
103+
burn::backend::ndarray::NdArrayDevice::default(),
104+
) {
105+
Ok(session) => {
106+
log::info!("Initialized built-in silero VAD session");
107+
Some(session)
117108
}
118-
}
109+
Err(e) => {
110+
log::error!("Failed to initialize silero VAD session: {}", e);
111+
None
112+
}
113+
};
119114

120-
let turn_detection = if realtime_rx.is_some() {
115+
// 创建新的 Realtime 会话
116+
let has_vad = vad_session.is_some();
117+
let mut session = RealtimeSession::new(chat_session, vad_session);
118+
119+
let turn_detection = if has_vad {
121120
TurnDetection::server_vad()
122121
} else {
123122
TurnDetection::none()
@@ -244,33 +243,10 @@ async fn handle_socket(config: Arc<StableRealtimeConfig>, socket: WebSocket) {
244243
None
245244
}
246245

247-
async fn select_event(
248-
socket: &mut SplitStream<WebSocket>,
249-
realtime_rx: &mut Option<crate::ai::vad::VadRealtimeRx>,
250-
) -> Option<RealtimeEvent> {
251-
if let Some(rx) = realtime_rx {
252-
tokio::select! {
253-
client_event = recv_client_event(socket) => {
254-
client_event.map(RealtimeEvent::ClientEvent)
255-
}
256-
vad_event = rx.next_event() => {
257-
match vad_event {
258-
Ok(event) => Some(RealtimeEvent::VadEvent(event)),
259-
Err(e) => {
260-
log::error!("Failed to receive VAD event: {}", e);
261-
None
262-
}
263-
}
264-
}
265-
}
266-
} else {
267-
recv_client_event(socket)
268-
.await
269-
.map(RealtimeEvent::ClientEvent)
270-
}
271-
}
272-
273-
while let Some(event) = select_event(&mut receiver, &mut realtime_rx).await {
246+
while let Some(event) = recv_client_event(&mut receiver)
247+
.await
248+
.map(RealtimeEvent::ClientEvent)
249+
{
274250
if let Err(e) = handle_client_message(
275251
event,
276252
&mut session,
@@ -360,14 +336,14 @@ async fn handle_client_message(
360336
return Ok(());
361337
}
362338
if turn_detection.turn_type == TurnDetectionType::ServerVad
363-
&& session.vad_realtime_client.is_none()
339+
&& session.vad_session.is_none()
364340
{
365341
let error_event = ServerEvent::Error {
366342
event_id: Uuid::new_v4().to_string(),
367343
error: ErrorDetails {
368344
error_type: "invalid_request_error".to_string(),
369-
code: Some("vad_realtime_not_connected".to_string()),
370-
message: "VAD realtime service is not connected".to_string(),
345+
code: Some("vad_not_available".to_string()),
346+
message: "VAD session is not available".to_string(),
371347
param: Some("turn_detection.type".to_string()),
372348
event_id: None,
373349
},
@@ -433,11 +409,11 @@ async fn handle_client_message(
433409
.as_ref()
434410
.map(|t| t.turn_type == TurnDetectionType::ServerVad)
435411
.unwrap_or_default()
436-
&& session.vad_realtime_client.is_some();
412+
&& session.vad_session.is_some();
437413

438414
log::debug!(
439415
"Server VAD status: {} {:?}",
440-
session.vad_realtime_client.is_some(),
416+
session.vad_session.is_some(),
441417
session.config.turn_detection
442418
);
443419

@@ -473,26 +449,55 @@ async fn handle_client_message(
473449
}
474450
}
475451

452+
// Process audio through built-in silero VAD
476453
if server_vad {
477-
let samples_24k = audio_data
478-
.chunks_exact(2)
479-
.map(|chunk| {
480-
i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32
481-
})
482-
.collect::<Vec<f32>>();
483-
let sample_16k = wav_io::resample::linear(samples_24k, 1, 24000, 16000);
484-
485-
let sample_16k = crate::util::convert_samples_f32_to_i16_bytes(&sample_16k);
486-
log::debug!(
487-
"Sending audio chunk to VAD realtime service, length: {}",
488-
sample_16k.len()
489-
);
490-
session
491-
.vad_realtime_client
492-
.as_mut()
493-
.unwrap()
494-
.push_audio_16k_chunk(Bytes::from(sample_16k))
495-
.await?;
454+
if let Some(vad_session) = session.vad_session.as_mut() {
455+
// Convert 24kHz PCM16 to 16kHz f32 for VAD
456+
let samples_24k: Vec<f32> = audio_data
457+
.chunks_exact(2)
458+
.map(|chunk| {
459+
i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32
460+
})
461+
.collect();
462+
let samples_16k =
463+
wav_io::resample::linear(samples_24k, 1, 24000, 16000);
464+
465+
// Process through VAD in chunks
466+
let chunk_size = VadSession::vad_chunk_size();
467+
let mut speech_detected = false;
468+
for chunk in samples_16k.chunks(chunk_size) {
469+
if let Ok(is_speech) = vad_session.detect(chunk) {
470+
if is_speech {
471+
speech_detected = true;
472+
} else if session.triggered && !is_speech {
473+
// Speech ended - trigger commit
474+
log::info!("VAD detected speech end, triggering commit");
475+
if handle_audio_buffer_commit(session, tx, None, asr)
476+
.await?
477+
{
478+
generate_response(session, tx, tts).await?;
479+
}
480+
session.triggered = false;
481+
if let Some(vs) = session.vad_session.as_mut() {
482+
vs.reset_state();
483+
}
484+
break;
485+
}
486+
}
487+
}
488+
489+
if speech_detected && !session.triggered {
490+
log::info!("VAD detected speech start");
491+
session.triggered = true;
492+
// Send speech started event
493+
let event = ServerEvent::InputAudioBufferSpeechStarted {
494+
event_id: Uuid::new_v4().to_string(),
495+
audio_start_ms: 0,
496+
item_id: Uuid::new_v4().to_string(),
497+
};
498+
let _ = tx.send(event).await;
499+
}
500+
}
496501
}
497502
}
498503

@@ -630,28 +635,6 @@ async fn handle_client_message(
630635
}
631636
}
632637
}
633-
RealtimeEvent::VadEvent(vad_realtime_event) => match vad_realtime_event {
634-
VadRealtimeEvent::Event { event } => match event.as_str() {
635-
"speech_start" => {
636-
log::debug!("VAD speech start detected");
637-
session.triggered = true;
638-
}
639-
"speech_end" => {
640-
log::debug!("VAD speech end detected");
641-
session.triggered = false;
642-
if handle_audio_buffer_commit(session, tx, None, asr).await? {
643-
log::debug!("Audio buffer committed, generating response");
644-
generate_response(session, tx, tts).await?;
645-
}
646-
}
647-
_ => {
648-
log::warn!("Unhandled VAD event: {}", event);
649-
}
650-
},
651-
VadRealtimeEvent::Error { message, .. } => {
652-
return Err(anyhow::anyhow!("VAD error: {}", message));
653-
}
654-
},
655638
}
656639

657640
Ok(())
@@ -682,9 +665,30 @@ async fn handle_audio_buffer_commit(
682665
};
683666
let _ = tx.send(committed_event).await;
684667

685-
if let Some(vad_url) = &config.vad_url {
686-
let vad = crate::ai::vad::vad_detect(&session.client, vad_url, wav_audio.clone()).await?;
687-
if vad.timestamps.is_empty() {
668+
// Check for speech using built-in silero VAD
669+
if let Some(vad_session) = session.vad_session.as_mut() {
670+
// Convert 24kHz PCM16 to 16kHz f32 for VAD
671+
let samples_24k: Vec<f32> = audio_data
672+
.chunks_exact(2)
673+
.map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32)
674+
.collect();
675+
let samples_16k = wav_io::resample::linear(samples_24k, 1, 24000, 16000);
676+
677+
// Process through VAD to check if there's any speech
678+
let chunk_size = VadSession::vad_chunk_size();
679+
let mut has_speech = false;
680+
vad_session.reset_state();
681+
for chunk in samples_16k.chunks(chunk_size) {
682+
if let Ok(is_speech) = vad_session.detect(chunk) {
683+
if is_speech {
684+
has_speech = true;
685+
break;
686+
}
687+
}
688+
}
689+
690+
if !has_speech {
691+
log::debug!("No speech detected in audio buffer, skipping ASR");
688692
let transcription_completed =
689693
ServerEvent::ConversationItemInputAudioTranscriptionCompleted {
690694
event_id: Uuid::new_v4().to_string(),

0 commit comments

Comments
 (0)