Skip to content

Commit cf900a1

Browse files
nhs000aldehir
authored andcommitted
unicode : add custom Qwen2 regex handler to fix segfault on long input (ggml-org#21257)
* unicode : add custom Qwen2 regex handler to fix segfault on long input std::regex uses recursive backtracking internally, which causes a stack overflow (segfault) when tokenizing long sequences of repeated characters (e.g. 43K 'A's). The Qwen2 tokenizer regex differs from Llama3 only in the digit pattern (\p{N} vs \p{N}{1,3}), so it was falling through to the std::regex fallback path instead of using a custom handler. Add unicode_regex_split_custom_qwen2() following the established pattern used by gpt2, llama3, kimi_k2, and afmoe custom handlers. Closes: ggml-org#21113 * cont : remove TODO comment * cont : update comment to reflect original regex * use the correct regex in the comment this time... [no ci] --------- Co-authored-by: Aldehir Rojas <hello@alde.dev>
1 parent a6476ed commit cf900a1

1 file changed

Lines changed: 138 additions & 1 deletion

File tree

src/unicode.cpp

Lines changed: 138 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,141 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
470470
return bpe_offsets;
471471
}
472472

473+
// Qwen2 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
474+
static std::vector<size_t> unicode_regex_split_custom_qwen2(const std::string & text, const std::vector<size_t> & offsets) {
475+
std::vector<size_t> bpe_offsets; // store the offset of each word
476+
bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
477+
478+
const auto cpts = unicode_cpts_from_utf8(text);
479+
480+
size_t start = 0;
481+
for (auto offset : offsets) {
482+
const size_t offset_ini = start;
483+
const size_t offset_end = start + offset;
484+
assert(offset_end <= cpts.size());
485+
start = offset_end;
486+
487+
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
488+
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
489+
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
490+
};
491+
492+
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
493+
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
494+
};
495+
496+
size_t _prev_end = offset_ini;
497+
auto _add_token = [&] (const size_t end) -> size_t {
498+
assert(_prev_end <= end && end <= offset_end);
499+
size_t len = end - _prev_end;
500+
if (len > 0) {
501+
bpe_offsets.push_back(len);
502+
}
503+
_prev_end = end;
504+
//if (len > 0) {
505+
// std::string s = "";
506+
// for(size_t p = end-len; p < end; p++)
507+
// s += unicode_cpt_to_utf8(cpts[p]);
508+
// printf(">>> '%s'\n", s.c_str());
509+
//}
510+
return len;
511+
};
512+
513+
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
514+
const uint32_t cpt = _get_cpt(pos);
515+
const auto flags = _get_flags(pos);
516+
517+
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
518+
if (cpt == '\'' && pos+1 < offset_end) {
519+
uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
520+
if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
521+
pos += _add_token(pos+2);
522+
continue;
523+
}
524+
if (pos+2 < offset_end) {
525+
uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
526+
if ((cpt_next == 'r' && cpt_next_next == 'e') ||
527+
(cpt_next == 'v' && cpt_next_next == 'e') ||
528+
(cpt_next == 'l' && cpt_next_next == 'l')) {
529+
pos += _add_token(pos+3);
530+
continue;
531+
}
532+
}
533+
}
534+
535+
// regex: [^\r\n\p{L}\p{N}]?\p{L}+
536+
if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
537+
if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
538+
pos++;
539+
while (_get_flags(pos).is_letter) {
540+
pos++;
541+
}
542+
_add_token(pos);
543+
continue;
544+
}
545+
}
546+
547+
// regex: \p{N}
548+
if (flags.is_number) {
549+
pos++;
550+
_add_token(pos);
551+
continue;
552+
}
553+
554+
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
555+
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
556+
if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
557+
pos += (cpt == ' ');
558+
while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
559+
flags2 = _get_flags(++pos);
560+
}
561+
uint32_t cpt2 = _get_cpt(pos);
562+
while (cpt2 == '\r' || cpt2 == '\n') {
563+
cpt2 = _get_cpt(++pos);
564+
}
565+
_add_token(pos);
566+
continue;
567+
}
568+
569+
size_t num_whitespaces = 0;
570+
size_t last_end_r_or_n = 0;
571+
while (_get_flags(pos+num_whitespaces).is_whitespace) {
572+
uint32_t cpt2 = _get_cpt(pos+num_whitespaces);
573+
if (cpt2 == '\r' || cpt2 == '\n') {
574+
last_end_r_or_n = pos + num_whitespaces + 1;
575+
}
576+
num_whitespaces++;
577+
}
578+
579+
// regex: \s*[\r\n]+
580+
if (last_end_r_or_n > 0) {
581+
pos = last_end_r_or_n;
582+
_add_token(pos);
583+
continue;
584+
}
585+
586+
// regex: \s+(?!\S)
587+
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
588+
pos += num_whitespaces - 1;
589+
_add_token(pos);
590+
continue;
591+
}
592+
593+
// regex: \s+
594+
if (num_whitespaces > 0) {
595+
pos += num_whitespaces;
596+
_add_token(pos);
597+
continue;
598+
}
599+
600+
// no matches
601+
_add_token(++pos);
602+
}
603+
}
604+
605+
return bpe_offsets;
606+
}
607+
473608
template <typename CharT>
474609
static std::vector<size_t> unicode_regex_split_stl(const std::basic_string<CharT> & text, const std::basic_string<CharT> & regex, const std::vector<size_t> & offsets) {
475610
using BidirIt = typename std::basic_string<CharT>::const_iterator;
@@ -790,8 +925,10 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
790925
} else if (
791926
regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
792927
regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
793-
794928
bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
929+
} else if (
930+
regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
931+
bpe_offsets = unicode_regex_split_custom_qwen2(text, offsets);
795932
} else if (regex_expr == "\\p{Han}+") {
796933
// K2's first pattern - handle all K2 patterns together
797934
bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);

0 commit comments

Comments
 (0)