Skip to content

Commit 1e64d43

Browse files
committed
fix: segfault when accessing bindings from C++
Move implementation from C++ to Python. Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
1 parent 735ca92 commit 1e64d43

5 files changed

Lines changed: 198 additions & 259 deletions

File tree

src/whispercpp/__init__.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from .utils import download_model
99

1010
if t.TYPE_CHECKING:
11+
import numpy as np
12+
from numpy.typing import NDArray
13+
1114
from . import api
1215
else:
1316
api = LazyLoader("api", globals(), "whispercpp.api")
@@ -27,20 +30,31 @@ def __init__(self, *args: t.Any, **kwargs: t.Any):
2730
context: api.Context
2831
params: api.Params
2932

30-
@classmethod
31-
def from_pretrained(cls, model_name: str):
33+
@staticmethod
34+
def from_pretrained(model_name: str):
3235
if model_name not in MODELS_URL:
3336
raise RuntimeError(
3437
f"'{model_name}' is not a valid preconverted model. Choose one of {list(MODELS_URL)}"
3538
)
36-
_ref = object.__new__(cls)
37-
_cpp_binding = api.WhisperPreTrainedModel(download_model(model_name))
38-
context = _cpp_binding.context
39-
params = _cpp_binding.params
40-
transcribe = _cpp_binding.transcribe
41-
del cls, _cpp_binding
39+
_ref = object.__new__(Whisper)
40+
context = api.Context.from_file(download_model(model_name))
41+
params = api.Params.from_sampling_strategy(
42+
api.SamplingStrategies.from_strategy_type(api.SAMPLING_GREEDY)
43+
)
44+
params.print_progress = False
45+
params.print_realtime = False
46+
context.reset_timings()
4247
_ref.__dict__.update(locals())
4348
return _ref
4449

50+
def transcribe(self, data: NDArray[np.float32], num_proc: int = 1):
51+
self.context.full_parallel(self.params, data, num_proc)
52+
return "".join(
53+
[
54+
self.context.full_get_segment_text(i)
55+
for i in range(self.context.full_n_segments())
56+
]
57+
)
58+
4559

4660
__all__ = ["Whisper", "api"]

src/whispercpp/api.pyi

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ from __future__ import annotations
33
import enum
44
import typing as t
55

6-
import numpy as np
76
from numpy.typing import NDArray
87

98
SAMPLE_RATE: int = ...
@@ -19,14 +18,19 @@ class SamplingBeamSearchStrategy:
1918
beam_size: int
2019
patience: float
2120

21+
SAMPLING_GREEDY: StrategyType = ...
22+
SAMPLING_BEAM_SEARCH: StrategyType = ...
23+
2224
class StrategyType(enum.Enum):
23-
GREEDY = ...
24-
BEAM_SEARCH = ...
25+
SAMPLING_GREEDY = ...
26+
SAMPLING_BEAM_SEARCH = ...
2527

2628
class SamplingStrategies:
2729
type: StrategyType
2830
greedy: SamplingGreedyStrategy
2931
beam_search: SamplingBeamSearchStrategy
32+
@staticmethod
33+
def from_strategy_type(strategy_type: StrategyType) -> SamplingStrategies: ...
3034

3135
# annotate the type of whisper_full_params
3236
_CppFullParams = t.Any
@@ -82,12 +86,4 @@ class Context:
8286
def full_get_segment_text(self, segment: int) -> str: ...
8387
def full_n_segments(self) -> int: ...
8488
def free(self) -> None: ...
85-
86-
class WhisperPreTrainedModel:
87-
context: Context
88-
params: Params
89-
@t.overload
90-
def __init__(self) -> None: ...
91-
@t.overload
92-
def __init__(self, path: str | bytes) -> None: ...
93-
def transcribe(self, arr: NDArray[np.float32], num_proc: int) -> str: ...
89+
def reset_timings(self) -> None: ...

src/whispercpp/api_export.cc

Lines changed: 114 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -5,66 +5,6 @@
55
#include <sstream>
66

