Skip to content

Commit f4e6714

Browse files
authored
Merge pull request #151 from dlparker/initial_prompt_mod_1
Possible fix for issues with initial_prompt
2 parents 1223b5b + ed01e78 commit f4e6714

1 file changed

Lines changed: 57 additions & 40 deletions

File tree

src/main.cpp

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -296,48 +296,65 @@ int whisper_ctx_init_openvino_encoder_wrapper(struct whisper_context_wrapper * c
296296
return whisper_ctx_init_openvino_encoder(ctx->ptr, model_path, device, cache_dir);
297297
}
298298

299-
class WhisperFullParamsWrapper : public whisper_full_params {
300-
std::string initial_prompt_str;
301-
std::string suppress_regex_str;
299+
struct WhisperFullParamsWrapper : public whisper_full_params {
300+
std::string initial_prompt_str;
301+
std::string suppress_regex_str;
302302
public:
303-
py::function py_progress_callback;
304-
WhisperFullParamsWrapper(const whisper_full_params& params = whisper_full_params())
305-
: whisper_full_params(params),
306-
initial_prompt_str(params.initial_prompt ? params.initial_prompt : ""),
307-
suppress_regex_str(params.suppress_regex ? params.suppress_regex : "") {
308-
initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str();
309-
suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str();
310-
// progress callback
311-
progress_callback_user_data = this;
312-
progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
313-
auto* self = static_cast<WhisperFullParamsWrapper*>(user_data);
314-
if(self && self->print_progress){
315-
if (self->py_progress_callback) {
316-
// call the python callback
317-
py::gil_scoped_acquire gil;
318-
self->py_progress_callback(progress); // Call Python callback
319-
}
320-
else {
321-
fprintf(stderr, "Progress: %3d%%\n", progress);
322-
} // Default message
323-
}
324-
} ;
325-
}
326-
327-
WhisperFullParamsWrapper(const WhisperFullParamsWrapper& other)
328-
: WhisperFullParamsWrapper(static_cast<const whisper_full_params&>(other)) {}
329-
330-
void set_initial_prompt(const std::string& prompt) {
331-
initial_prompt_str = prompt;
332-
initial_prompt = initial_prompt_str.c_str();
333-
}
334-
335-
void set_suppress_regex(const std::string& regex) {
336-
suppress_regex_str = regex;
337-
suppress_regex = suppress_regex_str.c_str();
338-
}
303+
py::function py_progress_callback;
304+
WhisperFullParamsWrapper(const whisper_full_params& params = whisper_full_params())
305+
: whisper_full_params(params),
306+
initial_prompt_str(params.initial_prompt ? params.initial_prompt : ""),
307+
suppress_regex_str(params.suppress_regex ? params.suppress_regex : "") {
308+
initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str();
309+
suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str();
310+
// progress callback
311+
progress_callback_user_data = this;
312+
progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
313+
auto* self = static_cast<WhisperFullParamsWrapper*>(user_data);
314+
if(self && self->print_progress){
315+
if (self->py_progress_callback) {
316+
// call the python callback
317+
py::gil_scoped_acquire gil;
318+
self->py_progress_callback(progress); // Call Python callback
319+
}
320+
else {
321+
fprintf(stderr, "Progress: %3d%%\n", progress);
322+
} // Default message
323+
}
324+
} ;
325+
}
326+
WhisperFullParamsWrapper(const WhisperFullParamsWrapper& other)
327+
: whisper_full_params(static_cast<whisper_full_params>(other)), // Copy base struct
328+
initial_prompt_str(other.initial_prompt_str),
329+
suppress_regex_str(other.suppress_regex_str),
330+
py_progress_callback(other.py_progress_callback) {
331+
// Reset pointers to new string copies
332+
initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str();
333+
suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str();
334+
progress_callback_user_data = this;
335+
progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
336+
auto* self = static_cast<WhisperFullParamsWrapper*>(user_data);
337+
if(self && self->print_progress){
338+
if (self->py_progress_callback) {
339+
// call the python callback
340+
py::gil_scoped_acquire gil;
341+
self->py_progress_callback(progress); // Call Python callback
342+
}
343+
else {
344+
fprintf(stderr, "Progress: %3d%%\n", progress);
345+
} // Default message
346+
}
347+
};
348+
}
349+
void set_initial_prompt(const std::string& prompt) {
350+
initial_prompt_str = prompt;
351+
initial_prompt = initial_prompt_str.c_str();
352+
}
353+
void set_suppress_regex(const std::string& regex) {
354+
suppress_regex_str = regex;
355+
suppress_regex = suppress_regex_str.c_str();
356+
}
339357
};
340-
341358
WhisperFullParamsWrapper whisper_full_default_params_wrapper(enum whisper_sampling_strategy strategy) {
342359
return WhisperFullParamsWrapper(whisper_full_default_params(strategy));
343360
}

0 commit comments

Comments
 (0)