|
1 | 1 | use futures_util::{ |
2 | | - stream::{SplitSink, SplitStream}, |
3 | 2 | SinkExt, StreamExt, |
| 3 | + stream::{SplitSink, SplitStream}, |
4 | 4 | }; |
5 | 5 | use reqwest::multipart::Part; |
6 | 6 | use reqwest_websocket::{RequestBuilderExt, WebSocket}; |
@@ -101,3 +101,109 @@ impl VadRealtimeRx { |
101 | 101 | } |
102 | 102 | } |
103 | 103 | } |
| 104 | + |
| 105 | +pub type VadParams = crate::config::SileroVadconfig; |
| 106 | + |
| 107 | +#[derive(Clone)] |
| 108 | +pub struct SileroVADFactory { |
| 109 | + device: burn::backend::ndarray::NdArrayDevice, |
| 110 | + params: VadParams, |
| 111 | +} |
| 112 | + |
| 113 | +impl SileroVADFactory { |
| 114 | + pub fn new(params: VadParams) -> anyhow::Result<Self> { |
| 115 | + let device = burn::backend::ndarray::NdArrayDevice::default(); |
| 116 | + |
| 117 | + Ok(SileroVADFactory { device, params }) |
| 118 | + } |
| 119 | + |
| 120 | + pub fn create_session(&self) -> anyhow::Result<VadSession> { |
| 121 | + let vad = Box::new(silero_vad_burn::SileroVAD6Model::new(&self.device)?); |
| 122 | + VadSession::new(&self.params, vad, self.device.clone()) |
| 123 | + } |
| 124 | +} |
| 125 | + |
| 126 | +pub struct VadSession { |
| 127 | + vad: Box<silero_vad_burn::SileroVAD6Model<burn::backend::NdArray>>, |
| 128 | + state: Option<silero_vad_burn::PredictState<burn::backend::NdArray>>, |
| 129 | + device: burn::backend::ndarray::NdArrayDevice, |
| 130 | + |
| 131 | + in_speech: bool, |
| 132 | + |
| 133 | + threshold: f32, |
| 134 | + neg_threshold: f32, |
| 135 | + |
| 136 | + silence_chunk_count: usize, |
| 137 | + max_silence_chunks: usize, |
| 138 | +} |
| 139 | + |
| 140 | +impl VadSession { |
| 141 | + const SAMPLE_RATE: usize = 16000; |
| 142 | + |
| 143 | + pub fn new( |
| 144 | + params: &VadParams, |
| 145 | + vad: Box<silero_vad_burn::SileroVAD6Model<burn::backend::NdArray>>, |
| 146 | + device: burn::backend::ndarray::NdArrayDevice, |
| 147 | + ) -> anyhow::Result<Self> { |
| 148 | + let state = Some(silero_vad_burn::PredictState::default(&device)); |
| 149 | + |
| 150 | + let neg_threshold = params |
| 151 | + .neg_threshold |
| 152 | + .unwrap_or_else(|| params.threshold - 0.15) |
| 153 | + .max(0.05); |
| 154 | + |
| 155 | + let threshold = params.threshold.min(0.95); |
| 156 | + let max_silence_chunks = params.max_silence_duration_ms * (Self::SAMPLE_RATE / 1000) |
| 157 | + / silero_vad_burn::CHUNK_SIZE; |
| 158 | + |
| 159 | + Ok(VadSession { |
| 160 | + vad, |
| 161 | + state, |
| 162 | + device, |
| 163 | + |
| 164 | + in_speech: false, |
| 165 | + threshold, |
| 166 | + neg_threshold, |
| 167 | + |
| 168 | + silence_chunk_count: 0, |
| 169 | + max_silence_chunks, |
| 170 | + }) |
| 171 | + } |
| 172 | + |
| 173 | + pub fn reset_state(&mut self) { |
| 174 | + self.state = Some(silero_vad_burn::PredictState::default(&self.device)); |
| 175 | + self.in_speech = false; |
| 176 | + self.silence_chunk_count = 0; |
| 177 | + } |
| 178 | + |
| 179 | + pub fn detect(&mut self, audio16k_chunk_512: &[f32]) -> anyhow::Result<bool> { |
| 180 | + debug_assert!( |
| 181 | + audio16k_chunk_512.len() <= 512, |
| 182 | + "audio16k_chunk_512 length must be less than 512", |
| 183 | + ); |
| 184 | + |
| 185 | + let audio_tensor = |
| 186 | + burn::Tensor::<_, 1>::from_floats(audio16k_chunk_512, &self.device).unsqueeze(); |
| 187 | + let (state, prob) = self.vad.predict(self.state.take().unwrap(), audio_tensor)?; |
| 188 | + self.state = Some(state); |
| 189 | + |
| 190 | + let prob: Vec<f32> = prob.to_data().to_vec()?; |
| 191 | + |
| 192 | + if prob[0] > self.threshold { |
| 193 | + self.in_speech = true; |
| 194 | + self.silence_chunk_count = 0; |
| 195 | + } else if prob[0] < self.neg_threshold { |
| 196 | + self.silence_chunk_count += 1; |
| 197 | + if self.silence_chunk_count >= self.max_silence_chunks { |
| 198 | + self.in_speech = false; |
| 199 | + } |
| 200 | + } else { |
| 201 | + } |
| 202 | + |
| 203 | + Ok(self.in_speech) |
| 204 | + } |
| 205 | + |
| 206 | + pub const fn vad_chunk_size() -> usize { |
| 207 | + silero_vad_burn::CHUNK_SIZE |
| 208 | + } |
| 209 | +} |
0 commit comments