@@ -296,48 +296,65 @@ int whisper_ctx_init_openvino_encoder_wrapper(struct whisper_context_wrapper * c
296296 return whisper_ctx_init_openvino_encoder (ctx->ptr , model_path, device, cache_dir);
297297}
298298
299- class WhisperFullParamsWrapper : public whisper_full_params {
300- std::string initial_prompt_str;
301- std::string suppress_regex_str;
299+ struct WhisperFullParamsWrapper : public whisper_full_params {
300+ std::string initial_prompt_str;
301+ std::string suppress_regex_str;
302302public:
303- py::function py_progress_callback;
304- WhisperFullParamsWrapper (const whisper_full_params& params = whisper_full_params())
305- : whisper_full_params(params),
306- initial_prompt_str (params.initial_prompt ? params.initial_prompt : " " ),
307- suppress_regex_str(params.suppress_regex ? params.suppress_regex : " " ) {
308- initial_prompt = initial_prompt_str.empty () ? nullptr : initial_prompt_str.c_str ();
309- suppress_regex = suppress_regex_str.empty () ? nullptr : suppress_regex_str.c_str ();
310- // progress callback
311- progress_callback_user_data = this ;
312- progress_callback = [](struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data) {
313- auto * self = static_cast <WhisperFullParamsWrapper*>(user_data);
314- if (self && self->print_progress ){
315- if (self->py_progress_callback ) {
316- // call the python callback
317- py::gil_scoped_acquire gil;
318- self->py_progress_callback (progress); // Call Python callback
319- }
320- else {
321- fprintf (stderr, " Progress: %3d%%\n " , progress);
322- } // Default message
323- }
324- } ;
325- }
326-
327- WhisperFullParamsWrapper (const WhisperFullParamsWrapper& other)
328- : WhisperFullParamsWrapper(static_cast <const whisper_full_params&>(other)) {}
329-
330- void set_initial_prompt (const std::string& prompt) {
331- initial_prompt_str = prompt;
332- initial_prompt = initial_prompt_str.c_str ();
333- }
334-
335- void set_suppress_regex (const std::string& regex) {
336- suppress_regex_str = regex;
337- suppress_regex = suppress_regex_str.c_str ();
338- }
303+ py::function py_progress_callback;
304+ WhisperFullParamsWrapper (const whisper_full_params& params = whisper_full_params())
305+ : whisper_full_params(params),
306+ initial_prompt_str (params.initial_prompt ? params.initial_prompt : " " ),
307+ suppress_regex_str(params.suppress_regex ? params.suppress_regex : " " ) {
308+ initial_prompt = initial_prompt_str.empty () ? nullptr : initial_prompt_str.c_str ();
309+ suppress_regex = suppress_regex_str.empty () ? nullptr : suppress_regex_str.c_str ();
310+ // progress callback
311+ progress_callback_user_data = this ;
312+ progress_callback = [](struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data) {
313+ auto * self = static_cast <WhisperFullParamsWrapper*>(user_data);
314+ if (self && self->print_progress ){
315+ if (self->py_progress_callback ) {
316+ // call the python callback
317+ py::gil_scoped_acquire gil;
318+ self->py_progress_callback (progress); // Call Python callback
319+ }
320+ else {
321+ fprintf (stderr, " Progress: %3d%%\n " , progress);
322+ } // Default message
323+ }
324+ } ;
325+ }
326+ WhisperFullParamsWrapper (const WhisperFullParamsWrapper& other)
327+ : whisper_full_params(static_cast <whisper_full_params>(other)), // Copy base struct
328+ initial_prompt_str(other.initial_prompt_str),
329+ suppress_regex_str(other.suppress_regex_str),
330+ py_progress_callback(other.py_progress_callback) {
331+ // Reset pointers to new string copies
332+ initial_prompt = initial_prompt_str.empty () ? nullptr : initial_prompt_str.c_str ();
333+ suppress_regex = suppress_regex_str.empty () ? nullptr : suppress_regex_str.c_str ();
334+ progress_callback_user_data = this ;
335+ progress_callback = [](struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data) {
336+ auto * self = static_cast <WhisperFullParamsWrapper*>(user_data);
337+ if (self && self->print_progress ){
338+ if (self->py_progress_callback ) {
339+ // call the python callback
340+ py::gil_scoped_acquire gil;
341+ self->py_progress_callback (progress); // Call Python callback
342+ }
343+ else {
344+ fprintf (stderr, " Progress: %3d%%\n " , progress);
345+ } // Default message
346+ }
347+ };
348+ }
349+ void set_initial_prompt (const std::string& prompt) {
350+ initial_prompt_str = prompt;
351+ initial_prompt = initial_prompt_str.c_str ();
352+ }
353+ void set_suppress_regex (const std::string& regex) {
354+ suppress_regex_str = regex;
355+ suppress_regex = suppress_regex_str.c_str ();
356+ }
339357};
340-
341358WhisperFullParamsWrapper whisper_full_default_params_wrapper (enum whisper_sampling_strategy strategy) {
342359 return WhisperFullParamsWrapper (whisper_full_default_params (strategy));
343360}
0 commit comments