Skip to content

Commit ece8721

Browse files
committed
examples : add parakeet-server example
This commit adds a http server for parkeet similar to whisper-server. The shared functionality has been extracted in to examples/server-common.h to avoid code duplication.
1 parent 43d78af commit ece8721

8 files changed

Lines changed: 992 additions & 344 deletions

File tree

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ else()
109109
add_subdirectory(vad-speech-segments)
110110
add_subdirectory(parakeet-cli)
111111
add_subdirectory(parakeet-quantize)
112+
add_subdirectory(parakeet-server)
112113
if (WHISPER_SDL2)
113114
add_subdirectory(stream)
114115
add_subdirectory(command)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
set(CMAKE_CXX_STANDARD 17)
2+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
3+
4+
set(TARGET parakeet-server)
5+
add_executable(${TARGET} parakeet-server.cpp)
6+
7+
include(DefaultTargetOptions)
8+
9+
target_sources(${TARGET} PRIVATE ../server-common.cpp)
10+
11+
target_link_libraries(${TARGET} PRIVATE common json_cpp parakeet ${CMAKE_THREAD_LIBS_INIT})
12+
13+
if (WIN32)
14+
target_link_libraries(${TARGET} PRIVATE ws2_32)
15+
endif()
16+
17+
install(TARGETS ${TARGET} RUNTIME)

examples/parakeet-server/parakeet-server.cpp

Lines changed: 420 additions & 0 deletions
Large diffs are not rendered by default.

