Skip to content

Commit 7568408

Browse files
ChenhanYuclaude
andcommitted
fix: pass chat template explicitly in generation tag tests
Tests now load chat_template_train.jinja instead of relying on auto-injection (which was replaced by strict verification). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
1 parent 20a491c commit 7568408

1 file changed

Lines changed: 16 additions & 3 deletions

File tree

tests/unit/torch/speculative/plugins/test_hf_dflash.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import os
2222
from copy import deepcopy
23+
from pathlib import Path
2324
from types import SimpleNamespace
2425
from unittest.mock import MagicMock
2526

@@ -414,15 +415,26 @@ def qwen3_tokenizer(self):
414415

415416
return AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
416417

417-
def test_chatml_think_template_produces_assistant_mask(self, qwen3_tokenizer):
418-
"""Qwen3 uses chatml_think style. Verify generation tags are injected and masks work."""
418+
@pytest.fixture
419+
def qwen3_chat_template(self):
420+
template_path = (
421+
Path(__file__).parents[5]
422+
/ "tools/launcher/examples/Qwen/Qwen3-8B/chat_template_train.jinja"
423+
)
424+
return template_path.read_text()
425+
426+
def test_chatml_think_template_produces_assistant_mask(
427+
self, qwen3_tokenizer, qwen3_chat_template
428+
):
429+
"""Verify generation-tagged chat template produces correct assistant masks."""
419430
from modelopt.torch.utils.plugins.transformers_dataset import LanguageDataCollator
420431

421432
collator = LanguageDataCollator(
422433
tokenizer=qwen3_tokenizer,
423434
train_len=128,
424435
return_labels=True,
425436
answer_only_loss=True,
437+
chat_template=qwen3_chat_template,
426438
)
427439

428440
# Verify template was replaced with generation-tagged version
@@ -451,7 +463,7 @@ def test_chatml_think_template_produces_assistant_mask(self, qwen3_tokenizer):
451463
decoded = qwen3_tokenizer.decode(non_masked)
452464
assert "The answer is 4." in decoded
453465

454-
def test_multi_turn_masks_only_assistant(self, qwen3_tokenizer):
466+
def test_multi_turn_masks_only_assistant(self, qwen3_tokenizer, qwen3_chat_template):
455467
"""Verify multi-turn: only assistant turns are unmasked."""
456468
from modelopt.torch.utils.plugins.transformers_dataset import LanguageDataCollator
457469

@@ -460,6 +472,7 @@ def test_multi_turn_masks_only_assistant(self, qwen3_tokenizer):
460472
train_len=256,
461473
return_labels=True,
462474
answer_only_loss=True,
475+
chat_template=qwen3_chat_template,
463476
)
464477

465478
messages = [

0 commit comments

Comments
 (0)