From 8ace28d964a6427802d0470d2ff6d91b0cf8acd3 Mon Sep 17 00:00:00 2001 From: Noah Lyons Date: Sun, 31 May 2026 18:52:50 -0400 Subject: [PATCH] server: merge split utf-8 token text in verbose json --- examples/cli/cli.cpp | 33 --------------------------------- examples/common-whisper.cpp | 28 ++++++++++++++++++++++++++++ examples/common-whisper.h | 3 +++ examples/server/server.cpp | 23 +++++++++++++++++++++-- tests/CMakeLists.txt | 8 ++++++++ tests/test-common-utf8.cpp | 34 ++++++++++++++++++++++++++++++++++ 6 files changed, 94 insertions(+), 35 deletions(-) create mode 100644 tests/test-common-utf8.cpp diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 55cd71b4e55..7ca563dc250 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -31,39 +31,6 @@ static void replace_all(std::string & s, const std::string & search, const std:: } } -// Returns the number of trailing continuation bytes still needed for `s` to end -// on a complete UTF-8 codepoint. Returns 0 if the tail of `s` is already a -// complete codepoint (or if the tail looks malformed and we should stop merging). -// Used to merge whisper tokens whose bytes split a multi-byte UTF-8 character -// (e.g. CJK), so the JSON output stays valid UTF-8. See https://github.com/ggml-org/whisper.cpp/issues/1798. -static int utf8_trailing_bytes_needed(const std::string & s) { - const int n = (int) s.size(); - int i = n - 1; - // walk back past continuation bytes (10xxxxxx) - while (i >= 0 && ((unsigned char) s[i] & 0xC0) == 0x80) { - --i; - } - if (i < 0) { - // all continuation bytes, or empty — nothing we can do - return 0; - } - const unsigned char c = (unsigned char) s[i]; - int expected; - if ((c & 0x80) == 0x00) { - expected = 1; // ASCII - } else if ((c & 0xE0) == 0xC0) { - expected = 2; - } else if ((c & 0xF0) == 0xE0) { - expected = 3; - } else if ((c & 0xF8) == 0xF0) { - expected = 4; - } else { - return 0; // malformed lead, give up - } - const int have = n - i; - return have >= expected ? 0 : (expected - have); -} - // command-line parameters struct whisper_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); diff --git a/examples/common-whisper.cpp b/examples/common-whisper.cpp index c84e6843adc..b12481c013f 100644 --- a/examples/common-whisper.cpp +++ b/examples/common-whisper.cpp @@ -198,6 +198,34 @@ int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate) { return std::max(0, std::min((int) n_samples - 1, (int) ((t*whisper_sample_rate)/100))); } +int utf8_trailing_bytes_needed(const std::string & s) { + const int n = (int) s.size(); + int i = n - 1; + while (i >= 0 && ((unsigned char) s[i] & 0xC0) == 0x80) { + --i; + } + if (i < 0) { + return 0; + } + + const unsigned char c = (unsigned char) s[i]; + int expected; + if ((c & 0x80) == 0x00) { + expected = 1; + } else if ((c & 0xE0) == 0xC0) { + expected = 2; + } else if ((c & 0xF0) == 0xE0) { + expected = 3; + } else if ((c & 0xF8) == 0xF0) { + expected = 4; + } else { + return 0; + } + + const int have = n - i; + return have >= expected ? 0 : (expected - have); +} + bool speak_with_file(const std::string & command, const std::string & text, const std::string & path, int voice_id) { std::ofstream speak_file(path.c_str()); if (speak_file.fail()) { diff --git a/examples/common-whisper.h b/examples/common-whisper.h index 8714c381046..aec430d3635 100644 --- a/examples/common-whisper.h +++ b/examples/common-whisper.h @@ -28,5 +28,8 @@ std::string to_timestamp(int64_t t, bool comma = false); // given a timestamp get the sample int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate); +// Returns the number of trailing bytes still needed for s to end on a complete UTF-8 codepoint. +int utf8_trailing_bytes_needed(const std::string & s); + // write text to file, and call system("command voice_id file") bool speak_with_file(const std::string & command, const std::string & text, const std::string & path, int voice_id); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index aae74c3d840..b87ef27375f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1107,10 +1107,29 @@ int main(int argc, char ** argv) { } segment["tokens"].push_back(token.id); - json word = json{{"word", whisper_full_get_token_text(ctx, i, j)}}; + std::string word_text = whisper_full_get_token_text(ctx, i, j); + int64_t word_t1 = token.t1; + + while (j + 1 < n_tokens && utf8_trailing_bytes_needed(word_text) > 0) { + const whisper_token_data next_token = whisper_full_get_token_data(ctx, i, j + 1); + // Keep verbose_json tokens free of EOT ids, matching the pre-merge server behavior. + if (next_token.id >= whisper_token_eot(ctx)) { + break; + } + + ++j; + segment["tokens"].push_back(next_token.id); + word_text += whisper_full_get_token_text(ctx, i, j); + if (next_token.t1 > -1) { + word_t1 = next_token.t1; + } + total_logprob += next_token.plog; + } + + json word = json{{"word", word_text}}; if (!params.no_timestamps && params.token_timestamps) { word["start"] = token.t0 * 0.01; - word["end"] = token.t1 * 0.01; + word["end"] = word_t1 * 0.01; word["t_dtw"] = token.t_dtw; } word["probability"] = token.p; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0593b748d36..646f45f2ab7 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -88,6 +88,14 @@ if (WHISPER_COMMON_FFMPEG) set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "tiny;mp3") endif() +# UTF-8 helper unit test +set(UTF8_TEST test-common-utf8) +add_executable(${UTF8_TEST} ${UTF8_TEST}.cpp) +target_include_directories(${UTF8_TEST} PRIVATE ../examples) +target_link_libraries(${UTF8_TEST} PRIVATE common) +add_test(NAME ${UTF8_TEST} COMMAND ${UTF8_TEST}) +set_tests_properties(${UTF8_TEST} PROPERTIES LABELS "unit") + # VAD test tests VAD in isolation set(VAD_TEST test-vad) add_executable(${VAD_TEST} ${VAD_TEST}.cpp) diff --git a/tests/test-common-utf8.cpp b/tests/test-common-utf8.cpp new file mode 100644 index 00000000000..91c73a7428d --- /dev/null +++ b/tests/test-common-utf8.cpp @@ -0,0 +1,34 @@ +#include "common-whisper.h" + +#include +#include +#include + +static void expect_needed(const std::string & input, int expected) { + const int actual = utf8_trailing_bytes_needed(input); + if (actual != expected) { + fprintf(stderr, "expected %d trailing UTF-8 bytes, got %d\n", expected, actual); + std::abort(); + } +} + +int main() { + expect_needed("", 0); + expect_needed("plain ascii", 0); + + const std::string cjk = "\xE4\xBD\xA0"; // U+4F60 + expect_needed(cjk.substr(0, 1), 2); + expect_needed(cjk.substr(0, 2), 1); + expect_needed(cjk, 0); + + const std::string emoji = "\xF0\x9F\x98\x80"; // U+1F600 + expect_needed(emoji.substr(0, 1), 3); + expect_needed(emoji.substr(0, 2), 2); + expect_needed(emoji.substr(0, 3), 1); + expect_needed(emoji, 0); + + expect_needed("\x80\x80", 0); + expect_needed("\xFF", 0); + + return 0; +}