Skip to content

Commit a4fa5b6

Browse files
committed
feat: implement batch tokenization for TokenizeManager
- Use tokenizer() for batch encoding plain texts - Use apply_chat_template() for batch processing chat templates - Remove padding tokens using attention mask - Preserve original message order - Add comprehensive unit tests for batch tokenization
1 parent 20fcd7f commit a4fa5b6

2 files changed

Lines changed: 269 additions & 15 deletions

File tree

python/minisgl/tokenizer/tokenize.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,53 @@
77
from transformers import PreTrainedTokenizerBase
88

99

10-
class TokenizeManager:
10+
class TokenizeManager:
1111
def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None:
1212
self.tokenizer = tokenizer
1313

1414
def tokenize(self, msgs: List[TokenizeMsg]) -> List[torch.Tensor]:
15-
results: List[torch.Tensor] = []
16-
# TODO: batch tokenization
17-
for msg in msgs:
15+
if not msgs:
16+
return []
17+
18+
# Separate plain text and chat template messages while preserving order
19+
plain_indices: List[int] = []
20+
plain_texts: List[str] = []
21+
chat_indices: List[int] = []
22+
chat_convs: List[List[dict]] = []
23+
24+
for i, msg in enumerate(msgs):
1825
if isinstance(msg.text, list):
19-
prompt = self.tokenizer.apply_chat_template(
20-
msg.text,
21-
tokenize=False,
22-
add_generation_prompt=True,
23-
)
24-
assert isinstance(prompt, str)
26+
chat_indices.append(i)
27+
chat_convs.append(msg.text)
2528
else:
26-
prompt = msg.text
27-
input_ids: torch.Tensor = ( # type: ignore
28-
self.tokenizer.encode(prompt, return_tensors="pt")
29+
plain_indices.append(i)
30+
plain_texts.append(msg.text)
31+
32+
results: List[torch.Tensor | None] = [None] * len(msgs)
33+
34+
# Batch encode plain texts
35+
if plain_texts:
36+
encoded = self.tokenizer(plain_texts, return_tensors="pt", padding=True)
37+
input_ids = encoded["input_ids"]
38+
attention_mask = encoded["attention_mask"]
39+
for i, (ids, mask) in enumerate(zip(input_ids, attention_mask)):
40+
# Remove padding tokens
41+
length = mask.sum().item()
42+
results[plain_indices[i]] = ids[:length].to(torch.int32)
43+
44+
# Batch encode chat templates
45+
if chat_convs:
46+
prompts = self.tokenizer.apply_chat_template(
47+
chat_convs,
48+
tokenize=False,
49+
add_generation_prompt=True,
2950
)
30-
results.append(input_ids.view(-1).to(torch.int32))
31-
return results
51+
encoded = self.tokenizer(prompts, return_tensors="pt", padding=True)
52+
input_ids = encoded["input_ids"]
53+
attention_mask = encoded["attention_mask"]
54+
for i, (ids, mask) in enumerate(zip(input_ids, attention_mask)):
55+
# Remove padding tokens
56+
length = mask.sum().item()
57+
results[chat_indices[i]] = ids[:length].to(torch.int32)
58+
59+
return results # type: ignore

tests/tokenizer/test_tokenize.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
"""Unit tests for TokenizeManager batch tokenization."""
2+
from __future__ import annotations
3+
4+
import torch
5+
from minisgl.core import SamplingParams
6+
from minisgl.message import TokenizeMsg
7+
from minisgl.tokenizer.tokenize import TokenizeManager
8+
from transformers import AutoTokenizer
9+
10+
11+
def get_test_tokenizer():
12+
"""Get a small tokenizer for testing."""
13+
return AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True)
14+
15+
16+
def test_single_plain_text():
17+
"""Test tokenization of a single plain text message."""
18+
tokenizer = get_test_tokenizer()
19+
manager = TokenizeManager(tokenizer)
20+
21+
msg = TokenizeMsg(
22+
uid=0,
23+
text="Hello, world!",
24+
sampling_params=SamplingParams(),
25+
)
26+
27+
results = manager.tokenize([msg])
28+
29+
assert len(results) == 1
30+
expected = tokenizer.encode("Hello, world!", return_tensors="pt").view(-1)
31+
assert torch.equal(results[0], expected.to(torch.int32))
32+
33+
34+
def test_batch_plain_text():
35+
"""Test batch tokenization of multiple plain text messages."""
36+
tokenizer = get_test_tokenizer()
37+
manager = TokenizeManager(tokenizer)
38+
39+
texts = [
40+
"Hello, world!",
41+
"How are you?",
42+
"This is a test.",
43+
"Mini-SGLang is awesome!",
44+
]
45+
46+
msgs = [
47+
TokenizeMsg(uid=i, text=text, sampling_params=SamplingParams())
48+
for i, text in enumerate(texts)
49+
]
50+
51+
results = manager.tokenize(msgs)
52+
53+
assert len(results) == len(texts)
54+
55+
for i, (result, text) in enumerate(zip(results, texts)):
56+
expected = tokenizer.encode(text, return_tensors="pt").view(-1)
57+
assert torch.equal(result, expected.to(torch.int32)), f"Mismatch at index {i}"
58+
59+
60+
def test_single_chat_template():
61+
"""Test tokenization of a single chat template message."""
62+
tokenizer = get_test_tokenizer()
63+
manager = TokenizeManager(tokenizer)
64+
65+
msg = TokenizeMsg(
66+
uid=0,
67+
text=[{"role": "user", "content": "Hello!"}],
68+
sampling_params=SamplingParams(),
69+
)
70+
71+
results = manager.tokenize([msg])
72+
73+
assert len(results) == 1
74+
75+
# Verify the result is valid tokens
76+
assert results[0].dtype == torch.int32
77+
assert len(results[0]) > 0
78+
79+
80+
def test_batch_chat_template():
81+
"""Test batch tokenization of multiple chat template messages."""
82+
tokenizer = get_test_tokenizer()
83+
manager = TokenizeManager(tokenizer)
84+
85+
conversations = [
86+
[{"role": "user", "content": "Hello!"}],
87+
[{"role": "user", "content": "How are you?"}],
88+
[
89+
{"role": "system", "content": "You are helpful."},
90+
{"role": "user", "content": "Hi!"},
91+
],
92+
]
93+
94+
msgs = [
95+
TokenizeMsg(uid=i, text=conv, sampling_params=SamplingParams())
96+
for i, conv in enumerate(conversations)
97+
]
98+
99+
results = manager.tokenize(msgs)
100+
101+
assert len(results) == len(conversations)
102+
103+
for result in results:
104+
assert result.dtype == torch.int32
105+
assert len(result) > 0
106+
107+
108+
def test_mixed_batch():
109+
"""Test batch tokenization with mixed plain text and chat template messages."""
110+
tokenizer = get_test_tokenizer()
111+
manager = TokenizeManager(tokenizer)
112+
113+
msgs = [
114+
TokenizeMsg(uid=0, text="Plain text message", sampling_params=SamplingParams()),
115+
TokenizeMsg(
116+
uid=1, text=[{"role": "user", "content": "Chat message"}], sampling_params=SamplingParams()
117+
),
118+
TokenizeMsg(uid=2, text="Another plain text", sampling_params=SamplingParams()),
119+
TokenizeMsg(
120+
uid=3,
121+
text=[
122+
{"role": "system", "content": "System prompt"},
123+
{"role": "user", "content": "User message"},
124+
],
125+
sampling_params=SamplingParams(),
126+
),
127+
]
128+
129+
results = manager.tokenize(msgs)
130+
131+
assert len(results) == 4
132+
133+
# Verify plain text results
134+
expected_0 = tokenizer.encode("Plain text message", return_tensors="pt").view(-1)
135+
assert torch.equal(results[0], expected_0.to(torch.int32))
136+
137+
expected_2 = tokenizer.encode("Another plain text", return_tensors="pt").view(-1)
138+
assert torch.equal(results[2], expected_2.to(torch.int32))
139+
140+
# Verify chat template results are valid
141+
assert results[1].dtype == torch.int32
142+
assert results[3].dtype == torch.int32
143+
144+
145+
def test_empty_batch():
146+
"""Test tokenization of an empty batch."""
147+
tokenizer = get_test_tokenizer()
148+
manager = TokenizeManager(tokenizer)
149+
150+
results = manager.tokenize([])
151+
152+
assert len(results) == 0
153+
154+
155+
def test_output_dtype():
156+
"""Verify that output tensors are int32 as expected by the system."""
157+
tokenizer = get_test_tokenizer()
158+
manager = TokenizeManager(tokenizer)
159+
160+
msgs = [
161+
TokenizeMsg(uid=0, text="Test", sampling_params=SamplingParams()),
162+
TokenizeMsg(uid=1, text=[{"role": "user", "content": "Test"}], sampling_params=SamplingParams()),
163+
]
164+
165+
results = manager.tokenize(msgs)
166+
167+
for result in results:
168+
assert result.dtype == torch.int32, f"Expected int32, got {result.dtype}"
169+
170+
171+
def test_consistency_with_original():
172+
"""Verify batch tokenization produces same results as individual tokenization."""
173+
tokenizer = get_test_tokenizer()
174+
manager = TokenizeManager(tokenizer)
175+
176+
texts = ["First message", "Second message", "Third message"]
177+
178+
msgs = [
179+
TokenizeMsg(uid=i, text=text, sampling_params=SamplingParams())
180+
for i, text in enumerate(texts)
181+
]
182+
183+
# Batch tokenization
184+
batch_results = manager.tokenize(msgs)
185+
186+
# Individual tokenization (original behavior)
187+
individual_results = []
188+
for text in texts:
189+
ids = tokenizer.encode(text, return_tensors="pt").view(-1).to(torch.int32)
190+
individual_results.append(ids)
191+
192+
# Compare
193+
for i, (batch, individual) in enumerate(zip(batch_results, individual_results)):
194+
assert torch.equal(batch, individual), f"Mismatch at index {i}"
195+
196+
197+
if __name__ == "__main__":
198+
import sys
199+
200+
failed = False
201+
tests = [
202+
test_single_plain_text,
203+
test_batch_plain_text,
204+
test_single_chat_template,
205+
test_batch_chat_template,
206+
test_mixed_batch,
207+
test_empty_batch,
208+
test_output_dtype,
209+
test_consistency_with_original,
210+
]
211+
212+
for test in tests:
213+
try:
214+
test()
215+
print(f"✓ {test.__name__}")
216+
except AssertionError as e:
217+
print(f"✗ {test.__name__}: {e}")
218+
failed = True
219+
except Exception as e:
220+
print(f"✗ {test.__name__}: {type(e).__name__}: {e}")
221+
failed = True
222+
223+
if failed:
224+
sys.exit(1)
225+
else:
226+
print("\nAll tests passed!")

0 commit comments

Comments
 (0)