diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 7aedb9df683..5f795129d6c 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -109,6 +109,7 @@ else() add_subdirectory(vad-speech-segments) add_subdirectory(parakeet-cli) add_subdirectory(parakeet-quantize) + add_subdirectory(parakeet-server) if (WHISPER_SDL2) add_subdirectory(stream) add_subdirectory(command) diff --git a/examples/server/httplib.h b/examples/httplib.h similarity index 100% rename from examples/server/httplib.h rename to examples/httplib.h diff --git a/examples/parakeet-server/CMakeLists.txt b/examples/parakeet-server/CMakeLists.txt new file mode 100644 index 00000000000..a2c69c052cc --- /dev/null +++ b/examples/parakeet-server/CMakeLists.txt @@ -0,0 +1,17 @@ +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(TARGET parakeet-server) +add_executable(${TARGET} parakeet-server.cpp) + +include(DefaultTargetOptions) + +target_sources(${TARGET} PRIVATE ../server-common.cpp) + +target_link_libraries(${TARGET} PRIVATE common json_cpp parakeet ${CMAKE_THREAD_LIBS_INIT}) + +if (WIN32) + target_link_libraries(${TARGET} PRIVATE ws2_32) +endif() + +install(TARGETS ${TARGET} RUNTIME) diff --git a/examples/parakeet-server/parakeet-server.cpp b/examples/parakeet-server/parakeet-server.cpp new file mode 100644 index 00000000000..52bdd1cfbe9 --- /dev/null +++ b/examples/parakeet-server/parakeet-server.cpp @@ -0,0 +1,403 @@ +#include "parakeet.h" +#include "common-whisper.h" +#include "server-common.h" + +#include "httplib.h" +#include "json.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace httplib; +using json = nlohmann::ordered_json; + +struct parakeet_params { + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + bool use_gpu = true; + int32_t gpu_device = 0; + + std::string model = "models/ggml-parakeet-tdt-0.6b-v3.bin"; + std::string response_format = json_format; +}; + +static void parakeet_print_usage(int /*argc*/, char ** argv, const parakeet_params & params, const server_params & sparams) { + fprintf(stderr, "\n"); + fprintf(stderr, "usage: %s [options]\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -ng, --no-gpu [%-7s] do not use GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -dev N, --device N [%-7d] GPU device ID\n", params.gpu_device); + fprintf(stderr, "\n"); + fprintf(stderr, " --host HOST, [%-7s] Hostname/IP address for the server\n", sparams.hostname.c_str()); + fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port); + fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str()); + fprintf(stderr, " --request-path PATH, [%-7s] Request path prefix\n", sparams.request_path.c_str()); + fprintf(stderr, " --inference-path PATH, [%-7s] Inference endpoint path\n", sparams.inference_path.c_str()); + fprintf(stderr, " --convert, [%-7s] Convert audio to WAV via ffmpeg\n", sparams.ffmpeg_converter ? "true" : "false"); + fprintf(stderr, " --keep-input-audio, [%-7s] Keep input audio in --tmp-dir\n", sparams.keep_input_audio ? "true" : "false"); + fprintf(stderr, " --tmp-dir PATH, [%-7s] Temporary directory for converted files\n", sparams.tmp_dir.c_str()); + fprintf(stderr, "\n"); +} + +static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & params, server_params & sparams) { + if (const char * env_device = std::getenv("PARAKEET_ARG_DEVICE")) { + params.gpu_device = std::stoi(env_device); + } + + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-h" || arg == "--help") { + parakeet_print_usage(argc, argv, params, sparams); + exit(0); + } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(argv[++i]); } + else if (arg == "--host") { sparams.hostname = argv[++i]; } + else if (arg == "--port") { sparams.port = std::stoi(argv[++i]); } + else if (arg == "--public") { sparams.public_path = argv[++i]; } + else if (arg == "--request-path") { sparams.request_path = argv[++i]; } + else if (arg == "--inference-path") { sparams.inference_path = argv[++i]; } + else if (arg == "--convert") { sparams.ffmpeg_converter = true; } + else if (arg == "--keep-input-audio") { sparams.keep_input_audio = true; } + else if (arg == "--tmp-dir") { sparams.tmp_dir = argv[++i]; } + else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + parakeet_print_usage(argc, argv, params, sparams); + exit(1); + } + } + + return true; +} + +static void get_req_parameters(const Request & req, parakeet_params & params) { + if (req.has_file("response_format")) { + params.response_format = req.get_file_value("response_format").content; + } +} + +struct parakeet_result : transcription_result { + parakeet_context * ctx; + + explicit parakeet_result(parakeet_context * c) : ctx(c) {} + + int n_segments() const override { + return parakeet_full_n_segments(ctx); + } + + segment_info get_segment(int i) const override { + segment_info seg; + seg.text = parakeet_full_get_segment_text(ctx, i); + seg.t0 = parakeet_full_get_segment_t0(ctx, i); + seg.t1 = parakeet_full_get_segment_t1(ctx, i); + seg.no_speech_prob = 0.0f; + + const int n_tokens = parakeet_full_n_tokens(ctx, i); + seg.tokens.reserve(n_tokens); + for (int j = 0; j < n_tokens; ++j) { + parakeet_token_data tok = parakeet_full_get_token_data(ctx, i, j); + seg.tokens.push_back({tok.id, parakeet_full_get_token_text(ctx, i, j), + tok.t0, tok.t1, tok.p}); + } + return seg; + } + + std::string get_language() const override { return "N/A"; } +}; + +static std::string generate_index_page(const server_params & sparams) { + std::ostringstream oss; + oss << R"( + + + Parakeet.cpp Server + + + + + +

