@@ -295,7 +295,7 @@ int whisper_ctx_init_openvino_encoder_wrapper(struct whisper_context_wrapper * c
295295 const char * cache_dir){
296296 return whisper_ctx_init_openvino_encoder (ctx->ptr , model_path, device, cache_dir);
297297}
298-
298+ /*
299299class WhisperFullParamsWrapper : public whisper_full_params {
300300 std::string initial_prompt_str;
301301 std::string suppress_regex_str;
@@ -338,6 +338,66 @@ class WhisperFullParamsWrapper : public whisper_full_params {
338338 }
339339};
340340
341+ */
342+ struct WhisperFullParamsWrapper : public whisper_full_params {
343+ std::string initial_prompt_str;
344+ std::string suppress_regex_str;
345+ public:
346+ py::function py_progress_callback;
347+ WhisperFullParamsWrapper (const whisper_full_params& params = whisper_full_params())
348+ : whisper_full_params(params),
349+ initial_prompt_str (params.initial_prompt ? params.initial_prompt : " " ),
350+ suppress_regex_str(params.suppress_regex ? params.suppress_regex : " " ) {
351+ initial_prompt = initial_prompt_str.empty () ? nullptr : initial_prompt_str.c_str ();
352+ suppress_regex = suppress_regex_str.empty () ? nullptr : suppress_regex_str.c_str ();
353+ // progress callback
354+ progress_callback_user_data = this ;
355+ progress_callback = [](struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data) {
356+ auto * self = static_cast <WhisperFullParamsWrapper*>(user_data);
357+ if (self && self->print_progress ){
358+ if (self->py_progress_callback ) {
359+ // call the python callback
360+ py::gil_scoped_acquire gil;
361+ self->py_progress_callback (progress); // Call Python callback
362+ }
363+ else {
364+ fprintf (stderr, " Progress: %3d%%\n " , progress);
365+ } // Default message
366+ }
367+ } ;
368+ }
369+ WhisperFullParamsWrapper (const WhisperFullParamsWrapper& other)
370+ : whisper_full_params(static_cast <whisper_full_params>(other)), // Copy base struct
371+ initial_prompt_str(other.initial_prompt_str),
372+ suppress_regex_str(other.suppress_regex_str),
373+ py_progress_callback(other.py_progress_callback) {
374+ // Reset pointers to new string copies
375+ initial_prompt = initial_prompt_str.empty () ? nullptr : initial_prompt_str.c_str ();
376+ suppress_regex = suppress_regex_str.empty () ? nullptr : suppress_regex_str.c_str ();
377+ progress_callback_user_data = this ;
378+ progress_callback = [](struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data) {
379+ auto * self = static_cast <WhisperFullParamsWrapper*>(user_data);
380+ if (self && self->print_progress ){
381+ if (self->py_progress_callback ) {
382+ // call the python callback
383+ py::gil_scoped_acquire gil;
384+ self->py_progress_callback (progress); // Call Python callback
385+ }
386+ else {
387+ fprintf (stderr, " Progress: %3d%%\n " , progress);
388+ } // Default message
389+ }
390+ };
391+ }
392+ void set_initial_prompt (const std::string& prompt) {
393+ initial_prompt_str = prompt;
394+ initial_prompt = initial_prompt_str.c_str ();
395+ }
396+ void set_suppress_regex (const std::string& regex) {
397+ suppress_regex_str = regex;
398+ suppress_regex = suppress_regex_str.c_str ();
399+ }
400+ };
341401WhisperFullParamsWrapper whisper_full_default_params_wrapper (enum whisper_sampling_strategy strategy) {
342402 return WhisperFullParamsWrapper (whisper_full_default_params (strategy));
343403}
0 commit comments