77
namespace whisper {
8-
9-
struct new_segment_callback_data {
10-
std::vector<std::string> *results;
11-
};
12-
13-
class Whisper {
14-
public:
15-
~Whisper() = default;
16-
Whisper(const char *model_path) {
17-
Context context = Context::from_file(model_path);
18-
this->context = context;
19-
20-
// Set default params to recommended.
21-
FullParams params = FullParams::from_sampling_strategy(
22-
SamplingStrategies::from_strategy_type(SamplingStrategies::GREEDY));
23-
// disable printing progress
24-
params.set_print_progress(false);
25-
// disable realtime print, using callback
26-
params.set_print_realtime(false);
27-
// invoke new_segment_callback for faster transcription.
28-
params.set_new_segment_callback([](struct whisper_context *ctx, int n_new,
29-
void *user_data) {
30-
const auto &results = ((new_segment_callback_data *)user_data)->results;
31-
32-
const int n_segments = whisper_full_n_segments(ctx);
33-
34-
// print the last n_new segments
35-
const int s0 = n_segments - n_new;
36-
37-
for (int i = s0; i < n_segments; i++) {
38-
const char *text = whisper_full_get_segment_text(ctx, i);
39-
results->push_back(std::move(text));
40-
};
41-
});
42-
this->params = params;
43-
}
44-
45-
std::string transcribe(std::vector<float> data, int num_proc) {
46-
std::vector<std::string> results;
47-
results.reserve(data.size());
48-
new_segment_callback_data user_data = {&results};
49-
params.set_new_segment_callback_user_data(&user_data);
50-
if (context.full_parallel(params, data, num_proc) != 0) {
51-
throw std::runtime_error("transcribe failed");
52-
}
53-
54-
const char *const delim = "";
55-
// We are allocating a new string for every element in the vector.
56-
// This is not efficient, for larger files.
57-
return std::accumulate(results.begin(), results.end(), std::string(delim));
58-
};
59-
60-
Context get_context() { return context; };
61-
FullParams get_params() { return params; };
62-
63-
private:
64-
Context context;
65-
FullParams params;
66-
};
67-
688
PYBIND11_MODULE(api, m) {
699
m.doc() = "Python interface for whisper.cpp";
7010

@@ -79,12 +19,119 @@ PYBIND11_MODULE(api, m) {
7919
ExportContextApi(m);
8020

8121
// NOTE: export Params API
82-
ExportParamsApi(m);
83-
84-
py::class_<Whisper>(m, "WhisperPreTrainedModel")
85-
.def(py::init<const char *>())
86-
.def_property_readonly("context", &Whisper::get_context)
87-
.def_property_readonly("params", &Whisper::get_params)
88-
.def("transcribe", &Whisper::transcribe, "data"_a, "num_proc"_a = 1);
22+
py::enum_<SamplingStrategies::StrategyType>(m, "StrategyType")
23+
.value("SAMPLING_GREEDY", SamplingStrategies::GREEDY)
24+
.value("SAMPLING_BEAM_SEARCH", SamplingStrategies::BEAM_SEARCH)
25+
.export_values();
26+
27+
py::class_<SamplingGreedy>(m, "SamplingGreedyStrategy")
28+
.def(py::init<>())
29+
.def_property(
30+
"best_of", [](SamplingGreedy &self) { return self.best_of; },
31+
[](SamplingGreedy &self, int best_of) { self.best_of = best_of; });
32+
33+
py::class_<SamplingBeamSearch>(m, "SamplingBeamSearchStrategy")
34+
.def(py::init<>())
35+
.def_property(
36+
"beam_size", [](SamplingBeamSearch &self) { return self.beam_size; },
37+
[](SamplingBeamSearch &self, int beam_size) {
38+
self.beam_size = beam_size;
39+
})
40+
.def_property(
41+
"patience", [](SamplingBeamSearch &self) { return self.patience; },
42+
[](SamplingBeamSearch &self, float patience) {
43+
self.patience = patience;
44+
});
45+
46+
py::class_<SamplingStrategies>(m, "SamplingStrategies",
47+
"Available sampling strategy for whisper")
48+
.def_static("from_strategy_type", &SamplingStrategies::from_strategy_type,
49+
"strategy"_a)
50+
.def_property(
51+
"type", [](SamplingStrategies &self) { return self.type; },
52+
[](SamplingStrategies &self, SamplingStrategies::StrategyType type) {
53+
self.type = type;
54+
})
55+
.def_property(
56+
"greedy", [](SamplingStrategies &self) { return self.greedy; },
57+
[](SamplingStrategies &self, SamplingGreedy greedy) {
58+
self.greedy = greedy;
59+
})
60+
.def_property(
61+
"beam_search",
62+
[](SamplingStrategies &self) { return self.beam_search; },
63+
[](SamplingStrategies &self, SamplingBeamSearch beam_search) {
64+
self.beam_search = beam_search;
65+
});
66+
67+
py::class_<FullParams>(m, "Params", "Whisper parameters container")
68+
.def_static("from_sampling_strategy", &FullParams::from_sampling_strategy,
69+
"sampling_strategy"_a)
70+
.def_property("num_threads", &FullParams::get_n_threads,
71+
&FullParams::set_n_threads)
72+
.def_property("num_max_text_ctx", &FullParams::get_n_max_text_ctx,
73+
&FullParams::set_n_max_text_ctx)
74+
.def_property("offset_ms", &FullParams::get_offset_ms,
75+
&FullParams::set_offset_ms)
76+
.def_property("duration_ms", &FullParams::get_duration_ms,
77+
&FullParams::set_duration_ms)
78+
.def_property("translate", &FullParams::get_translate,
79+
&FullParams::set_translate)
80+
.def_property("no_context", &FullParams::get_no_context,
81+
&FullParams::set_no_context)
82+
.def_property("single_segment", &FullParams::get_single_segment,
83+
&FullParams::set_single_segment)
84+
.def_property("print_special", &FullParams::get_print_special,
85+
&FullParams::set_print_special)
86+
.def_property("print_progress", &FullParams::get_print_progress,
87+
&FullParams::set_print_progress)
88+
.def_property("print_realtime", &FullParams::get_print_realtime,
89+
&FullParams::set_print_realtime)
90+
.def_property("print_timestamps", &FullParams::get_print_timestamps,
91+
&FullParams::set_print_timestamps)
92+
.def_property("token_timestamps", &FullParams::get_token_timestamps,
93+
&FullParams::set_token_timestamps)
94+
.def_property("timestamp_token_probability_threshold",
95+
&FullParams::get_thold_pt, &FullParams::set_thold_pt)
96+
.def_property("timestamp_token_sum_probability_threshold",
97+
&FullParams::get_thold_ptsum, &FullParams::set_thold_ptsum)
98+
.def_property("max_segment_length", &FullParams::get_max_len,
99+
&FullParams::set_max_len)
100+
.def_property("split_on_word", &FullParams::get_split_on_word,
101+
&FullParams::set_split_on_word)
102+
.def_property("max_tokens", &FullParams::get_max_tokens,
103+
&FullParams::set_max_tokens)
104+
.def_property("speed_up", &FullParams::get_speed_up,
105+
&FullParams::set_speed_up)
106+
.def_property("audio_ctx", &FullParams::get_audio_ctx,
107+
&FullParams::set_audio_ctx)
108+
.def("set_tokens", &FullParams::set_tokens, "tokens"_a)
109+
.def_property_readonly("prompt_tokens", &FullParams::get_prompt_tokens)
110+
.def_property_readonly("prompt_num_tokens",
111+
&FullParams::get_prompt_n_tokens)
112+
.def_property("language", &FullParams::get_language,
113+
&FullParams::set_language)
114+
.def_property("suppress_blank", &FullParams::get_suppress_blank,
115+
&FullParams::set_suppress_blank)
116+
.def_property("suppress_none_speech_tokens",
117+
&FullParams::get_suppress_none_speech_tokens,
118+
&FullParams::set_suppress_none_speech_tokens)
119+
.def_property("temperature", &FullParams::get_temperature,
120+
&FullParams::set_temperature)
121+
.def_property("max_intial_timestamps", &FullParams::get_max_intial_ts,
122+
&FullParams::set_max_intial_ts)
123+
.def_property("length_penalty", &FullParams::get_length_penalty,
124+
&FullParams::set_length_penalty)
125+
.def_property("temperature_inc", &FullParams::get_temperature_inc,
126+
&FullParams::set_temperature_inc)
127+
.def_property("entropy_threshold", &FullParams::get_entropy_thold,
128+
&FullParams::set_entropy_thold)
129+
.def_property("logprob_threshold", &FullParams::get_logprob_thold,
130+
&FullParams::set_logprob_thold)
131+
.def_property("no_speech_threshold", &FullParams::get_no_speech_thold,
132+
&FullParams::set_no_speech_thold);
133+
// TODO: idk what to do with setting all the callbacks for FullParams. API are
134+
// there, but need more time investingating conversion from Python callback to
135+
// C++ callback
89136
}
90137
}; // namespace whisper

0 commit comments

Comments
 (0)