Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 43 additions & 15 deletions python/minisgl/tokenizer/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,53 @@
from transformers import PreTrainedTokenizerBase


class TokenizeManager:
class TokenizeManager:
def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None:
self.tokenizer = tokenizer

def tokenize(self, msgs: List[TokenizeMsg]) -> List[torch.Tensor]:
results: List[torch.Tensor] = []
# TODO: batch tokenization
for msg in msgs:
if not msgs:
return []

# Separate plain text and chat template messages while preserving order
plain_indices: List[int] = []
plain_texts: List[str] = []
chat_indices: List[int] = []
chat_convs: List[List[dict]] = []

for i, msg in enumerate(msgs):
if isinstance(msg.text, list):
prompt = self.tokenizer.apply_chat_template(
msg.text,
tokenize=False,
add_generation_prompt=True,
)
assert isinstance(prompt, str)
chat_indices.append(i)
chat_convs.append(msg.text)
else:
prompt = msg.text
input_ids: torch.Tensor = ( # type: ignore
self.tokenizer.encode(prompt, return_tensors="pt")
plain_indices.append(i)
plain_texts.append(msg.text)

results: List[torch.Tensor | None] = [None] * len(msgs)

# Batch encode plain texts
if plain_texts:
encoded = self.tokenizer(plain_texts, return_tensors="pt", padding=True)
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
for i, (ids, mask) in enumerate(zip(input_ids, attention_mask)):
# Remove padding tokens
length = mask.sum().item()
results[plain_indices[i]] = ids[:length].to(torch.int32)

# Batch encode chat templates
if chat_convs:
prompts = self.tokenizer.apply_chat_template(
chat_convs,
tokenize=False,
add_generation_prompt=True,
)
results.append(input_ids.view(-1).to(torch.int32))
return results
encoded = self.tokenizer(prompts, return_tensors="pt", padding=True)
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
for i, (ids, mask) in enumerate(zip(input_ids, attention_mask)):
# Remove padding tokens
length = mask.sum().item()
results[chat_indices[i]] = ids[:length].to(torch.int32)

return results # type: ignore