Skip to content

Commit 2154693

Browse files
committed
Emit chat-template generation warning once, not per tokenization worker
Hoist the `{% generation %}` detection out of the tokenize closure in make_chat_tokenize_fn so the heuristic-mode warning fires once in the main process instead of ~N times (once per num_proc worker). Drop return_assistant_tokens_mask=True on the heuristic path to silence the matching transformers-internal warning_once at the source. Also repoint the LAQ recipe test at configs/quantize/experimental/. Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent 215f643 commit 2154693

2 files changed

Lines changed: 59 additions & 30 deletions

File tree

examples/llm_qat/dataset_utils.py

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import concurrent.futures
3838
import hashlib
3939
import os
40+
import re
4041
import shutil
4142
import tempfile
4243
import time
@@ -209,6 +210,20 @@ def _supports_chatml_heuristic(tokenizer: PreTrainedTokenizerBase) -> bool:
209210
return False
210211

211212

213+
# Mirrors the regex transformers uses to detect {% generation %} in chat templates
214+
# (see transformers/utils/chat_template_utils.py).
215+
_GENERATION_KEYWORD_RE = re.compile(r"\{\%-?\s*generation\s*-?\%\}")
216+
217+
218+
def _chat_template_has_generation(tokenizer: PreTrainedTokenizerBase) -> bool:
219+
tpl = getattr(tokenizer, "chat_template", None)
220+
if isinstance(tpl, dict):
221+
tpl = tpl.get("default")
222+
if not isinstance(tpl, str):
223+
return False
224+
return bool(_GENERATION_KEYWORD_RE.search(tpl))
225+
226+
212227
def _encode_role(tokenizer: PreTrainedTokenizerBase, role: str) -> list[int]:
213228
"""Encode a role string, returning only the role tokens (no special tokens)."""
214229
return tokenizer.encode(role, add_special_tokens=False)
@@ -295,7 +310,19 @@ def make_chat_tokenize_fn(
295310
Tested model families (ChatML format): Qwen2, Qwen2.5, Qwen3, Qwen3.5, Nemotron 3.
296311
"""
297312
_check_model_family(tokenizer)
298-
_heuristic_checked = {"done": False}
313+
use_heuristic = not _chat_template_has_generation(tokenizer)
314+
if use_heuristic:
315+
if not _supports_chatml_heuristic(tokenizer):
316+
model_name = getattr(tokenizer, "name_or_path", "unknown")
317+
raise ValueError(
318+
f"Chat template for '{model_name}' does not support "
319+
f"{{% generation %}} and does not use ChatML format. "
320+
f"Use make_pretrain_tokenize_fn instead."
321+
)
322+
warn_rank_0(
323+
"Chat template lacks {% generation %} support. "
324+
"Using heuristic ChatML-based assistant masking."
325+
)
299326

300327
def tokenize(sample):
301328
messages = sample.get(chat_key)
@@ -308,15 +335,25 @@ def tokenize(sample):
308335
}
309336

310337
try:
311-
result = tokenizer.apply_chat_template(
312-
messages,
313-
tokenize=True,
314-
return_dict=True,
315-
return_assistant_tokens_mask=True,
316-
padding="max_length",
317-
truncation=True,
318-
max_length=max_length,
319-
)
338+
if use_heuristic:
339+
result = tokenizer.apply_chat_template(
340+
messages,
341+
tokenize=True,
342+
return_dict=True,
343+
padding="max_length",
344+
truncation=True,
345+
max_length=max_length,
346+
)
347+
else:
348+
result = tokenizer.apply_chat_template(
349+
messages,
350+
tokenize=True,
351+
return_dict=True,
352+
return_assistant_tokens_mask=True,
353+
padding="max_length",
354+
truncation=True,
355+
max_length=max_length,
356+
)
320357
except (ValueError, TypeError, KeyError) as e:
321358
print_rank_0(f"WARNING: Failed to tokenize sample: {e}. Skipping.")
322359
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id or 0
@@ -327,25 +364,10 @@ def tokenize(sample):
327364
}
328365

329366
input_ids = result["input_ids"]
330-
assistant_masks = result["assistant_masks"]
331-
332-
# Fallback: if native masks are all zeros, use heuristic ChatML masking
333-
if any(m == "assistant" for m in (msg.get("role") for msg in messages)):
334-
if not any(assistant_masks):
335-
if not _heuristic_checked["done"]:
336-
_heuristic_checked["done"] = True
337-
if not _supports_chatml_heuristic(tokenizer):
338-
model_name = getattr(tokenizer, "name_or_path", "unknown")
339-
raise ValueError(
340-
f"Chat template for '{model_name}' does not support "
341-
f"{{% generation %}} and does not use ChatML format. "
342-
f"Use make_pretrain_tokenize_fn instead."
343-
)
344-
print_rank_0(
345-
"WARNING: Chat template lacks {% generation %} support. "
346-
"Using heuristic ChatML-based assistant masking."
347-
)
348-
assistant_masks = _chatml_assistant_mask(input_ids, tokenizer)
367+
if use_heuristic:
368+
assistant_masks = _chatml_assistant_mask(input_ids, tokenizer)
369+
else:
370+
assistant_masks = result["assistant_masks"]
349371

350372
labels = [tid if mask else IGNORE_TOKEN_ID for tid, mask in zip(input_ids, assistant_masks)]
351373
return {

tests/unit/recipe/test_laq_recipes.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,14 @@
2020
import pytest
2121
import yaml
2222

23-
CONFIGS_DIR = Path(__file__).resolve().parents[3] / "examples" / "llm_qat" / "configs" / "quantize"
23+
CONFIGS_DIR = (
24+
Path(__file__).resolve().parents[3]
25+
/ "examples"
26+
/ "llm_qat"
27+
/ "configs"
28+
/ "quantize"
29+
/ "experimental"
30+
)
2431

2532
# (filename, expected learnable_amax, expected tied_amax)
2633
_LAQ_RECIPES = [

0 commit comments

Comments
 (0)