Skip to content

Commit d6a9608

Browse files
committed
feat: embeding silero_vad for whisper
1 parent 041ce61 commit d6a9608

12 files changed

Lines changed: 7098 additions & 1431 deletions

File tree

Cargo.lock

Lines changed: 6626 additions & 1305 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ rmcp = { version = "0.1.5", features = [
5151
"transport-streamable-http-client",
5252
"reqwest",
5353
], default-features = false, git = "https://github.com/modelcontextprotocol/rust-sdk", rev = "b9d7d61" } # branch = "main"
54+
5455
base64 = "0.22.1"
5556
reqwest-websocket = "0.5.0"
5657
futures-util = "0.3.31"
@@ -59,3 +60,7 @@ tower = { version = "0.5.2", features = ["util"] }
5960
tower-http = { version = "0.6.1", features = ["fs", "trace"] }
6061

6162
chrono = "0.4.41"
63+
64+
# vad
65+
silero_vad_burn = "0.1.1"
66+
burn = { version = "0.20", features = ["ndarray"] }

src/ai/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ pub async fn llm_stable<'p, I: IntoIterator<Item = C>, C: AsRef<llm::Content>>(
473473
serde_json::to_string_pretty(&serde_json::json!(
474474
{
475475
"stream": true,
476-
"messages": messages,
476+
"last_message": messages.last(),
477477
"model": model.to_string(),
478478
"tools": tool_name,
479479
"extra": extra,

src/ai/vad.rs

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use futures_util::{
2-
stream::{SplitSink, SplitStream},
32
SinkExt, StreamExt,
3+
stream::{SplitSink, SplitStream},
44
};
55
use reqwest::multipart::Part;
66
use reqwest_websocket::{RequestBuilderExt, WebSocket};
@@ -101,3 +101,109 @@ impl VadRealtimeRx {
101101
}
102102
}
103103
}
104+
105+
pub type VadParams = crate::config::SileroVadconfig;
106+
107+
#[derive(Clone)]
108+
pub struct SileroVADFactory {
109+
device: burn::backend::ndarray::NdArrayDevice,
110+
params: VadParams,
111+
}
112+
113+
impl SileroVADFactory {
114+
pub fn new(params: VadParams) -> anyhow::Result<Self> {
115+
let device = burn::backend::ndarray::NdArrayDevice::default();
116+
117+
Ok(SileroVADFactory { device, params })
118+
}
119+
120+
pub fn create_session(&self) -> anyhow::Result<VadSession> {
121+
let vad = Box::new(silero_vad_burn::SileroVAD6Model::new(&self.device)?);
122+
VadSession::new(&self.params, vad, self.device.clone())
123+
}
124+
}
125+
126+
pub struct VadSession {
127+
vad: Box<silero_vad_burn::SileroVAD6Model<burn::backend::NdArray>>,
128+
state: Option<silero_vad_burn::PredictState<burn::backend::NdArray>>,
129+
device: burn::backend::ndarray::NdArrayDevice,
130+
131+
in_speech: bool,
132+
133+
threshold: f32,
134+
neg_threshold: f32,
135+
136+
silence_chunk_count: usize,
137+
max_silence_chunks: usize,
138+
}
139+
140+
impl VadSession {
141+
const SAMPLE_RATE: usize = 16000;
142+
143+
pub fn new(
144+
params: &VadParams,
145+
vad: Box<silero_vad_burn::SileroVAD6Model<burn::backend::NdArray>>,
146+
device: burn::backend::ndarray::NdArrayDevice,
147+
) -> anyhow::Result<Self> {
148+
let state = Some(silero_vad_burn::PredictState::default(&device));
149+
150+
let neg_threshold = params
151+
.neg_threshold
152+
.unwrap_or_else(|| params.threshold - 0.15)
153+
.max(0.05);
154+
155+
let threshold = params.threshold.min(0.95);
156+
let max_silence_chunks = params.max_silence_duration_ms * (Self::SAMPLE_RATE / 1000)
157+
/ silero_vad_burn::CHUNK_SIZE;
158+
159+
Ok(VadSession {
160+
vad,
161+
state,
162+
device,
163+
164+
in_speech: false,
165+
threshold,
166+
neg_threshold,
167+
168+
silence_chunk_count: 0,
169+
max_silence_chunks,
170+
})
171+
}
172+
173+
pub fn reset_state(&mut self) {
174+
self.state = Some(silero_vad_burn::PredictState::default(&self.device));
175+
self.in_speech = false;
176+
self.silence_chunk_count = 0;
177+
}
178+
179+
pub fn detect(&mut self, audio16k_chunk_512: &[f32]) -> anyhow::Result<bool> {
180+
debug_assert!(
181+
audio16k_chunk_512.len() <= 512,
182+
"audio16k_chunk_512 length must be less than 512",
183+
);
184+
185+
let audio_tensor =
186+
burn::Tensor::<_, 1>::from_floats(audio16k_chunk_512, &self.device).unsqueeze();
187+
let (state, prob) = self.vad.predict(self.state.take().unwrap(), audio_tensor)?;
188+
self.state = Some(state);
189+
190+
let prob: Vec<f32> = prob.to_data().to_vec()?;
191+
192+
if prob[0] > self.threshold {
193+
self.in_speech = true;
194+
self.silence_chunk_count = 0;
195+
} else if prob[0] < self.neg_threshold {
196+
self.silence_chunk_count += 1;
197+
if self.silence_chunk_count >= self.max_silence_chunks {
198+
self.in_speech = false;
199+
}
200+
} else {
201+
}
202+
203+
Ok(self.in_speech)
204+
}
205+
206+
pub const fn vad_chunk_size() -> usize {
207+
silero_vad_burn::CHUNK_SIZE
208+
}
209+
}

src/config.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,50 @@ pub enum TTSConfig {
242242
Elevenlabs(ElevenlabsTTS),
243243
}
244244

245+
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
246+
pub struct SileroVadconfig {
247+
pub threshold: f32,
248+
pub neg_threshold: Option<f32>,
249+
250+
pub min_speech_duration_ms: usize,
251+
pub max_silence_duration_ms: usize,
252+
pub audio_cache_ms: usize,
253+
}
254+
255+
impl SileroVadconfig {
256+
pub fn default_threshold() -> f32 {
257+
0.5
258+
}
259+
260+
pub fn default_neg_threshold() -> Option<f32> {
261+
None
262+
}
263+
264+
pub fn default_min_speech_duration_ms() -> usize {
265+
150
266+
}
267+
268+
pub fn default_max_silence_duration_ms() -> usize {
269+
400
270+
}
271+
272+
pub fn default_audio_cache_ms() -> usize {
273+
1000
274+
}
275+
}
276+
277+
impl Default for SileroVadconfig {
278+
fn default() -> Self {
279+
SileroVadconfig {
280+
threshold: Self::default_threshold(),
281+
neg_threshold: Self::default_neg_threshold(),
282+
min_speech_duration_ms: Self::default_min_speech_duration_ms(),
283+
max_silence_duration_ms: Self::default_max_silence_duration_ms(),
284+
audio_cache_ms: Self::default_audio_cache_ms(),
285+
}
286+
}
287+
}
288+
245289
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
246290
pub struct WhisperASRConfig {
247291
pub url: String,
@@ -253,8 +297,14 @@ pub struct WhisperASRConfig {
253297
pub model: String,
254298
#[serde(default)]
255299
pub prompt: String,
300+
301+
#[serde(default)]
302+
pub vad: SileroVadconfig,
303+
304+
#[deprecated]
256305
#[serde(default)]
257306
pub vad_url: Option<String>,
307+
#[deprecated]
258308
#[serde(default)]
259309
pub vad_realtime_url: Option<String>,
260310
}

src/main.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,10 @@ async fn routes(
186186
router.route(
187187
"/version",
188188
get(|| async {
189-
serde_json::to_string_pretty(&serde_json::json!(
189+
axum::response::Json(serde_json::json!(
190190
{
191191
"version": env!("CARGO_PKG_VERSION"),
192192
}))
193-
.unwrap()
194193
}),
195194
)
196195
}

src/protocol.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ pub enum ServerEvent {
1616
StartVideo,
1717
EndVideo,
1818
EndResponse,
19+
20+
EndVad,
1921
}
2022

2123
#[test]

src/services/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ pub struct ConnectQueryParams {
2121
opus: bool,
2222
#[serde(default)]
2323
vowel: bool,
24+
#[serde(default)]
25+
server_vad: bool,
2426
}
2527

2628
pub async fn v2_mixed_handler(
@@ -43,6 +45,7 @@ pub async fn v2_mixed_handler(
4345
reconnect: params.reconnect,
4446
opus: params.opus,
4547
vowel: params.vowel,
48+
server_vad: params.server_vad,
4649
}),
4750
)
4851
.await

0 commit comments

Comments
 (0)