Skip to content

Commit c7e8420

Browse files
authored
Merge pull request #399 from rubik-hua/mistral
issue/398 Bugfix: add MistralProcessor and InternLMProcessor to fix missing spaces in streaming output
2 parents 1b44040 + 76cbd3e commit c7e8420

2 files changed

Lines changed: 108 additions & 0 deletions

File tree

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import re
2+
import types
3+
from .basic_llm_processor import BasicLLMProcessor
4+
from .processor import register_processor
5+
6+
7+
@register_processor("internlm3")
8+
class InternLMProcessor(BasicLLMProcessor):
9+
def __init__(self, model_dir_path: str):
10+
super().__init__(model_dir_path)
11+
self._fix_tokenizer_decode(self.tokenizer)
12+
13+
@staticmethod
14+
def _fix_tokenizer_decode(tokenizer):
15+
"""Fix InternLM tokenizer incremental decoding space loss.
16+
17+
Similar to Mistral, InternLM uses a Fast tokenizer whose Rust backend
18+
trims leading spaces derived from ▁ (U+2581) during single-token
19+
decoding, causing words to concatenate.
20+
21+
Fix: patch tokenizer.decode() to manually replace ▁ → space
22+
and handle byte fallback.
23+
"""
24+
def patched_decode(self_tok, token_ids, skip_special_tokens=False, **kwargs):
25+
if isinstance(token_ids, int):
26+
token_ids = [token_ids]
27+
28+
tokens = self_tok.convert_ids_to_tokens(
29+
token_ids, skip_special_tokens=skip_special_tokens
30+
)
31+
if isinstance(tokens, str):
32+
tokens = [tokens]
33+
34+
if skip_special_tokens:
35+
special = set(self_tok.all_special_tokens)
36+
tokens = [t for t in tokens if t not in special]
37+
38+
text = "".join(tokens).replace("\u2581", " ")
39+
40+
def byte_fallback_replace(match):
41+
hex_strs = re.findall(r"<0x([0-9A-Fa-f]{2})>", match.group(0))
42+
byte_values = bytes([int(h, 16) for h in hex_strs])
43+
return byte_values.decode("utf-8", errors="replace")
44+
45+
text = re.sub(r"(<0x[0-9A-Fa-f]{2}>)+", byte_fallback_replace, text)
46+
47+
return text
48+
49+
tokenizer.decode = types.MethodType(patched_decode, tokenizer)
50+
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import re
2+
import types
3+
from .basic_llm_processor import BasicLLMProcessor
4+
from .processor import register_processor
5+
6+
7+
@register_processor("mistral")
8+
class MistralProcessor(BasicLLMProcessor):
9+
def __init__(self, model_dir_path: str):
10+
super().__init__(model_dir_path)
11+
self._fix_tokenizer_decode(self.tokenizer)
12+
13+
@staticmethod
14+
def _fix_tokenizer_decode(tokenizer):
15+
"""Fix Mistral tokenizer incremental decoding space loss.
16+
17+
LlamaTokenizerFast.decode() calls Rust backend directly, which
18+
trims leading spaces derived from ▁ (U+2581) during single-token
19+
decoding, causing English words to concatenate.
20+
21+
Fix: patch tokenizer.decode() to:
22+
1. Convert token IDs to raw token strings (preserving ▁)
23+
2. Manually replace ▁ → space and handle byte fallback
24+
"""
25+
original_decode = tokenizer.decode
26+
27+
def patched_decode(self_tok, token_ids, skip_special_tokens=False, **kwargs):
28+
# 1. Get raw token strings (preserving ▁)
29+
if isinstance(token_ids, int):
30+
token_ids = [token_ids]
31+
tokens = self_tok.convert_ids_to_tokens(
32+
token_ids, skip_special_tokens=skip_special_tokens
33+
)
34+
if isinstance(tokens, str):
35+
tokens = [tokens]
36+
37+
# 2. Remove special tokens if requested
38+
if skip_special_tokens:
39+
special = set(self_tok.all_special_tokens)
40+
tokens = [t for t in tokens if t not in special]
41+
42+
# 3. Join + replace ▁ (U+2581) with space
43+
text = "".join(tokens).replace("\u2581", " ")
44+
45+
# 4. Handle SentencePiece byte fallback: consecutive <0xHH> → UTF-8
46+
def byte_fallback_replace(match):
47+
hex_strs = re.findall(r"<0x([0-9A-Fa-f]{2})>", match.group(0))
48+
byte_values = bytes([int(h, 16) for h in hex_strs])
49+
return byte_values.decode("utf-8", errors="replace")
50+
51+
text = re.sub(r"(<0x[0-9A-Fa-f]{2}>)+", byte_fallback_replace, text)
52+
53+
# 5. Strip leading/trailing whitespace only if ALL tokens were special
54+
# (preserve inter-word spaces from ▁)
55+
return text
56+
57+
tokenizer.decode = types.MethodType(patched_decode, tokenizer)
58+

0 commit comments

Comments
 (0)