examples/server-common.cpp

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
#include "server-common.h"
2+
#include "common-whisper.h"
3+
4+
#include <cstdio>
5+
#include <csignal>
6+
#include <random>
7+
#include <sstream>
8+
#include <memory>
9+
#include <fstream>
10+
#include <algorithm>
11+
#include <cmath>
12+
#include <cstdlib>
13+
#include <filesystem>
14+
15+
#if defined (_WIN32)
16+
#include <windows.h>
17+
#endif
18+
19+
const std::string json_format = "json";
20+
const std::string text_format = "text";
21+
const std::string srt_format = "srt";
22+
const std::string vjson_format = "verbose_json";
23+
const std::string vtt_format = "vtt";
24+
25+
namespace {
26+
std::function<void()> g_shutdown_callback;
27+
std::atomic_flag g_is_terminating = ATOMIC_FLAG_INIT;
28+
29+
void signal_handler(int /*signal*/) {
30+
if (g_is_terminating.test_and_set()) {
31+
fprintf(stderr, "Received second interrupt, terminating immediately.\n");
32+
exit(1);
33+
}
34+
if (g_shutdown_callback) {
35+
g_shutdown_callback();
36+
}
37+
}
38+
}
39+
40+
bool parse_str_to_bool(const std::string & s) {
41+
if (s == "true" || s == "1" || s == "yes" || s == "y") {
42+
return true;
43+
}
44+
return false;
45+
}
46+
47+
bool check_ffmpeg_availability() {
48+
int result = system("ffmpeg -version");
49+
if (result == 0) {
50+
std::cout << "ffmpeg is available." << std::endl;
51+
} else {
52+
std::cout << "ffmpeg is not found. Please ensure that ffmpeg is installed "
53+
<< "and that its executable is included in your system's PATH. ";
54+
exit(0);
55+
}
56+
return true;
57+
}
58+
59+
std::string generate_temp_filename(const std::string & path, const std::string & prefix, const std::string & extension) {
60+
auto now = std::chrono::system_clock::now();
61+
auto now_time_t = std::chrono::system_clock::to_time_t(now);
62+
63+
static std::mt19937 rng{std::random_device{}()};
64+
std::uniform_int_distribution<long long> dist(0, 1e9);
65+
66+
std::stringstream ss;
67+
ss << path
68+
<< std::filesystem::path::preferred_separator
69+
<< prefix
70+
<< "-"
71+
<< std::put_time(std::localtime(&now_time_t), "%Y%m%d-%H%M%S")
72+
<< "-"
73+
<< dist(rng)
74+
<< extension;
75+
76+
return ss.str();
77+
}
78+
79+
bool convert_to_wav(const std::string & temp_filename, std::string & error_resp, bool stereo) {
80+
std::ostringstream cmd_stream;
81+
std::string converted_filename_temp = temp_filename + "_temp.wav";
82+
cmd_stream << "ffmpeg -i \"" << temp_filename << "\" -y -ar 16000 -ac " << (stereo ? 2 : 1) << " -c:a pcm_s16le \"" << converted_filename_temp << "\" 2>&1";
83+
std::string cmd = cmd_stream.str();
84+
85+
int status = std::system(cmd.c_str());
86+
if (status != 0) {
87+
error_resp = "{\"error\":\"FFmpeg conversion failed.\"}";
88+
return false;
89+
}
90+
91+
if (remove(temp_filename.c_str()) != 0) {
92+
error_resp = "{\"error\":\"Failed to remove the original file.\"}";
93+
return false;
94+
}
95+
96+
if (rename(converted_filename_temp.c_str(), temp_filename.c_str()) != 0) {
97+
error_resp = "{\"error\":\"Failed to rename the temporary file.\"}";
98+
return false;
99+
}
100+
return true;
101+
}
102+
103+
void setup_signal_handler(std::function<void()> shutdown_callback) {
104+
g_shutdown_callback = std::move(shutdown_callback);
105+
106+
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
107+
struct sigaction sigint_action;
108+
sigint_action.sa_handler = signal_handler;
109+
sigemptyset(&sigint_action.sa_mask);
110+
sigint_action.sa_flags = 0;
111+
sigaction(SIGINT, &sigint_action, NULL);
112+
sigaction(SIGTERM, &sigint_action, NULL);
113+
#elif defined (_WIN32)
114+
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
115+
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
116+
};
117+
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
118+
#endif
119+
}
120+
121+
static std::string ms_to_timestamp(int64_t t_ms, bool comma = false) {
122+
// to_timestamp expects centiseconds, our adapter uses milliseconds
123+
return ::to_timestamp(t_ms / 10, comma);
124+
}
125+
126+
127+
std::string format_text(const transcription_result & result) {
128+
std::stringstream ss;
129+
const int n_segments = result.n_segments();
130+
for (int i = 0; i < n_segments; ++i) {
131+
auto seg = result.get_segment(i);
132+
auto speaker = result.get_speaker(i);
133+
ss << speaker << seg.text << "\n";
134+
}
135+
return ss.str();
136+
}
137+
138+
std::string format_srt(const transcription_result & result, int offset_n) {
139+
std::stringstream ss;
140+
const int n_segments = result.n_segments();
141+
for (int i = 0; i < n_segments; ++i) {
142+
auto seg = result.get_segment(i);
143+
auto speaker = result.get_speaker(i);
144+
145+
ss << i + 1 + offset_n << "\n";
146+
ss << ms_to_timestamp(seg.t0, true) << " --> " << ms_to_timestamp(seg.t1, true) << "\n";
147+
ss << speaker << seg.text << "\n\n";
148+
}
149+
return ss.str();
150+
}
151+
152+
std::string format_vtt(const transcription_result & result) {
153+
std::stringstream ss;
154+
ss << "WEBVTT\n\n";
155+
156+
const int n_segments = result.n_segments();
157+
for (int i = 0; i < n_segments; ++i) {
158+
auto seg = result.get_segment(i);
159+
std::string speaker_tag;
160+
161+
auto speaker_id = result.get_speaker(i);
162+
if (!speaker_id.empty()) {
163+
speaker_tag = "<v Speaker" + speaker_id + ">";
164+
}
165+
166+
ss << ms_to_timestamp(seg.t0) << " --> " << ms_to_timestamp(seg.t1) << "\n";
167+
ss << speaker_tag << seg.text << "\n\n";
168+
}
169+
return ss.str();
170+
}
171+
172+
std::string format_json(const transcription_result & result) {
173+
std::string text = format_text(result);
174+
json jres = json{{"text", text}};
175+
return jres.dump(-1, ' ', false, json::error_handler_t::replace);
176+
}
177+
178+
std::string format_verbose_json(
179+
const transcription_result & result,
180+
float temperature,
181+
float duration,
182+
bool no_timestamps,
183+
bool token_timestamps) {
184+
std::string text = format_text(result);
185+
std::string task = result.get_task();
186+
std::string language = result.get_language();
187+
188+
json jres = json{
189+
{"task", task},
190+
{"language", language},
191+
{"duration", duration},
192+
{"text", text},
193+
{"segments", json::array()}
194+
};
195+
196+
// Merge language probability data into the top-level response.
197+
// Adapters return a json object whose keys are merged directly, allowing
198+
// model-specific fields (e.g. whisper's detected_language) to appear at
199+
// the top level alongside the standard language_probabilities map.
200+
json lang_data = result.get_language_probabilities();
201+
for (auto & [key, val] : lang_data.items()) {
202+
jres[key] = val;
203+
}
204+
205+
const int n_segments = result.n_segments();
206+
for (int i = 0; i < n_segments; ++i) {
207+
auto seg = result.get_segment(i);
208+
209+
json segment = json{
210+
{"id", i},
211+
{"text", seg.text},
212+
};
213+
214+
if (!no_timestamps) {
215+
segment["start"] = seg.t0 * 0.001f; // ms -> seconds
216+
segment["end"] = seg.t1 * 0.001f;
217+
}
218+
219+
auto speaker_id = result.get_speaker(i);
220+
if (!speaker_id.empty()) {
221+
segment["speaker"] = speaker_id;
222+
}
223+
224+
// Build word-level tokens by merging partial UTF-8 tokens
225+
std::vector<json> words;
226+
int n_tokens = (int)seg.tokens.size();
227+
float total_logprob = 0.0f;
228+
229+
for (int j = 0; j < n_tokens; ++j) {
230+
auto & tok = seg.tokens[j];
231+
232+
// Merge trailing partial UTF-8 bytes into complete words
233+
std::string word_text = tok.text;
234+
int64_t word_t1 = tok.t1;
235+
236+
while (j + 1 < n_tokens) {
237+
int trailing = utf8_trailing_bytes_needed(word_text);
238+
if (trailing <= 0) break;
239+
240+
++j;
241+
auto & next_tok = seg.tokens[j];
242+
word_text += next_tok.text;
243+
if (next_tok.t1 > word_t1) {
244+
word_t1 = next_tok.t1;
245+
}
246+
}
247+
248+
json word = json{{"word", word_text}};
249+
if (!no_timestamps && token_timestamps) {
250+
word["start"] = tok.t0 * 0.001f;
251+
word["end"] = word_t1 * 0.001f;
252+
}
253+
word["probability"] = tok.prob;
254+
255+
// Approximate logprob from probability
256+
float logprob = tok.prob > 0.0f ? std::log(tok.prob + 1e-10f) : -1e10f;
257+
total_logprob += logprob;
258+
259+
words.push_back(word);
260+
}
261+
262+
segment["words"] = words;
263+
segment["tokens"] = json::array();
264+
for (auto & tok : seg.tokens) {
265+
segment["tokens"].push_back(tok.id);
266+
}
267+
268+
segment["temperature"] = temperature;
269+
int n_word_tokens = (int)seg.tokens.size();
270+
segment["avg_logprob"] = n_word_tokens > 0 ? total_logprob / n_word_tokens : 0.0f;
271+
segment["no_speech_prob"] = seg.no_speech_prob;
272+
273+
jres["segments"].push_back(segment);
274+
}
275+
276+
return jres.dump(-1, ' ', false, json::error_handler_t::replace);
277+
}
278+
279+
void setup_server_common(
280+
httplib::Server & svr,
281+
const server_params & sparams,
282+
std::atomic<server_state> & state,
283+
std::function<void(const httplib::Request &, httplib::Response &)> load_handler,
284+
std::function<void(const httplib::Request &, httplib::Response &)> inference_handler,
285+
const std::string & default_content,
286+
const std::string & server_name) {
287+
288+
svr.set_default_headers({
289+
{"Server", server_name},
290+
{"Access-Control-Allow-Origin", "*"},
291+
{"Access-Control-Allow-Headers", "content-type, authorization"}
292+
});
293+
294+
// Default index page
295+
svr.Get(sparams.request_path + "/", [&](const httplib::Request &, httplib::Response & res) {
296+
res.set_content(default_content, "text/html");
297+
return false;
298+
});
299+
300+
// CORS preflight
301+
svr.Options(sparams.request_path + sparams.inference_path,
302+
[&](const httplib::Request &, httplib::Response &) {});
303+
304+
// Inference endpoint
305+
svr.Post(sparams.request_path + sparams.inference_path, inference_handler);
306+
307+
// Model reload endpoint
308+
if (load_handler) {
309+
svr.Post(sparams.request_path + "/load", load_handler);
310+
}
311+
312+
// Health check
313+
svr.Get(sparams.request_path + "/health", [&](const httplib::Request &, httplib::Response & res) {
314+
server_state current_state = state.load();
315+
if (current_state == SERVER_STATE_READY) {
316+
res.set_content("{\"status\":\"ok\"}", "application/json");
317+
} else {
318+
res.set_content("{\"status\":\"loading model\"}", "application/json");
319+
res.status = 503;
320+
}
321+
});
322+
323+
// Exception handler
324+
svr.set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
325+
const char fmt[] = "500 Internal Server Error\n%s";
326+
char buf[BUFSIZ];
327+
try {
328+
std::rethrow_exception(std::move(ep));
329+
} catch (std::exception & e) {
330+
snprintf(buf, sizeof(buf), fmt, e.what());
331+
} catch (...) {
332+
snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
333+
}
334+
res.set_content(buf, "text/plain");
335+
res.status = 500;
336+
});
337+
338+
// Error handler
339+
svr.set_error_handler([](const httplib::Request & req, httplib::Response & res) {
340+
if (res.status == 400) {
341+
res.set_content("Invalid request", "text/plain");
342+
} else if (res.status != 500) {
343+
res.set_content("File Not Found (" + req.path + ")", "text/plain");
344+
res.status = 404;
345+
}
346+
});
347+
348+
svr.set_read_timeout(sparams.read_timeout);
349+
svr.set_write_timeout(sparams.write_timeout);
350+
}

0 commit comments

Comments
 (0)