Parakeet.cpp Server

+ +

)" << sparams.request_path << sparams.inference_path << R"(

+
+    curl 127.0.0.1:)" << sparams.port << sparams.request_path << sparams.inference_path << R"( \
+    -H "Content-Type: multipart/form-data" \
+    -F file="@" \
+    -F response_format="json"
+        
+ +
+

Try it out

+
+ +
+ + +
+ + +
+
+ + + )"; + return oss.str(); +} + +int main(int argc, char ** argv) { + ggml_backend_load_all(); + + parakeet_params params; + server_params sparams; + + std::mutex model_mutex; + + if (!parakeet_params_parse(argc, argv, params, sparams)) { + parakeet_print_usage(argc, argv, params, sparams); + return 1; + } + + if (sparams.ffmpeg_converter) { + check_ffmpeg_availability(); + } + + parakeet_context_params cparams = parakeet_context_default_params(); + cparams.use_gpu = params.use_gpu; + cparams.gpu_device = params.gpu_device; + + std::unique_ptr svr = std::make_unique(); + std::atomic state{SERVER_STATE_LOADING_MODEL}; + + struct parakeet_context * ctx = parakeet_init_from_file_with_params(params.model.c_str(), cparams); + if (ctx == nullptr) { + fprintf(stderr, "error: failed to initialize parakeet context from '%s'\n", params.model.c_str()); + return 1; + } + + state.store(SERVER_STATE_READY); + + printf("Successfully loaded Parakeet model from: %s\n", params.model.c_str()); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads, (int32_t) std::thread::hardware_concurrency(), parakeet_print_system_info()); + + const std::string default_content = generate_index_page(sparams); + + parakeet_params default_params = params; + + auto inference_handler = [&](const Request & req, Response & res) { + std::lock_guard lock(model_mutex); + + if (!req.has_file("file")) { + fprintf(stderr, "error: no 'file' field in the request\n"); + res.status = 400; + res.set_content("{\"error\":\"no 'file' field in the request\"}", "application/json"); + return; + } + + auto audio_file = req.get_file_value("file"); + parakeet_params cur_params = default_params; + get_req_parameters(req, cur_params); + + std::string filename{audio_file.filename}; + printf("Received request: %s\n", filename.c_str()); + + std::vector pcmf32; + std::vector> pcmf32s; + + std::string temp_filename; + + if (sparams.keep_input_audio || sparams.ffmpeg_converter) { + temp_filename = generate_temp_filename(sparams.tmp_dir, "parakeet-server", ".wav"); + + std::ofstream temp_file{temp_filename, std::ios::binary}; + temp_file.write(audio_file.content.data(), + static_cast(audio_file.content.size())); + } + + if (sparams.ffmpeg_converter) { + std::string error_resp; + if (!convert_to_wav(temp_filename, error_resp, false)) { + res.status = 500; + res.set_content(error_resp, "application/json"); + return; + } + + if (!::read_audio_data(temp_filename, pcmf32, pcmf32s, false)) { + fprintf(stderr, "error: failed to read WAV file '%s'\n", temp_filename.c_str()); + res.status = 400; + res.set_content("{\"error\":\"failed to read WAV file\"}", "application/json"); + if (!sparams.keep_input_audio) { + std::remove(temp_filename.c_str()); + } + return; + } + } else { + if (!::read_audio_data(audio_file.content.data(), audio_file.content.size(), pcmf32, pcmf32s, false)) { + fprintf(stderr, "error: failed to read audio data\n"); + res.status = 400; + res.set_content("{\"error\":\"failed to read audio data\"}", "application/json"); + return; + } + } + + if (!sparams.keep_input_audio) { + std::remove(temp_filename.c_str()); + } + + printf("Successfully loaded %s\n", filename.c_str()); + + fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads ...\n", + __func__, filename.c_str(), (int)pcmf32.size(), + float(pcmf32.size()) / PARAKEET_SAMPLE_RATE, cur_params.n_threads); + + { + printf("Running parakeet.cpp inference on %s\n", filename.c_str()); + + parakeet_full_params fparams = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + fparams.n_threads = cur_params.n_threads; + fparams.no_context = true; + + // Abort callback for HTTP disconnect + fparams.abort_callback = [](void * user_data) { + auto req_ptr = static_cast(user_data); + return req_ptr->is_connection_closed(); + }; + fparams.abort_callback_user_data = (void *) &req; + + int ret = parakeet_full(ctx, fparams, pcmf32.data(), (int)pcmf32.size()); + if (ret != 0) { + if (req.is_connection_closed()) { + fprintf(stderr, "client disconnected, aborted processing\n"); + res.status = 499; + res.set_content("{\"error\":\"client disconnected\"}", "application/json"); + return; + } + fprintf(stderr, "%s: failed to process audio\n", argv[0]); + res.status = 500; + res.set_content("{\"error\":\"failed to process audio\"}", "application/json"); + return; + } + } + + // Format response + parakeet_result result{ctx}; + + if (cur_params.response_format == text_format) { + std::string text = format_text(result); + res.set_content(text, "text/plain; charset=utf-8"); + } else if (cur_params.response_format == srt_format) { + std::string srt = format_srt(result, 0); + res.set_content(srt, "application/x-subrip"); + } else if (cur_params.response_format == vtt_format) { + std::string vtt = format_vtt(result); + res.set_content(vtt, "text/vtt"); + } else if (cur_params.response_format == vjson_format) { + float duration = float(pcmf32.size()) / PARAKEET_SAMPLE_RATE; + std::string vjson = format_verbose_json(result, 0.0f, duration, false, true); + res.set_content(vjson, "application/json"); + } else { + std::string j = format_json(result); + res.set_content(j, "application/json"); + } + }; + + auto load_handler = [&](const Request & req, Response & res) { + std::lock_guard lock(model_mutex); + state.store(SERVER_STATE_LOADING_MODEL); + + if (!req.has_file("model")) { + fprintf(stderr, "error: no 'model' field in the request\n"); + res.status = 400; + res.set_content("{\"error\":\"no 'model' field in the request\"}", "application/json"); + return; + } + + std::string model = req.get_file_value("model").content; + + parakeet_free(ctx); + + ctx = parakeet_init_from_file_with_params(model.c_str(), cparams); + if (ctx == nullptr) { + fprintf(stderr, "error: failed to load model '%s'\n", model.c_str()); + res.status = 500; + res.set_content("{\"error\":\"failed to load model\"}", "application/json"); + return; + } + + state.store(SERVER_STATE_READY); + res.set_content("Load was successful!", "text/plain"); + }; + + setup_server_common(*svr, sparams, state, load_handler, inference_handler, default_content, "parakeet.cpp"); + + setup_signal_handler([&]() { + printf("\nShutting down gracefully...\n"); + svr->stop(); + }); + + if (!svr->bind_to_port(sparams.hostname, sparams.port)) { + fprintf(stderr, "couldn't bind to server socket: hostname=%s port=%d\n", + sparams.hostname.c_str(), sparams.port); + return 1; + } + + svr->set_base_dir(sparams.public_path); + + printf("\nparakeet server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port); + + auto clean_up = [&]() { + parakeet_print_timings(ctx); + parakeet_free(ctx); + }; + + std::thread t([&]() { + if (!svr->listen_after_bind()) { + fprintf(stderr, "error: server listen failed\n"); + } + }); + + svr->wait_until_ready(); + t.join(); + + clean_up(); + + return 0; +} diff --git a/examples/server-common.cpp b/examples/server-common.cpp new file mode 100644 index 00000000000..2d78a2bcd6c --- /dev/null +++ b/examples/server-common.cpp @@ -0,0 +1,350 @@ +#include "server-common.h" +#include "common-whisper.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined (_WIN32) +#include +#endif + +const std::string json_format = "json"; +const std::string text_format = "text"; +const std::string srt_format = "srt"; +const std::string vjson_format = "verbose_json"; +const std::string vtt_format = "vtt"; + +namespace { + std::function g_shutdown_callback; + std::atomic_flag g_is_terminating = ATOMIC_FLAG_INIT; + + void signal_handler(int /*signal*/) { + if (g_is_terminating.test_and_set()) { + fprintf(stderr, "Received second interrupt, terminating immediately.\n"); + exit(1); + } + if (g_shutdown_callback) { + g_shutdown_callback(); + } + } +} + +bool parse_str_to_bool(const std::string & s) { + if (s == "true" || s == "1" || s == "yes" || s == "y") { + return true; + } + return false; +} + +bool check_ffmpeg_availability() { + int result = system("ffmpeg -version"); + if (result == 0) { + std::cout << "ffmpeg is available." << std::endl; + } else { + std::cout << "ffmpeg is not found. Please ensure that ffmpeg is installed " + << "and that its executable is included in your system's PATH. "; + exit(0); + } + return true; +} + +std::string generate_temp_filename(const std::string & path, const std::string & prefix, const std::string & extension) { + auto now = std::chrono::system_clock::now(); + auto now_time_t = std::chrono::system_clock::to_time_t(now); + + static std::mt19937 rng{std::random_device{}()}; + std::uniform_int_distribution dist(0, 1e9); + + std::stringstream ss; + ss << path + << std::filesystem::path::preferred_separator + << prefix + << "-" + << std::put_time(std::localtime(&now_time_t), "%Y%m%d-%H%M%S") + << "-" + << dist(rng) + << extension; + + return ss.str(); +} + +bool convert_to_wav(const std::string & temp_filename, std::string & error_resp, bool stereo) { + std::ostringstream cmd_stream; + std::string converted_filename_temp = temp_filename + "_temp.wav"; + cmd_stream << "ffmpeg -i \"" << temp_filename << "\" -y -ar 16000 -ac " << (stereo ? 2 : 1) << " -c:a pcm_s16le \"" << converted_filename_temp << "\" 2>&1"; + std::string cmd = cmd_stream.str(); + + int status = std::system(cmd.c_str()); + if (status != 0) { + error_resp = "{\"error\":\"FFmpeg conversion failed.\"}"; + return false; + } + + if (remove(temp_filename.c_str()) != 0) { + error_resp = "{\"error\":\"Failed to remove the original file.\"}"; + return false; + } + + if (rename(converted_filename_temp.c_str(), temp_filename.c_str()) != 0) { + error_resp = "{\"error\":\"Failed to rename the temporary file.\"}"; + return false; + } + return true; +} + +void setup_signal_handler(std::function shutdown_callback) { + g_shutdown_callback = std::move(shutdown_callback); + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = signal_handler; + sigemptyset(&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); + sigaction(SIGTERM, &sigint_action, NULL); +#elif defined (_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); +#endif +} + +static std::string ms_to_timestamp(int64_t t_ms, bool comma = false) { + // to_timestamp expects centiseconds, our adapter uses milliseconds + return ::to_timestamp(t_ms / 10, comma); +} + + +std::string format_text(const transcription_result & result) { + std::stringstream ss; + const int n_segments = result.n_segments(); + for (int i = 0; i < n_segments; ++i) { + auto seg = result.get_segment(i); + auto speaker = result.get_speaker(i); + ss << speaker << seg.text << "\n"; + } + return ss.str(); +} + +std::string format_srt(const transcription_result & result, int offset_n) { + std::stringstream ss; + const int n_segments = result.n_segments(); + for (int i = 0; i < n_segments; ++i) { + auto seg = result.get_segment(i); + auto speaker = result.get_speaker(i); + + ss << i + 1 + offset_n << "\n"; + ss << ms_to_timestamp(seg.t0, true) << " --> " << ms_to_timestamp(seg.t1, true) << "\n"; + ss << speaker << seg.text << "\n\n"; + } + return ss.str(); +} + +std::string format_vtt(const transcription_result & result) { + std::stringstream ss; + ss << "WEBVTT\n\n"; + + const int n_segments = result.n_segments(); + for (int i = 0; i < n_segments; ++i) { + auto seg = result.get_segment(i); + std::string speaker_tag; + + auto speaker_id = result.get_speaker(i); + if (!speaker_id.empty()) { + speaker_tag = ""; + } + + ss << ms_to_timestamp(seg.t0) << " --> " << ms_to_timestamp(seg.t1) << "\n"; + ss << speaker_tag << seg.text << "\n\n"; + } + return ss.str(); +} + +std::string format_json(const transcription_result & result) { + std::string text = format_text(result); + json jres = json{{"text", text}}; + return jres.dump(-1, ' ', false, json::error_handler_t::replace); +} + +std::string format_verbose_json( + const transcription_result & result, + float temperature, + float duration, + bool no_timestamps, + bool token_timestamps) { + std::string text = format_text(result); + std::string task = result.get_task(); + std::string language = result.get_language(); + + json jres = json{ + {"task", task}, + {"language", language}, + {"duration", duration}, + {"text", text}, + {"segments", json::array()} + }; + + // Merge language probability data into the top-level response. + // Adapters return a json object whose keys are merged directly, allowing + // model-specific fields (e.g. whisper's detected_language) to appear at + // the top level alongside the standard language_probabilities map. + json lang_data = result.get_language_probabilities(); + for (auto & [key, val] : lang_data.items()) { + jres[key] = val; + } + + const int n_segments = result.n_segments(); + for (int i = 0; i < n_segments; ++i) { + auto seg = result.get_segment(i); + + json segment = json{ + {"id", i}, + {"text", seg.text}, + }; + + if (!no_timestamps) { + segment["start"] = seg.t0 * 0.001f; // ms -> seconds + segment["end"] = seg.t1 * 0.001f; + } + + auto speaker_id = result.get_speaker(i); + if (!speaker_id.empty()) { + segment["speaker"] = speaker_id; + } + + // Build word-level tokens by merging partial UTF-8 tokens + std::vector words; + int n_tokens = (int)seg.tokens.size(); + float total_logprob = 0.0f; + + for (int j = 0; j < n_tokens; ++j) { + auto & tok = seg.tokens[j]; + + // Merge trailing partial UTF-8 bytes into complete words + std::string word_text = tok.text; + int64_t word_t1 = tok.t1; + + while (j + 1 < n_tokens) { + int trailing = utf8_trailing_bytes_needed(word_text); + if (trailing <= 0) break; + + ++j; + auto & next_tok = seg.tokens[j]; + word_text += next_tok.text; + if (next_tok.t1 > word_t1) { + word_t1 = next_tok.t1; + } + } + + json word = json{{"word", word_text}}; + if (!no_timestamps && token_timestamps) { + word["start"] = tok.t0 * 0.001f; + word["end"] = word_t1 * 0.001f; + } + word["probability"] = tok.prob; + + // Approximate logprob from probability + float logprob = tok.prob > 0.0f ? std::log(tok.prob + 1e-10f) : -1e10f; + total_logprob += logprob; + + words.push_back(word); + } + + segment["words"] = words; + segment["tokens"] = json::array(); + for (auto & tok : seg.tokens) { + segment["tokens"].push_back(tok.id); + } + + segment["temperature"] = temperature; + int n_word_tokens = (int)seg.tokens.size(); + segment["avg_logprob"] = n_word_tokens > 0 ? total_logprob / n_word_tokens : 0.0f; + segment["no_speech_prob"] = seg.no_speech_prob; + + jres["segments"].push_back(segment); + } + + return jres.dump(-1, ' ', false, json::error_handler_t::replace); +} + +void setup_server_common( + httplib::Server & svr, + const server_params & sparams, + std::atomic & state, + std::function load_handler, + std::function inference_handler, + const std::string & default_content, + const std::string & server_name) { + + svr.set_default_headers({ + {"Server", server_name}, + {"Access-Control-Allow-Origin", "*"}, + {"Access-Control-Allow-Headers", "content-type, authorization"} + }); + + // Default index page + svr.Get(sparams.request_path + "/", [&](const httplib::Request &, httplib::Response & res) { + res.set_content(default_content, "text/html"); + return false; + }); + + // CORS preflight + svr.Options(sparams.request_path + sparams.inference_path, + [&](const httplib::Request &, httplib::Response &) {}); + + // Inference endpoint + svr.Post(sparams.request_path + sparams.inference_path, inference_handler); + + // Model reload endpoint + if (load_handler) { + svr.Post(sparams.request_path + "/load", load_handler); + } + + // Health check + svr.Get(sparams.request_path + "/health", [&](const httplib::Request &, httplib::Response & res) { + server_state current_state = state.load(); + if (current_state == SERVER_STATE_READY) { + res.set_content("{\"status\":\"ok\"}", "application/json"); + } else { + res.set_content("{\"status\":\"loading model\"}", "application/json"); + res.status = 503; + } + }); + + // Exception handler + svr.set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { + const char fmt[] = "500 Internal Server Error\n%s"; + char buf[BUFSIZ]; + try { + std::rethrow_exception(std::move(ep)); + } catch (std::exception & e) { + snprintf(buf, sizeof(buf), fmt, e.what()); + } catch (...) { + snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); + } + res.set_content(buf, "text/plain"); + res.status = 500; + }); + + // Error handler + svr.set_error_handler([](const httplib::Request & req, httplib::Response & res) { + if (res.status == 400) { + res.set_content("Invalid request", "text/plain"); + } else if (res.status != 500) { + res.set_content("File Not Found (" + req.path + ")", "text/plain"); + res.status = 404; + } + }); + + svr.set_read_timeout(sparams.read_timeout); + svr.set_write_timeout(sparams.write_timeout); +} diff --git a/examples/server-common.h b/examples/server-common.h new file mode 100644 index 00000000000..1bb99a013d9 --- /dev/null +++ b/examples/server-common.h @@ -0,0 +1,100 @@ +// Common server utilities for whisper.cpp and parakeet.cpp servers +// Extracts shared HTTP infrastructure, response formatting, and request handling. +#pragma once + +#include "httplib.h" +#include "json.hpp" + +#include +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +enum server_state { + SERVER_STATE_LOADING_MODEL, + SERVER_STATE_READY, +}; + +struct server_params { + std::string hostname = "127.0.0.1"; + std::string public_path = "examples/server/public"; + std::string request_path = ""; + std::string inference_path = "/inference"; + std::string tmp_dir = "."; + + int32_t port = 8080; + int32_t read_timeout = 600; + int32_t write_timeout = 600; + + bool ffmpeg_converter = false; + bool keep_input_audio = false; +}; + +struct segment_token { + int32_t id; + std::string text; + int64_t t0; // in ms + int64_t t1; // in ms + float prob; +}; + +struct segment_info { + std::string text; + int64_t t0; // in ms + int64_t t1; // in ms + float no_speech_prob; + std::vector tokens; +}; + +struct transcription_result { + virtual ~transcription_result() = default; + + virtual int n_segments() const = 0; + virtual segment_info get_segment(int i) const = 0; + + virtual std::string get_speaker(int /*i*/) const { return {}; } + virtual std::string get_language() const { return {}; } + virtual json get_language_probabilities() const { return {}; } + virtual std::string get_task() const { return "transcribe"; } +}; + +extern const std::string json_format; +extern const std::string text_format; +extern const std::string srt_format; +extern const std::string vjson_format; +extern const std::string vtt_format; + +std::string format_text(const transcription_result & result); +std::string format_srt(const transcription_result & result, int offset_n = 0); +std::string format_vtt(const transcription_result & result); +std::string format_json(const transcription_result & result); +std::string format_verbose_json( + const transcription_result & result, + float temperature, + float duration, + bool no_timestamps, + bool token_timestamps); + + +bool parse_str_to_bool(const std::string & s); + +bool check_ffmpeg_availability(); + +std::string generate_temp_filename(const std::string & path, const std::string & prefix, const std::string & extension); + +bool convert_to_wav(const std::string & temp_filename, std::string & error_resp, bool stereo); + +void setup_signal_handler(std::function shutdown_callback); + +// Set up common server configuration (CORS, error handlers, timeouts) +void setup_server_common( + httplib::Server & svr, + const server_params & sparams, + std::atomic & state, + std::function load_handler, + std::function inference_handler, + const std::string & default_content, + const std::string & server_name); diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index c082546bdf1..ced41c93864 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -2,7 +2,9 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(TARGET whisper-server) -add_executable(${TARGET} server.cpp httplib.h) +add_executable(${TARGET} server.cpp ../httplib.h) + +target_sources(${TARGET} PRIVATE ../server-common.cpp) include(DefaultTargetOptions) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b87ef27375f..f58df80e6db 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1,5 +1,6 @@ #include "common.h" #include "common-whisper.h" +#include "server-common.h" #include "whisper.h" #include "httplib.h" @@ -17,7 +18,6 @@ #include #include #include -#include #include #if defined (_WIN32) #include @@ -26,49 +26,8 @@ using namespace httplib; using json = nlohmann::ordered_json; -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded -}; - namespace { -// output formats -const std::string json_format = "json"; -const std::string text_format = "text"; -const std::string srt_format = "srt"; -const std::string vjson_format = "verbose_json"; -const std::string vtt_format = "vtt"; - -std::function shutdown_handler; -std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; - -inline void signal_handler(int signal) { - if (is_terminating.test_and_set()) { - // in case it hangs, we can force terminate the server by hitting Ctrl+C twice - // this is for better developer experience, we can remove when the server is stable enough - fprintf(stderr, "Received second interrupt, terminating immediately.\n"); - exit(1); - } - - shutdown_handler(signal); -} - -struct server_params -{ - std::string hostname = "127.0.0.1"; - std::string public_path = "examples/server/public"; - std::string request_path = ""; - std::string inference_path = "/inference"; - std::string tmp_dir = "."; - - int32_t port = 8080; - int32_t read_timeout = 600; - int32_t write_timeout = 600; - - bool ffmpeg_converter = false; -}; - struct whisper_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_processors = 1; @@ -282,65 +241,6 @@ struct whisper_print_user_data { int progress_prev; }; -void check_ffmpeg_availibility() { - int result = system("ffmpeg -version"); - - if (result == 0) { - std::cout << "ffmpeg is available." << std::endl; - } else { - // ffmpeg is not available - std::cout << "ffmpeg is not found. Please ensure that ffmpeg is installed "; - std::cout << "and that its executable is included in your system's PATH. "; - exit(0); - } -} - -std::string generate_temp_filename(const std::string &path, const std::string &prefix, const std::string &extension) { - auto now = std::chrono::system_clock::now(); - auto now_time_t = std::chrono::system_clock::to_time_t(now); - - static std::mt19937 rng{std::random_device{}()}; - std::uniform_int_distribution dist(0, 1e9); - - std::stringstream ss; - ss << path - << std::filesystem::path::preferred_separator - << prefix - << "-" - << std::put_time(std::localtime(&now_time_t), "%Y%m%d-%H%M%S") - << "-" - << dist(rng) - << extension; - - return ss.str(); -} - -bool convert_to_wav(const std::string & temp_filename, std::string & error_resp, bool stereo) { - std::ostringstream cmd_stream; - std::string converted_filename_temp = temp_filename + "_temp.wav"; - cmd_stream << "ffmpeg -i \"" << temp_filename << "\" -y -ar 16000 -ac " << (stereo ? 2 : 1) << " -c:a pcm_s16le \"" << converted_filename_temp << "\" 2>&1"; - std::string cmd = cmd_stream.str(); - - int status = std::system(cmd.c_str()); - if (status != 0) { - error_resp = "{\"error\":\"FFmpeg conversion failed.\"}"; - return false; - } - - // Remove the original file - if (remove(temp_filename.c_str()) != 0) { - error_resp = "{\"error\":\"Failed to remove the original file.\"}"; - return false; - } - - // Rename the temporary file to match the original filename - if (rename(converted_filename_temp.c_str(), temp_filename.c_str()) != 0) { - error_resp = "{\"error\":\"Failed to rename the temporary file.\"}"; - return false; - } - return true; -} - std::string estimate_diarization_speaker(const std::vector> & pcmf32s, int64_t t0, int64_t t1, bool id_only = false) { std::string speaker = ""; const int64_t n_samples = pcmf32s[0].size(); @@ -451,32 +351,6 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper } } -std::string output_str(struct whisper_context * ctx, const whisper_params & params, const std::vector> & pcmf32s) { - std::stringstream result; - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); - std::string speaker = ""; - - if (params.diarize && pcmf32s.size() == 2) - { - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - speaker = estimate_diarization_speaker(pcmf32s, t0, t1); - } - - result << speaker << text << "\n"; - } - return result.str(); -} - -bool parse_str_to_bool(const std::string & s) { - if (s == "true" || s == "1" || s == "yes" || s == "y") { - return true; - } - return false; -} - void get_req_parameters(const Request & req, whisper_params & params) { if (req.has_file("offset_t")) @@ -627,6 +501,87 @@ void get_req_parameters(const Request & req, whisper_params & params) } } +struct whisper_result : transcription_result { + whisper_context * ctx; + const whisper_params & params; + const std::vector> & pcmf32s; + + whisper_result(whisper_context * c, const whisper_params & p, + const std::vector> & s) + : ctx(c), params(p), pcmf32s(s) {} + + int n_segments() const override { + return whisper_full_n_segments(ctx); + } + + segment_info get_segment(int i) const override { + segment_info seg; + seg.text = whisper_full_get_segment_text(ctx, i); + seg.t0 = whisper_full_get_segment_t0(ctx, i) * 10; // centiseconds -> ms + seg.t1 = whisper_full_get_segment_t1(ctx, i) * 10; + seg.no_speech_prob = whisper_full_get_segment_no_speech_prob(ctx, i); + + const int n_tokens = whisper_full_n_tokens(ctx, i); + seg.tokens.reserve(n_tokens); + for (int j = 0; j < n_tokens; ++j) { + whisper_token_data tok = whisper_full_get_token_data(ctx, i, j); + if (tok.id >= whisper_token_eot(ctx)) { + continue; + } + segment_token st; + st.id = tok.id; + st.text = whisper_full_get_token_text(ctx, i, j); + st.t0 = tok.t0 * 10; // centiseconds -> ms + st.t1 = tok.t1 * 10; + st.prob = tok.p; + seg.tokens.push_back(st); + } + return seg; + } + + std::string get_speaker(int i) const override { + if (params.diarize && pcmf32s.size() == 2) { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + return estimate_diarization_speaker(pcmf32s, t0, t1); + } + if (params.tinydiarize) { + return whisper_full_get_segment_speaker_turn_next(ctx, i) + ? params.tdrz_speaker_turn : std::string{}; + } + return {}; + } + + std::string get_language() const override { + return whisper_lang_str_full(whisper_full_lang_id(ctx)); + } + + json get_language_probabilities() const override { + if (params.no_language_probabilities) { + return {}; + } + std::vector lang_probs(whisper_lang_max_id() + 1, 0.0f); + const auto detected_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, lang_probs.data()); + + json lang_prob_map = json::object(); + for (int i = 0; i <= whisper_lang_max_id(); ++i) { + if (lang_probs[i] > 0.001f) { + lang_prob_map[whisper_lang_str(i)] = lang_probs[i]; + } + } + + return json{ + {"detected_language", whisper_lang_str_full(detected_id)}, + {"detected_language_probability", lang_probs[detected_id]}, + {"language_probabilities", lang_prob_map}, + }; + } + + std::string get_task() const override { + return params.translate ? "translate" : "transcribe"; + } +}; + } // namespace int main(int argc, char ** argv) { @@ -655,7 +610,7 @@ int main(int argc, char ** argv) { } if (sparams.ffmpeg_converter) { - check_ffmpeg_availibility(); + check_ffmpeg_availability(); } // whisper init struct whisper_context_params cparams = whisper_context_default_params(); @@ -725,11 +680,6 @@ int main(int argc, char ** argv) { whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr); state.store(SERVER_STATE_READY); - - svr->set_default_headers({{"Server", "whisper.cpp"}, - {"Access-Control-Allow-Origin", "*"}, - {"Access-Control-Allow-Headers", "content-type, authorization"}}); - std::string const default_content = R"( @@ -805,16 +755,7 @@ int main(int argc, char ** argv) { // store default params so we can reset after each inference request whisper_params default_params = params; - // this is only called if no index.html is found in the public --path - svr->Get(sparams.request_path + "/", [&](const Request &, Response &res){ - res.set_content(default_content, "text/html"); - return false; - }); - - svr->Options(sparams.request_path + sparams.inference_path, [&](const Request &, Response &){ - }); - - svr->Post(sparams.request_path + sparams.inference_path, [&](const Request &req, Response &res){ + auto inference_handler = [&](const Request & req, Response & res) { // acquire whisper model mutex lock std::lock_guard lock(whisper_mutex); @@ -1004,163 +945,31 @@ int main(int argc, char ** argv) { } } + whisper_result result{ctx, params, pcmf32s}; + // return results to user if (params.response_format == text_format) { - std::string results = output_str(ctx, params, pcmf32s); - res.set_content(results.c_str(), "text/html; charset=utf-8"); + res.set_content(format_text(result), "text/html; charset=utf-8"); } else if (params.response_format == srt_format) { - std::stringstream ss; - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - std::string speaker = ""; - - if (params.diarize && pcmf32s.size() == 2) - { - speaker = estimate_diarization_speaker(pcmf32s, t0, t1); - } - - ss << i + 1 + params.offset_n << "\n"; - ss << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n"; - ss << speaker << text << "\n\n"; - } - res.set_content(ss.str(), "application/x-subrip"); + res.set_content(format_srt(result, params.offset_n), "application/x-subrip"); } else if (params.response_format == vtt_format) { - std::stringstream ss; - - ss << "WEBVTT\n\n"; - - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - std::string speaker = ""; - - if (params.diarize && pcmf32s.size() == 2) - { - speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true); - speaker.insert(0, ""); - } - - ss << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; - ss << speaker << text << "\n\n"; - } - res.set_content(ss.str(), "text/vtt"); + res.set_content(format_vtt(result), "text/vtt"); } else if (params.response_format == vjson_format) { - /* try to match openai/whisper's Python format */ - std::string results = output_str(ctx, params, pcmf32s); - json jres = json{ - {"task", params.translate ? "translate" : "transcribe"}, - {"language", whisper_lang_str_full(whisper_full_lang_id(ctx))}, - {"duration", float(pcmf32.size())/WHISPER_SAMPLE_RATE}, - {"text", results}, - {"segments", json::array()} - }; - // Only compute language probabilities if requested (expensive operation) - if (!params.no_language_probabilities) { - std::vector lang_probs(whisper_lang_max_id() + 1, 0.0f); - const auto detected_lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, lang_probs.data()); - jres["detected_language"] = whisper_lang_str_full(detected_lang_id); - jres["detected_language_probability"] = lang_probs[detected_lang_id]; - jres["language_probabilities"] = json::object(); - // Add all language probabilities - for (int i = 0; i <= whisper_lang_max_id(); ++i) { - if (lang_probs[i] > 0.001f) { // Only include non-negligible probabilities - jres["language_probabilities"][whisper_lang_str(i)] = lang_probs[i]; - } - } - } - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) - { - json segment = json{ - {"id", i}, - {"text", whisper_full_get_segment_text(ctx, i)}, - }; - - if (!params.no_timestamps) { - segment["start"] = whisper_full_get_segment_t0(ctx, i) * 0.01; - segment["end"] = whisper_full_get_segment_t1(ctx, i) * 0.01; - } - - if (params.diarize && pcmf32s.size() == 2) { - segment["speaker"] = estimate_diarization_speaker( - pcmf32s, - whisper_full_get_segment_t0(ctx, i), - whisper_full_get_segment_t1(ctx, i), - true); - } - - float total_logprob = 0; - const int n_tokens = whisper_full_n_tokens(ctx, i); - for (int j = 0; j < n_tokens; ++j) { - whisper_token_data token = whisper_full_get_token_data(ctx, i, j); - if (token.id >= whisper_token_eot(ctx)) { - continue; - } - - segment["tokens"].push_back(token.id); - std::string word_text = whisper_full_get_token_text(ctx, i, j); - int64_t word_t1 = token.t1; - - while (j + 1 < n_tokens && utf8_trailing_bytes_needed(word_text) > 0) { - const whisper_token_data next_token = whisper_full_get_token_data(ctx, i, j + 1); - // Keep verbose_json tokens free of EOT ids, matching the pre-merge server behavior. - if (next_token.id >= whisper_token_eot(ctx)) { - break; - } - - ++j; - segment["tokens"].push_back(next_token.id); - word_text += whisper_full_get_token_text(ctx, i, j); - if (next_token.t1 > -1) { - word_t1 = next_token.t1; - } - total_logprob += next_token.plog; - } - - json word = json{{"word", word_text}}; - if (!params.no_timestamps && params.token_timestamps) { - word["start"] = token.t0 * 0.01; - word["end"] = word_t1 * 0.01; - word["t_dtw"] = token.t_dtw; - } - word["probability"] = token.p; - total_logprob += token.plog; - segment["words"].push_back(word); - } - - segment["temperature"] = params.temperature; - segment["avg_logprob"] = total_logprob / n_tokens; - - // TODO compression_ratio and no_speech_prob are not implemented yet - // segment["compression_ratio"] = 0; - segment["no_speech_prob"] = whisper_full_get_segment_no_speech_prob(ctx, i); - - jres["segments"].push_back(segment); - } - res.set_content(jres.dump(-1, ' ', false, json::error_handler_t::replace), - "application/json"); + res.set_content( + format_verbose_json(result, params.temperature, float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.no_timestamps, params.token_timestamps), + "application/json"); } // TODO add more output formats else { - std::string results = output_str(ctx, params, pcmf32s); - json jres = json{ - {"text", results} - }; - res.set_content(jres.dump(-1, ' ', false, json::error_handler_t::replace), - "application/json"); + res.set_content(format_json(result), "application/json"); } - }); - svr->Post(sparams.request_path + "/load", [&](const Request &req, Response &res){ + }; + + auto load_handler = [&](const Request & req, Response & res) { std::lock_guard lock(whisper_mutex); state.store(SERVER_STATE_LOADING_MODEL); if (!req.has_file("model")) @@ -1201,46 +1010,17 @@ int main(int argc, char ** argv) { res.set_content(success, "application/text"); // check if the model is in the file system - }); - - svr->Get(sparams.request_path + "/health", [&](const Request &, Response &res){ - server_state current_state = state.load(); - if (current_state == SERVER_STATE_READY) { - const std::string health_response = "{\"status\":\"ok\"}"; - res.set_content(health_response, "application/json"); - } else { - res.set_content("{\"status\":\"loading model\"}", "application/json"); - res.status = 503; - } - }); + }; - svr->set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) { - const char fmt[] = "500 Internal Server Error\n%s"; - char buf[BUFSIZ]; - try { - std::rethrow_exception(std::move(ep)); - } catch (std::exception &e) { - snprintf(buf, sizeof(buf), fmt, e.what()); - } catch (...) { - snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); - } - res.set_content(buf, "text/plain"); - res.status = 500; - }); + setup_server_common(*svr, sparams, state, load_handler, inference_handler, default_content, "whisper.cpp"); - svr->set_error_handler([](const Request &req, Response &res) { - if (res.status == 400) { - res.set_content("Invalid request", "text/plain"); - } else if (res.status != 500) { - res.set_content("File Not Found (" + req.path + ")", "text/plain"); - res.status = 404; + setup_signal_handler([&]() { + printf("\nCaught shutdown signal, shutting down gracefully...\n"); + if (svr) { + svr->stop(); } }); - // set timeouts and change hostname and port - svr->set_read_timeout(sparams.read_timeout); - svr->set_write_timeout(sparams.write_timeout); - if (!svr->bind_to_port(sparams.hostname, sparams.port)) { fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", @@ -1254,27 +1034,6 @@ int main(int argc, char ** argv) { // to make it ctrl+clickable: printf("\nwhisper server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port); - shutdown_handler = [&](int signal) { - printf("\nCaught signal %d, shutting down gracefully...\n", signal); - if (svr) { - svr->stop(); - } - }; - -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) - struct sigaction sigint_action; - sigint_action.sa_handler = signal_handler; - sigemptyset (&sigint_action.sa_mask); - sigint_action.sa_flags = 0; - sigaction(SIGINT, &sigint_action, NULL); - sigaction(SIGTERM, &sigint_action, NULL); -#elif defined (_WIN32) - auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { - return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; - }; - SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); -#endif - // clean up function, to be called before exit auto clean_up = [&]() { whisper_print_timings(ctx);