Skip to content

Commit 764b574

Browse files
authored
Merge pull request #405 from InfiniTensor/refactor/sentencepiece-processor
refactor: use a common processor for fm9g, internlm, and mistral
2 parents c7e8420 + 5238acc commit 764b574

2 files changed

Lines changed: 32 additions & 65 deletions

File tree

python/infinilm/processors/mistral_processor.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

python/infinilm/processors/internlm_processor.py renamed to python/infinilm/processors/sentencepiece_processor.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,49 @@
44
from .processor import register_processor
55

66

7-
@register_processor("internlm3")
8-
class InternLMProcessor(BasicLLMProcessor):
7+
class SentencePieceProcessor(BasicLLMProcessor):
8+
"""Generic processor for models using SentencePiece tokenizer
9+
(InternLM, Mistral, FM9G, etc.).
10+
11+
Fixes leading space loss during incremental decoding with Fast tokenizers.
12+
"""
13+
914
def __init__(self, model_dir_path: str):
1015
super().__init__(model_dir_path)
1116
self._fix_tokenizer_decode(self.tokenizer)
1217

1318
@staticmethod
1419
def _fix_tokenizer_decode(tokenizer):
15-
"""Fix InternLM tokenizer incremental decoding space loss.
20+
"""Fix leading space loss when tokenizer decodes incrementally.
1621
17-
Similar to Mistral, InternLM uses a Fast tokenizer whose Rust backend
18-
trims leading spaces derived from ▁ (U+2581) during single-token
19-
decoding, causing words to concatenate.
22+
Problem: Fast tokenizer's Rust backend trims leading spaces derived
23+
from ▁ (U+2581) during single-token decoding, causing words to concatenate.
2024
2125
Fix: patch tokenizer.decode() to manually replace ▁ → space
2226
and handle byte fallback.
2327
"""
28+
2429
def patched_decode(self_tok, token_ids, skip_special_tokens=False, **kwargs):
30+
# 1. Normalize input to list of token IDs
2531
if isinstance(token_ids, int):
2632
token_ids = [token_ids]
27-
33+
34+
# 2. Convert token IDs to raw token strings (preserving ▁)
2835
tokens = self_tok.convert_ids_to_tokens(
2936
token_ids, skip_special_tokens=skip_special_tokens
3037
)
3138
if isinstance(tokens, str):
3239
tokens = [tokens]
3340

41+
# 3. Remove special tokens if requested
3442
if skip_special_tokens:
3543
special = set(self_tok.all_special_tokens)
3644
tokens = [t for t in tokens if t not in special]
3745

46+
# 4. Join and replace ▁ (U+2581) with space
3847
text = "".join(tokens).replace("\u2581", " ")
3948

49+
# 5. Handle SentencePiece byte fallback: consecutive <0xHH> → UTF-8
4050
def byte_fallback_replace(match):
4151
hex_strs = re.findall(r"<0x([0-9A-Fa-f]{2})>", match.group(0))
4252
byte_values = bytes([int(h, 16) for h in hex_strs])
@@ -48,3 +58,18 @@ def byte_fallback_replace(match):
4858

4959
tokenizer.decode = types.MethodType(patched_decode, tokenizer)
5060

61+
62+
@register_processor("fm9g")
63+
class FM9GProcessor(SentencePieceProcessor):
64+
pass
65+
66+
67+
# Register model-specific aliases
68+
@register_processor("internlm3")
69+
class InternLMProcessor(SentencePieceProcessor):
70+
pass
71+
72+
73+
@register_processor("mistral")
74+
class MistralProcessor(SentencePieceProcessor):
75+
pass

0 commit comments

Comments
 (0)