Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 0 additions & 33 deletions examples/cli/cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,39 +31,6 @@ static void replace_all(std::string & s, const std::string & search, const std::
}
}

// Returns the number of trailing continuation bytes still needed for `s` to end
// on a complete UTF-8 codepoint. Returns 0 if the tail of `s` is already a
// complete codepoint (or if the tail looks malformed and we should stop merging).
// Used to merge whisper tokens whose bytes split a multi-byte UTF-8 character
// (e.g. CJK), so the JSON output stays valid UTF-8. See https://github.com/ggml-org/whisper.cpp/issues/1798.
static int utf8_trailing_bytes_needed(const std::string & s) {
const int n = (int) s.size();
int i = n - 1;
// walk back past continuation bytes (10xxxxxx)
while (i >= 0 && ((unsigned char) s[i] & 0xC0) == 0x80) {
--i;
}
if (i < 0) {
// all continuation bytes, or empty — nothing we can do
return 0;
}
const unsigned char c = (unsigned char) s[i];
int expected;
if ((c & 0x80) == 0x00) {
expected = 1; // ASCII
} else if ((c & 0xE0) == 0xC0) {
expected = 2;
} else if ((c & 0xF0) == 0xE0) {
expected = 3;
} else if ((c & 0xF8) == 0xF0) {
expected = 4;
} else {
return 0; // malformed lead, give up
}
const int have = n - i;
return have >= expected ? 0 : (expected - have);
}

// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
Expand Down
28 changes: 28 additions & 0 deletions examples/common-whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,34 @@ int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*whisper_sample_rate)/100)));
}

int utf8_trailing_bytes_needed(const std::string & s) {
const int n = (int) s.size();
int i = n - 1;
while (i >= 0 && ((unsigned char) s[i] & 0xC0) == 0x80) {
--i;
}
if (i < 0) {
return 0;
}

const unsigned char c = (unsigned char) s[i];
int expected;
if ((c & 0x80) == 0x00) {
expected = 1;
} else if ((c & 0xE0) == 0xC0) {
expected = 2;
} else if ((c & 0xF0) == 0xE0) {
expected = 3;
} else if ((c & 0xF8) == 0xF0) {
expected = 4;
} else {
return 0;
}

const int have = n - i;
return have >= expected ? 0 : (expected - have);
}

bool speak_with_file(const std::string & command, const std::string & text, const std::string & path, int voice_id) {
std::ofstream speak_file(path.c_str());
if (speak_file.fail()) {
Expand Down
3 changes: 3 additions & 0 deletions examples/common-whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,8 @@ std::string to_timestamp(int64_t t, bool comma = false);
// given a timestamp get the sample
int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate);

// Returns the number of trailing bytes still needed for s to end on a complete UTF-8 codepoint.
int utf8_trailing_bytes_needed(const std::string & s);

// write text to file, and call system("command voice_id file")
bool speak_with_file(const std::string & command, const std::string & text, const std::string & path, int voice_id);
23 changes: 21 additions & 2 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1107,10 +1107,29 @@ int main(int argc, char ** argv) {
}

segment["tokens"].push_back(token.id);
json word = json{{"word", whisper_full_get_token_text(ctx, i, j)}};
std::string word_text = whisper_full_get_token_text(ctx, i, j);
int64_t word_t1 = token.t1;

while (j + 1 < n_tokens && utf8_trailing_bytes_needed(word_text) > 0) {
const whisper_token_data next_token = whisper_full_get_token_data(ctx, i, j + 1);
// Keep verbose_json tokens free of EOT ids, matching the pre-merge server behavior.
if (next_token.id >= whisper_token_eot(ctx)) {
break;
}

++j;
segment["tokens"].push_back(next_token.id);
word_text += whisper_full_get_token_text(ctx, i, j);
if (next_token.t1 > -1) {
word_t1 = next_token.t1;
}
total_logprob += next_token.plog;
}

json word = json{{"word", word_text}};
if (!params.no_timestamps && params.token_timestamps) {
word["start"] = token.t0 * 0.01;
word["end"] = token.t1 * 0.01;
word["end"] = word_t1 * 0.01;
word["t_dtw"] = token.t_dtw;
}
word["probability"] = token.p;
Expand Down
8 changes: 8 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ if (WHISPER_COMMON_FFMPEG)
set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "tiny;mp3")
endif()

# UTF-8 helper unit test
set(UTF8_TEST test-common-utf8)
add_executable(${UTF8_TEST} ${UTF8_TEST}.cpp)
target_include_directories(${UTF8_TEST} PRIVATE ../examples)
target_link_libraries(${UTF8_TEST} PRIVATE common)
add_test(NAME ${UTF8_TEST} COMMAND ${UTF8_TEST})
set_tests_properties(${UTF8_TEST} PROPERTIES LABELS "unit")

# VAD test tests VAD in isolation
set(VAD_TEST test-vad)
add_executable(${VAD_TEST} ${VAD_TEST}.cpp)
Expand Down
34 changes: 34 additions & 0 deletions tests/test-common-utf8.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include "common-whisper.h"

#include <cstdlib>
#include <cstdio>
#include <string>

static void expect_needed(const std::string & input, int expected) {
const int actual = utf8_trailing_bytes_needed(input);
if (actual != expected) {
fprintf(stderr, "expected %d trailing UTF-8 bytes, got %d\n", expected, actual);
std::abort();
}
}

int main() {
expect_needed("", 0);
expect_needed("plain ascii", 0);

const std::string cjk = "\xE4\xBD\xA0"; // U+4F60
expect_needed(cjk.substr(0, 1), 2);
expect_needed(cjk.substr(0, 2), 1);
expect_needed(cjk, 0);

const std::string emoji = "\xF0\x9F\x98\x80"; // U+1F600
expect_needed(emoji.substr(0, 1), 3);
expect_needed(emoji.substr(0, 2), 2);
expect_needed(emoji.substr(0, 3), 1);
expect_needed(emoji, 0);

expect_needed("\x80\x80", 0);
expect_needed("\xFF", 0);

return 0;
}
Loading