44from .processor import register_processor
55
66
7- @register_processor ("internlm3" )
8- class InternLMProcessor (BasicLLMProcessor ):
7+ class SentencePieceProcessor (BasicLLMProcessor ):
8+ """Generic processor for models using SentencePiece tokenizer
9+ (InternLM, Mistral, FM9G, etc.).
10+
11+ Fixes leading space loss during incremental decoding with Fast tokenizers.
12+ """
13+
914 def __init__ (self , model_dir_path : str ):
1015 super ().__init__ (model_dir_path )
1116 self ._fix_tokenizer_decode (self .tokenizer )
1217
1318 @staticmethod
1419 def _fix_tokenizer_decode (tokenizer ):
15- """Fix InternLM tokenizer incremental decoding space loss .
20+ """Fix leading space loss when tokenizer decodes incrementally .
1621
17- Similar to Mistral, InternLM uses a Fast tokenizer whose Rust backend
18- trims leading spaces derived from ▁ (U+2581) during single-token
19- decoding, causing words to concatenate.
22+ Problem: Fast tokenizer's Rust backend trims leading spaces derived
23+ from ▁ (U+2581) during single-token decoding, causing words to concatenate.
2024
2125 Fix: patch tokenizer.decode() to manually replace ▁ → space
2226 and handle byte fallback.
2327 """
28+
2429 def patched_decode (self_tok , token_ids , skip_special_tokens = False , ** kwargs ):
30+ # 1. Normalize input to list of token IDs
2531 if isinstance (token_ids , int ):
2632 token_ids = [token_ids ]
27-
33+
34+ # 2. Convert token IDs to raw token strings (preserving ▁)
2835 tokens = self_tok .convert_ids_to_tokens (
2936 token_ids , skip_special_tokens = skip_special_tokens
3037 )
3138 if isinstance (tokens , str ):
3239 tokens = [tokens ]
3340
41+ # 3. Remove special tokens if requested
3442 if skip_special_tokens :
3543 special = set (self_tok .all_special_tokens )
3644 tokens = [t for t in tokens if t not in special ]
3745
46+ # 4. Join and replace ▁ (U+2581) with space
3847 text = "" .join (tokens ).replace ("\u2581 " , " " )
3948
49+ # 5. Handle SentencePiece byte fallback: consecutive <0xHH> → UTF-8
4050 def byte_fallback_replace (match ):
4151 hex_strs = re .findall (r"<0x([0-9A-Fa-f]{2})>" , match .group (0 ))
4252 byte_values = bytes ([int (h , 16 ) for h in hex_strs ])
@@ -48,3 +58,18 @@ def byte_fallback_replace(match):
4858
4959 tokenizer .decode = types .MethodType (patched_decode , tokenizer )
5060
61+
62+ @register_processor ("fm9g" )
63+ class FM9GProcessor (SentencePieceProcessor ):
64+ pass
65+
66+
67+ # Register model-specific aliases
68+ @register_processor ("internlm3" )
69+ class InternLMProcessor (SentencePieceProcessor ):
70+ pass
71+
72+
73+ @register_processor ("mistral" )
74+ class MistralProcessor (SentencePieceProcessor ):
75+ pass
0 commit comments