Skip to content

Commit 50a9bdf

Browse files
committed
feat: Add Voice Activity Detection (VAD) support
1 parent d8f202f commit 50a9bdf

2 files changed

Lines changed: 166 additions & 9 deletions

File tree

pywhispercpp/constants.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,11 @@
171171
'default': 0
172172
},
173173
'initial_prompt': {
174-
'type': str,
175-
'description': "Initial prompt, these are prepended to any existing text context from a previous call",
176-
'options': None,
177-
'default': None
178-
},
174+
'type': str,
175+
'description': "Initial prompt, these are prepended to any existing text context from a previous call",
176+
'options': None,
177+
'default': None
178+
},
179179
'prompt_tokens': {
180180
'type': Tuple,
181181
'description': "tokens to provide to the whisper decoder as initial prompt",
@@ -265,5 +265,17 @@
265265
'description': 'calculate the geometric mean of token probabilities for each segment.',
266266
'options': None,
267267
'default': True
268+
},
269+
'vad': {
270+
'type': bool,
271+
'description': 'Enable VAD',
272+
'options': None,
273+
'default': False
274+
},
275+
'vad_model_path': {
276+
'type': str,
277+
'description': 'Path to VAD model',
278+
'options': None,
279+
'default': None
268280
}
269281
}

src/main.cpp

Lines changed: 149 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,8 @@ py::function py_logits_filter_callback;
3838
// Thanks to https://github.com/pybind/pybind11/issues/2770
3939
struct whisper_context_wrapper {
4040
whisper_context* ptr;
41-
4241
};
4342

