Skip to content

Commit 909307c

Browse files
buxukulinxiaodong
andauthored
whisper : make voice_length() utf-8 aware for CJK (#3915)
* whisper : make voice_length() utf-8 aware for CJK voice_length() weights each token by how long its text takes to say, which drives how a segment's time is shared between its tokens. It looped over raw bytes, so every CJK character (3 bytes) was counted ~3x and full-width punctuation never matched, skewing token timestamps for Chinese/Japanese. Decode one utf-8 code point at a time and give full-width ,。!? etc. the same weights as their ASCII counterparts. Pure-ASCII text is unaffected. * whisper : one statement per line in voice_length() --------- Co-authored-by: linxiaodong <calm.lin@wukongsch.com>
1 parent 0874de3 commit 909307c

1 file changed

Lines changed: 76 additions & 14 deletions

File tree

src/whisper.cpp

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8397,24 +8397,86 @@ static int64_t sample_to_timestamp(int i_sample) {
83978397

83988398
// a cost-function / heuristic that is high for text that takes longer to pronounce
83998399
// obviously, can be improved
8400+
//
8401+
// iterate over utf-8 code points rather than raw bytes: a CJK glyph is 3 bytes, so the
8402+
// old per-byte loop counted every Han/kana/hangul character ~3x and never matched
8403+
// full-width punctuation, skewing how a segment's time is shared between its tokens for
8404+
// Chinese/Japanese. full-width punctuation gets the same weight as its ASCII form and
8405+
// pure-ASCII text decodes to the same weights as before.
84008406
static float voice_length(const std::string & text) {
84018407
float res = 0.0f;
84028408

8403-
for (char c : text) {
8404-
if (c == ' ') {
8405-
res += 0.01f;
8406-
} else if (c == ',') {
8407-
res += 2.00f;
8408-
} else if (c == '.') {
8409-
res += 3.00f;
8410-
} else if (c == '!') {
8411-
res += 3.00f;
8412-
} else if (c == '?') {
8413-
res += 3.00f;
8414-
} else if (c >= '0' && c <= '9') {
8415-
res += 3.00f;
8409+
const unsigned char * s = (const unsigned char *) text.data();
8410+
const size_t n = text.size();
8411+
8412+
for (size_t i = 0; i < n; ) {
8413+
const unsigned char c = s[i];
8414+
uint32_t cp = c;
8415+
int len = 1;
8416+
if (c < 0x80) {
8417+
len = 1;
8418+
} else if ((c >> 5) == 0x6) {
8419+
cp = c & 0x1F;
8420+
len = 2;
8421+
} else if ((c >> 4) == 0xE) {
8422+
cp = c & 0x0F;
8423+
len = 3;
8424+
} else if ((c >> 3) == 0x1E) {
8425+
cp = c & 0x07;
8426+
len = 4;
84168427
} else {
8417-
res += 1.00f;
8428+
cp = c; // stray continuation / invalid lead byte
8429+
len = 1;
8430+
}
8431+
if (i + (size_t) len <= n) {
8432+
bool ok = true;
8433+
for (int k = 1; k < len; ++k) {
8434+
const unsigned char cc = s[i + k];
8435+
if ((cc & 0xC0) != 0x80) {
8436+
ok = false;
8437+
break;
8438+
}
8439+
cp = (cp << 6) | (cc & 0x3F);
8440+
}
8441+
if (!ok) {
8442+
cp = c;
8443+
len = 1;
8444+
}
8445+
} else {
8446+
cp = c;
8447+
len = 1;
8448+
}
8449+
i += (size_t) len;
8450+
8451+
switch (cp) {
8452+
case ' ':
8453+
case 0x3000: // ideographic space
8454+
res += 0.01f;
8455+
break;
8456+
case ',':
8457+
case 0xFF0C: //
8458+
case 0x3001: //
8459+
case 0xFF1B: //
8460+
case 0xFF1A: //
8461+
res += 2.00f;
8462+
break;
8463+
case '.':
8464+
case '!':
8465+
case '?':
8466+
case 0x3002: //
8467+
case 0xFF0E: //
8468+
case 0xFF01: //
8469+
case 0xFF1F: //
8470+
case 0x2026: //
8471+
res += 3.00f;
8472+
break;
8473+
default:
8474+
if ((cp >= '0' && cp <= '9') || (cp >= 0xFF10 && cp <= 0xFF19)) {
8475+
res += 3.00f; // half/full-width digits
8476+
} else {
8477+
res += 1.00f; // letters, CJK ideographs, kana, hangul, ...
8478+
}
8479+
break;
84188480
}
84198481
}
84208482

0 commit comments

Comments
 (0)