3737import concurrent .futures
3838import hashlib
3939import os
40+ import re
4041import shutil
4142import tempfile
4243import time
@@ -209,6 +210,20 @@ def _supports_chatml_heuristic(tokenizer: PreTrainedTokenizerBase) -> bool:
209210 return False
210211
211212
213+ # Mirrors the regex transformers uses to detect {% generation %} in chat templates
214+ # (see transformers/utils/chat_template_utils.py).
215+ _GENERATION_KEYWORD_RE = re .compile (r"\{\%-?\s*generation\s*-?\%\}" )
216+
217+
218+ def _chat_template_has_generation (tokenizer : PreTrainedTokenizerBase ) -> bool :
219+ tpl = getattr (tokenizer , "chat_template" , None )
220+ if isinstance (tpl , dict ):
221+ tpl = tpl .get ("default" )
222+ if not isinstance (tpl , str ):
223+ return False
224+ return bool (_GENERATION_KEYWORD_RE .search (tpl ))
225+
226+
212227def _encode_role (tokenizer : PreTrainedTokenizerBase , role : str ) -> list [int ]:
213228 """Encode a role string, returning only the role tokens (no special tokens)."""
214229 return tokenizer .encode (role , add_special_tokens = False )
@@ -295,7 +310,19 @@ def make_chat_tokenize_fn(
295310 Tested model families (ChatML format): Qwen2, Qwen2.5, Qwen3, Qwen3.5, Nemotron 3.
296311 """
297312 _check_model_family (tokenizer )
298- _heuristic_checked = {"done" : False }
313+ use_heuristic = not _chat_template_has_generation (tokenizer )
314+ if use_heuristic :
315+ if not _supports_chatml_heuristic (tokenizer ):
316+ model_name = getattr (tokenizer , "name_or_path" , "unknown" )
317+ raise ValueError (
318+ f"Chat template for '{ model_name } ' does not support "
319+ f"{{% generation %}} and does not use ChatML format. "
320+ f"Use make_pretrain_tokenize_fn instead."
321+ )
322+ warn_rank_0 (
323+ "Chat template lacks {% generation %} support. "
324+ "Using heuristic ChatML-based assistant masking."
325+ )
299326
300327 def tokenize (sample ):
301328 messages = sample .get (chat_key )
@@ -308,15 +335,25 @@ def tokenize(sample):
308335 }
309336
310337 try :
311- result = tokenizer .apply_chat_template (
312- messages ,
313- tokenize = True ,
314- return_dict = True ,
315- return_assistant_tokens_mask = True ,
316- padding = "max_length" ,
317- truncation = True ,
318- max_length = max_length ,
319- )
338+ if use_heuristic :
339+ result = tokenizer .apply_chat_template (
340+ messages ,
341+ tokenize = True ,
342+ return_dict = True ,
343+ padding = "max_length" ,
344+ truncation = True ,
345+ max_length = max_length ,
346+ )
347+ else :
348+ result = tokenizer .apply_chat_template (
349+ messages ,
350+ tokenize = True ,
351+ return_dict = True ,
352+ return_assistant_tokens_mask = True ,
353+ padding = "max_length" ,
354+ truncation = True ,
355+ max_length = max_length ,
356+ )
320357 except (ValueError , TypeError , KeyError ) as e :
321358 print_rank_0 (f"WARNING: Failed to tokenize sample: { e } . Skipping." )
322359 pad_id = tokenizer .pad_token_id or tokenizer .eos_token_id or 0
@@ -327,25 +364,10 @@ def tokenize(sample):
327364 }
328365
329366 input_ids = result ["input_ids" ]
330- assistant_masks = result ["assistant_masks" ]
331-
332- # Fallback: if native masks are all zeros, use heuristic ChatML masking
333- if any (m == "assistant" for m in (msg .get ("role" ) for msg in messages )):
334- if not any (assistant_masks ):
335- if not _heuristic_checked ["done" ]:
336- _heuristic_checked ["done" ] = True
337- if not _supports_chatml_heuristic (tokenizer ):
338- model_name = getattr (tokenizer , "name_or_path" , "unknown" )
339- raise ValueError (
340- f"Chat template for '{ model_name } ' does not support "
341- f"{{% generation %}} and does not use ChatML format. "
342- f"Use make_pretrain_tokenize_fn instead."
343- )
344- print_rank_0 (
345- "WARNING: Chat template lacks {% generation %} support. "
346- "Using heuristic ChatML-based assistant masking."
347- )
348- assistant_masks = _chatml_assistant_mask (input_ids , tokenizer )
367+ if use_heuristic :
368+ assistant_masks = _chatml_assistant_mask (input_ids , tokenizer )
369+ else :
370+ assistant_masks = result ["assistant_masks" ]
349371
350372 labels = [tid if mask else IGNORE_TOKEN_ID for tid , mask in zip (input_ids , assistant_masks )]
351373 return {
0 commit comments