44-
4543
// struct inside params
4644
struct greedy{
4745
int best_of;
@@ -299,14 +297,18 @@ int whisper_ctx_init_openvino_encoder_wrapper(struct whisper_context_wrapper * c
299297
struct WhisperFullParamsWrapper : public whisper_full_params {
300298
std::string initial_prompt_str;
301299
std::string suppress_regex_str;
300+
std::string vad_model_path_str;
302301
public:
303302
py::function py_progress_callback;
304303
WhisperFullParamsWrapper(const whisper_full_params& params = whisper_full_params())
305304
: whisper_full_params(params),
306305
initial_prompt_str(params.initial_prompt ? params.initial_prompt : ""),
307-
suppress_regex_str(params.suppress_regex ? params.suppress_regex : "") {
306+
suppress_regex_str(params.suppress_regex ? params.suppress_regex : ""),
307+
vad_model_path_str(params.vad_model_path ? params.vad_model_path : "")
308+
{
308309
initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str();
309310
suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str();
311+
vad_model_path = vad_model_path_str.empty() ? nullptr : vad_model_path_str.c_str();
310312
// progress callback
311313
progress_callback_user_data = this;
312314
progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
@@ -327,10 +329,12 @@ struct WhisperFullParamsWrapper : public whisper_full_params {
327329
: whisper_full_params(static_cast<whisper_full_params>(other)), // Copy base struct
328330
initial_prompt_str(other.initial_prompt_str),
329331
suppress_regex_str(other.suppress_regex_str),
332+
vad_model_path_str(other.vad_model_path_str),
330333
py_progress_callback(other.py_progress_callback) {
331334
// Reset pointers to new string copies
332335
initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str();
333336
suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str();
337+
vad_model_path = vad_model_path_str.empty() ? nullptr : vad_model_path_str.c_str();
334338
progress_callback_user_data = this;
335339
progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
336340
auto* self = static_cast<WhisperFullParamsWrapper*>(user_data);
@@ -354,6 +358,10 @@ struct WhisperFullParamsWrapper : public whisper_full_params {
354358
suppress_regex_str = regex;
355359
suppress_regex = suppress_regex_str.c_str();
356360
}
361+
void set_vad_model_path(const std::string& model_path) {
362+
vad_model_path_str = model_path;
363+
vad_model_path = vad_model_path_str.c_str();
364+
}
357365
};
358366
WhisperFullParamsWrapper whisper_full_default_params_wrapper(enum whisper_sampling_strategy strategy) {
359367
return WhisperFullParamsWrapper(whisper_full_default_params(strategy));
@@ -411,6 +419,99 @@ py::dict get_greedy(whisper_full_params * params){
411419
return d;
412420
}
413421

422+
423+
// Voice Activity Detection (VAD)
424+
struct whisper_vad_context_wrapper {
425+
whisper_vad_context* ptr;
426+
};
427+
428+
struct whisper_vad_context_wrapper whisper_vad_init_from_file_with_params_wrapper(const char * path_model, struct whisper_vad_context_params params){
429+
struct whisper_vad_context * ctx = whisper_vad_init_from_file_with_params(path_model, params);
430+
struct whisper_vad_context_wrapper ctw_w;
431+
ctw_w.ptr = ctx;
432+
return ctw_w;
433+
}
434+
435+
bool whisper_vad_detect_speech_wrapper(
436+
struct whisper_vad_context_wrapper * ctx,
437+
py::array_t<float> samples,
438+
int n_samples){
439+
py::buffer_info buf = samples.request();
440+
float *samples_ptr = static_cast<float *>(buf.ptr);
441+
442+
py::gil_scoped_release release;
443+
return whisper_vad_detect_speech(ctx->ptr, samples_ptr, n_samples);
444+
}
445+
446+
int whisper_vad_n_probs_wrapper(struct whisper_vad_context_wrapper * ctx){
447+
return whisper_vad_n_probs(ctx->ptr);
448+
}
449+
450+
py::array_t<float> whisper_vad_probs_wrapper(struct whisper_vad_context_wrapper * ctx) {
451+
float * probs_ptr = whisper_vad_probs(ctx->ptr);
452+
int n_probs = whisper_vad_n_probs(ctx->ptr);
453+
454+
if (probs_ptr == nullptr || n_probs <= 0) {
455+
return py::array_t<float>(0);
456+
}
457+
return py::array_t<float>(
458+
{n_probs},
459+
{sizeof(float)},
460+
probs_ptr
461+
);
462+
}
463+
464+
struct whisper_vad_segments_wrapper {
465+
struct whisper_vad_segments * ptr;
466+
};
467+
468+
struct whisper_vad_segments_wrapper whisper_vad_segments_from_probs_wrapper(
469+
struct whisper_vad_context_wrapper * vctx_w,
470+
struct whisper_vad_params params
471+
){
472+
struct whisper_vad_segments * wvs = whisper_vad_segments_from_probs(vctx_w->ptr, params);
473+
struct whisper_vad_segments_wrapper wvs_w;
474+
wvs_w.ptr = wvs;
475+
return wvs_w;
476+
}
477+
478+
struct whisper_vad_segments_wrapper whisper_vad_segments_from_samples_wrapper(
479+
struct whisper_vad_context_wrapper * vctx_w,
480+
struct whisper_vad_params params,
481+
py::array_t<float> samples,
482+
int n_samples){
483+
484+
py::buffer_info buf = samples.request();
485+
float *samples_ptr = static_cast<float *>(buf.ptr);
486+
487+
struct whisper_vad_segments * wvs = whisper_vad_segments_from_samples(vctx_w->ptr, params, samples_ptr, n_samples);
488+
struct whisper_vad_segments_wrapper wvs_w;
489+
wvs_w.ptr = wvs;
490+
return wvs_w;
491+
}
492+
493+
int whisper_vad_segments_n_segments_wrapper(struct whisper_vad_segments_wrapper * segments_wrapper){
494+
return whisper_vad_segments_n_segments(segments_wrapper->ptr);
495+
}
496+
497+
float whisper_vad_segments_get_segment_t0_wrapper(struct whisper_vad_segments_wrapper * segments_wrapper, int i_segment) {
498+
return whisper_vad_segments_get_segment_t0(segments_wrapper->ptr, i_segment);
499+
}
500+
501+
float whisper_vad_segments_get_segment_t1_wrapper(struct whisper_vad_segments_wrapper * segments_wrapper, int i_segment) {
502+
return whisper_vad_segments_get_segment_t1(segments_wrapper->ptr, i_segment);
503+
}
504+
505+
void whisper_vad_free_segments_wrapper(struct whisper_vad_segments_wrapper * segments_wrapper){
506+
return whisper_vad_free_segments(segments_wrapper->ptr);
507+
}
508+
509+
void whisper_vad_free_wrapper(struct whisper_vad_context_wrapper * ctx_w){
510+
return whisper_vad_free(ctx_w->ptr);
511+
}
512+
513+
////////////
514+
414515
PYBIND11_MODULE(_pywhispercpp, m) {
415516
m.doc() = R"pbdoc(
416517
Pywhispercpp: Python binding to whisper.cpp
@@ -665,7 +766,17 @@ PYBIND11_MODULE(_pywhispercpp, m) {
665766
[](WhisperFullParamsWrapper &self, py::dict dict) {self.beam_search.beam_size = dict["beam_size"].cast<int>(); self.beam_search.patience = dict["patience"].cast<float>();})
666767
.def_readwrite("new_segment_callback_user_data", &WhisperFullParamsWrapper::new_segment_callback_user_data)
667768
.def_readwrite("encoder_begin_callback_user_data", &WhisperFullParamsWrapper::encoder_begin_callback_user_data)
668-
.def_readwrite("logits_filter_callback_user_data", &WhisperFullParamsWrapper::logits_filter_callback_user_data);
769+
.def_readwrite("logits_filter_callback_user_data", &WhisperFullParamsWrapper::logits_filter_callback_user_data)
770+
.def_readwrite("vad", &WhisperFullParamsWrapper::vad)
771+
.def_property("vad_model_path",
772+
[](WhisperFullParamsWrapper &self) {
773+
return py::str(self.vad_model_path ? self.vad_model_path : "");
774+
},
775+
[](WhisperFullParamsWrapper &self, const std::string &vad_model_path) {
776+
self.set_vad_model_path(vad_model_path);
777+
}
778+
)
779+
.def_readwrite("vad_params", &WhisperFullParamsWrapper::vad_params);
669780

670781

671782
py::implicitly_convertible<whisper_full_params, WhisperFullParamsWrapper>();
@@ -718,6 +829,40 @@ PYBIND11_MODULE(_pywhispercpp, m) {
718829
m.def("assign_logits_filter_callback", &assign_logits_filter_callback, "Assigns a logits_filter_callback, takes <whisper_full_params> instance and a callable function with the same parameters which are defined in the interface",
719830
py::arg("params"), py::arg("callback"));
720831

832+
// VAD
833+
py::class_<whisper_vad_params>(m,"whisper_vad_params")
834+
.def(py::init<>())
835+
.def_readwrite("threshold", &whisper_vad_params::threshold)
836+
.def_readwrite("min_speech_duration_ms", &whisper_vad_params::min_speech_duration_ms)
837+
.def_readwrite("min_silence_duration_ms", &whisper_vad_params::min_silence_duration_ms)
838+
.def_readwrite("max_speech_duration_s", &whisper_vad_params::max_speech_duration_s)
839+
.def_readwrite("speech_pad_ms", &whisper_vad_params::speech_pad_ms)
840+
.def_readwrite("samples_overlap", &whisper_vad_params::samples_overlap);
841+
842+
m.def("whisper_vad_default_params", &whisper_vad_default_params);
843+
844+
py::class_<whisper_vad_context_params>(m,"whisper_vad_context_params")
845+
.def(py::init<>())
846+
.def_readwrite("n_threads", &whisper_vad_context_params::n_threads)
847+
.def_readwrite("use_gpu", &whisper_vad_context_params::use_gpu)
848+
.def_readwrite("gpu_device", &whisper_vad_context_params::gpu_device);
849+
850+
m.def("whisper_vad_default_context_params", &whisper_vad_default_context_params);
851+
m.def("whisper_vad_init_from_file_with_params", &whisper_vad_init_from_file_with_params_wrapper);
852+
m.def("whisper_vad_detect_speech", &whisper_vad_detect_speech_wrapper);
853+
m.def("whisper_vad_n_probs", &whisper_vad_n_probs_wrapper);
854+
m.def("whisper_vad_probs", &whisper_vad_probs_wrapper);
855+
py::class_<whisper_vad_segments_wrapper>(m, "whisper_vad_segments");
856+
m.def("whisper_vad_segments_from_probs", &whisper_vad_segments_from_probs_wrapper);
857+
m.def("whisper_vad_segments_from_samples", &whisper_vad_segments_from_samples_wrapper);
858+
m.def("whisper_vad_segments_n_segments", &whisper_vad_segments_n_segments_wrapper);
859+
m.def("whisper_vad_segments_get_segment_t0", &whisper_vad_segments_get_segment_t0_wrapper);
860+
m.def("whisper_vad_segments_get_segment_t1", &whisper_vad_segments_get_segment_t1_wrapper);
861+
m.def("whisper_vad_free_segments", &whisper_vad_free_segments_wrapper);
862+
m.def("whisper_vad_free", &whisper_vad_free_wrapper);
863+
864+
865+
721866

722867
#ifdef VERSION_INFO
723868
m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);

0 commit comments

Comments
 (0)