55#include < sstream>
66
77namespace whisper {
8-
9- struct new_segment_callback_data {
10- std::vector<std::string> *results;
11- };
12-
13- class Whisper {
14- public:
15- ~Whisper () = default ;
16- Whisper (const char *model_path) {
17- Context context = Context::from_file (model_path);
18- this ->context = context;
19-
20- // Set default params to recommended.
21- FullParams params = FullParams::from_sampling_strategy (
22- SamplingStrategies::from_strategy_type (SamplingStrategies::GREEDY));
23- // disable printing progress
24- params.set_print_progress (false );
25- // disable realtime print, using callback
26- params.set_print_realtime (false );
27- // invoke new_segment_callback for faster transcription.
28- params.set_new_segment_callback ([](struct whisper_context *ctx, int n_new,
29- void *user_data) {
30- const auto &results = ((new_segment_callback_data *)user_data)->results ;
31-
32- const int n_segments = whisper_full_n_segments (ctx);
33-
34- // print the last n_new segments
35- const int s0 = n_segments - n_new;
36-
37- for (int i = s0; i < n_segments; i++) {
38- const char *text = whisper_full_get_segment_text (ctx, i);
39- results->push_back (std::move (text));
40- };
41- });
42- this ->params = params;
43- }
44-
45- std::string transcribe (std::vector<float > data, int num_proc) {
46- std::vector<std::string> results;
47- results.reserve (data.size ());
48- new_segment_callback_data user_data = {&results};
49- params.set_new_segment_callback_user_data (&user_data);
50- if (context.full_parallel (params, data, num_proc) != 0 ) {
51- throw std::runtime_error (" transcribe failed" );
52- }
53-
54- const char *const delim = " " ;
55- // We are allocating a new string for every element in the vector.
56- // This is not efficient, for larger files.
57- return std::accumulate (results.begin (), results.end (), std::string (delim));
58- };
59-
60- Context get_context () { return context; };
61- FullParams get_params () { return params; };
62-
63- private:
64- Context context;
65- FullParams params;
66- };
67-
688PYBIND11_MODULE (api, m) {
699 m.doc () = " Python interface for whisper.cpp" ;
7010
@@ -79,12 +19,119 @@ PYBIND11_MODULE(api, m) {
7919 ExportContextApi (m);
8020
8121 // NOTE: export Params API
82- ExportParamsApi (m);
83-
84- py::class_<Whisper>(m, " WhisperPreTrainedModel" )
85- .def (py::init<const char *>())
86- .def_property_readonly (" context" , &Whisper::get_context)
87- .def_property_readonly (" params" , &Whisper::get_params)
88- .def (" transcribe" , &Whisper::transcribe, " data" _a, " num_proc" _a = 1 );
22+ py::enum_<SamplingStrategies::StrategyType>(m, " StrategyType" )
23+ .value (" SAMPLING_GREEDY" , SamplingStrategies::GREEDY)
24+ .value (" SAMPLING_BEAM_SEARCH" , SamplingStrategies::BEAM_SEARCH)
25+ .export_values ();
26+
27+ py::class_<SamplingGreedy>(m, " SamplingGreedyStrategy" )
28+ .def (py::init<>())
29+ .def_property (
30+ " best_of" , [](SamplingGreedy &self) { return self.best_of ; },
31+ [](SamplingGreedy &self, int best_of) { self.best_of = best_of; });
32+
33+ py::class_<SamplingBeamSearch>(m, " SamplingBeamSearchStrategy" )
34+ .def (py::init<>())
35+ .def_property (
36+ " beam_size" , [](SamplingBeamSearch &self) { return self.beam_size ; },
37+ [](SamplingBeamSearch &self, int beam_size) {
38+ self.beam_size = beam_size;
39+ })
40+ .def_property (
41+ " patience" , [](SamplingBeamSearch &self) { return self.patience ; },
42+ [](SamplingBeamSearch &self, float patience) {
43+ self.patience = patience;
44+ });
45+
46+ py::class_<SamplingStrategies>(m, " SamplingStrategies" ,
47+ " Available sampling strategy for whisper" )
48+ .def_static (" from_strategy_type" , &SamplingStrategies::from_strategy_type,
49+ " strategy" _a)
50+ .def_property (
51+ " type" , [](SamplingStrategies &self) { return self.type ; },
52+ [](SamplingStrategies &self, SamplingStrategies::StrategyType type) {
53+ self.type = type;
54+ })
55+ .def_property (
56+ " greedy" , [](SamplingStrategies &self) { return self.greedy ; },
57+ [](SamplingStrategies &self, SamplingGreedy greedy) {
58+ self.greedy = greedy;
59+ })
60+ .def_property (
61+ " beam_search" ,
62+ [](SamplingStrategies &self) { return self.beam_search ; },
63+ [](SamplingStrategies &self, SamplingBeamSearch beam_search) {
64+ self.beam_search = beam_search;
65+ });
66+
67+ py::class_<FullParams>(m, " Params" , " Whisper parameters container" )
68+ .def_static (" from_sampling_strategy" , &FullParams::from_sampling_strategy,
69+ " sampling_strategy" _a)
70+ .def_property (" num_threads" , &FullParams::get_n_threads,
71+ &FullParams::set_n_threads)
72+ .def_property (" num_max_text_ctx" , &FullParams::get_n_max_text_ctx,
73+ &FullParams::set_n_max_text_ctx)
74+ .def_property (" offset_ms" , &FullParams::get_offset_ms,
75+ &FullParams::set_offset_ms)
76+ .def_property (" duration_ms" , &FullParams::get_duration_ms,
77+ &FullParams::set_duration_ms)
78+ .def_property (" translate" , &FullParams::get_translate,
79+ &FullParams::set_translate)
80+ .def_property (" no_context" , &FullParams::get_no_context,
81+ &FullParams::set_no_context)
82+ .def_property (" single_segment" , &FullParams::get_single_segment,
83+ &FullParams::set_single_segment)
84+ .def_property (" print_special" , &FullParams::get_print_special,
85+ &FullParams::set_print_special)
86+ .def_property (" print_progress" , &FullParams::get_print_progress,
87+ &FullParams::set_print_progress)
88+ .def_property (" print_realtime" , &FullParams::get_print_realtime,
89+ &FullParams::set_print_realtime)
90+ .def_property (" print_timestamps" , &FullParams::get_print_timestamps,
91+ &FullParams::set_print_timestamps)
92+ .def_property (" token_timestamps" , &FullParams::get_token_timestamps,
93+ &FullParams::set_token_timestamps)
94+ .def_property (" timestamp_token_probability_threshold" ,
95+ &FullParams::get_thold_pt, &FullParams::set_thold_pt)
96+ .def_property (" timestamp_token_sum_probability_threshold" ,
97+ &FullParams::get_thold_ptsum, &FullParams::set_thold_ptsum)
98+ .def_property (" max_segment_length" , &FullParams::get_max_len,
99+ &FullParams::set_max_len)
100+ .def_property (" split_on_word" , &FullParams::get_split_on_word,
101+ &FullParams::set_split_on_word)
102+ .def_property (" max_tokens" , &FullParams::get_max_tokens,
103+ &FullParams::set_max_tokens)
104+ .def_property (" speed_up" , &FullParams::get_speed_up,
105+ &FullParams::set_speed_up)
106+ .def_property (" audio_ctx" , &FullParams::get_audio_ctx,
107+ &FullParams::set_audio_ctx)
108+ .def (" set_tokens" , &FullParams::set_tokens, " tokens" _a)
109+ .def_property_readonly (" prompt_tokens" , &FullParams::get_prompt_tokens)
110+ .def_property_readonly (" prompt_num_tokens" ,
111+ &FullParams::get_prompt_n_tokens)
112+ .def_property (" language" , &FullParams::get_language,
113+ &FullParams::set_language)
114+ .def_property (" suppress_blank" , &FullParams::get_suppress_blank,
115+ &FullParams::set_suppress_blank)
116+ .def_property (" suppress_none_speech_tokens" ,
117+ &FullParams::get_suppress_none_speech_tokens,
118+ &FullParams::set_suppress_none_speech_tokens)
119+ .def_property (" temperature" , &FullParams::get_temperature,
120+ &FullParams::set_temperature)
121+ .def_property (" max_intial_timestamps" , &FullParams::get_max_intial_ts,
122+ &FullParams::set_max_intial_ts)
123+ .def_property (" length_penalty" , &FullParams::get_length_penalty,
124+ &FullParams::set_length_penalty)
125+ .def_property (" temperature_inc" , &FullParams::get_temperature_inc,
126+ &FullParams::set_temperature_inc)
127+ .def_property (" entropy_threshold" , &FullParams::get_entropy_thold,
128+ &FullParams::set_entropy_thold)
129+ .def_property (" logprob_threshold" , &FullParams::get_logprob_thold,
130+ &FullParams::set_logprob_thold)
131+ .def_property (" no_speech_threshold" , &FullParams::get_no_speech_thold,
132+ &FullParams::set_no_speech_thold);
133+ // TODO: idk what to do with setting all the callbacks for FullParams. API are
134+ // there, but need more time investingating conversion from Python callback to
135+ // C++ callback
89136}
90137}; // namespace whisper
0 commit comments