Skip to content

Commit 7db6d71

Browse files
Merge pull request #3350 from AI-Hypercomputer:jimmytsai/add-template-check-for-sft
PiperOrigin-RevId: 893538846
2 parents d245255 + 5bd32d7 commit 7db6d71

3 files changed

Lines changed: 151 additions & 13 deletions

File tree

src/maxtext/input_pipeline/hf_data_processing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ def preprocessing_pipeline(
302302
)
303303
operations = []
304304
if use_sft:
305+
input_pipeline_utils.verify_chat_template_generation_prompt_logic(tokenizer)
305306
operations.append(
306307
input_pipeline_utils.SFTPromptMasking(
307308
text_column_name=data_column_names[0],

src/maxtext/input_pipeline/input_pipeline_utils.py

Lines changed: 101 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from threading import current_thread
2020
from typing import Any, Iterable, TYPE_CHECKING
2121

22+
from jinja2 import TemplateError
23+
2224
if TYPE_CHECKING:
2325
import datasets
2426
import tensorflow as tf
@@ -178,6 +180,103 @@ def is_conversational(features, data_columns):
178180
return False
179181

180182

183+
def _extract_token_ids(tokens):
184+
"""Extracts token IDs from various tokenizer output formats.
185+
186+
This helper function standardizes the extraction of tokenized integer IDs
187+
from common return types of Hugging Face tokenizers, including
188+
`BatchEncoding` objects, dictionaries, or simple lists.
189+
190+
Args:
191+
tokens: The object containing token IDs. Supported types include:
192+
- A list of integers.
193+
- A dictionary containing the `INPUT_TOKENS_KEY`.
194+
- An object (e.g., `BatchEncoding`) with an attribute named `INPUT_TOKENS_KEY`.
195+
196+
Returns:
197+
A list of integer token IDs.
198+
199+
Raises:
200+
ValueError: If the input type is not supported or does not contain the expected key.
201+
"""
202+
# attention masks in BatchEncoding are effectively ignored
203+
if hasattr(tokens, INPUT_TOKENS_KEY):
204+
return getattr(tokens, INPUT_TOKENS_KEY)
205+
elif isinstance(tokens, dict) and INPUT_TOKENS_KEY in tokens:
206+
return tokens[INPUT_TOKENS_KEY]
207+
elif isinstance(tokens, list):
208+
return tokens
209+
else:
210+
raise ValueError(f"Can't extract token_ids from type {type(tokens)}")
211+
212+
213+
def verify_chat_template_generation_prompt_logic(tokenizer_model):
214+
"""Verifies the tokenizer's chat template for correct SFT loss masking.
215+
216+
This function ensures that the tokens added by `add_generation_prompt=True`
217+
are identical to the tokens that begin an assistant's turn in a complete
218+
conversation, which is critical for masking prompt tokens during SFT loss
219+
calculation.
220+
221+
Example of a mismatch:
222+
A `ValueError` is raised if the generation prompt and the actual
223+
assistant prefix do not match. For example:
224+
225+
- `add_generation_prompt=True` on a user message produces a prompt ending in:
226+
`...<|im_start|>generation\n`
227+
- A full turn with an assistant message starts the reply with:
228+
`...<|im_start|>assistant\n...`
229+
230+
This function would fail because the tokens for "generation" do not
231+
match the tokens for "assistant".
232+
233+
Args:
234+
tokenizer_model: The Hugging Face tokenizer instance to verify.
235+
236+
Raises:
237+
ValueError: If the `add_generation_prompt` tokens do not exactly
238+
match the beginning of an assistant message in the template.
239+
"""
240+
dummy_msgs = [{"role": "system", "content": "System message"}, {"role": "user", "content": "Test message"}]
241+
242+
try:
243+
prompt_wo_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=False, tokenize=True)
244+
except TemplateError:
245+
max_logging.info(
246+
"Tokenizer failed to apply chat template with 'system' role. "
247+
"Falling back to 'user' role only for chat template verification."
248+
)
249+
dummy_msgs.pop(0)
250+
prompt_wo_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=False, tokenize=True)
251+
prompt_wo_gen_ids = _extract_token_ids(prompt_wo_gen_tokens)
252+
253+
prompt_w_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=True, tokenize=True)
254+
prompt_w_gen_ids = _extract_token_ids(prompt_w_gen_tokens)
255+
256+
if prompt_w_gen_ids[: len(prompt_wo_gen_ids)] != prompt_wo_gen_ids:
257+
raise ValueError("Unable to extract generation prompt tokens.")
258+
# Extract the tokenized generation prompt (the expected assistant prefix)
259+
assistant_prefix = prompt_w_gen_ids[len(prompt_wo_gen_ids) :]
260+
full_turn_tokens = _extract_token_ids(
261+
tokenizer_model.apply_chat_template(
262+
dummy_msgs + [{"role": "assistant", "content": "Dummy response"}], add_generation_prompt=False, tokenize=True
263+
)
264+
)
265+
full_turn_ids = _extract_token_ids(full_turn_tokens)
266+
# Extract the actual tokens that appear right after the user message in the full turn
267+
actual_prefix_in_full_turn = full_turn_ids[len(prompt_wo_gen_ids) : len(prompt_wo_gen_ids) + len(assistant_prefix)]
268+
269+
if actual_prefix_in_full_turn != assistant_prefix:
270+
expected_str = tokenizer_model.decode(assistant_prefix)
271+
actual_str = tokenizer_model.decode(actual_prefix_in_full_turn)
272+
raise ValueError(
273+
"Chat template generation prompt mismatch!\n"
274+
f"Expected assistant prefix tokens: {assistant_prefix} ('{expected_str}')\n"
275+
f"Actual prefix tokens found: {actual_prefix_in_full_turn} ('{actual_str}')\n"
276+
"This means the tokenizer's chat template will break the sft masking logic."
277+
)
278+
279+
181280
def _get_completion_in_chat_template(tokenizer_model, round_msgs):
182281
"""
183282
Calculates the completion part of a conversation turn when formatted with a chat template.
@@ -196,18 +295,8 @@ def _get_completion_in_chat_template(tokenizer_model, round_msgs):
196295
# include generation_prompt as part of the prompt tokens
197296
prompt_tokens = tokenizer_model.apply_chat_template(round_msgs[:-1], add_generation_prompt=True, tokenize=True)
198297

199-
# attention masks in BatchEncoding are effectively ignored
200-
if hasattr(prompt_completion_tokens, INPUT_TOKENS_KEY):
201-
prompt_completion_ids = getattr(prompt_completion_tokens, INPUT_TOKENS_KEY)
202-
prompt_ids = getattr(prompt_tokens, INPUT_TOKENS_KEY)
203-
elif isinstance(prompt_completion_tokens, dict) and INPUT_TOKENS_KEY in prompt_completion_tokens:
204-
prompt_completion_ids = prompt_completion_tokens[INPUT_TOKENS_KEY]
205-
prompt_ids = prompt_tokens[INPUT_TOKENS_KEY]
206-
elif isinstance(prompt_completion_tokens, list):
207-
prompt_completion_ids = prompt_completion_tokens
208-
prompt_ids = prompt_tokens
209-
else:
210-
raise ValueError(f"Can't handle the chat template output of type {type(prompt_completion_tokens)}")
298+
prompt_completion_ids = _extract_token_ids(prompt_completion_tokens)
299+
prompt_ids = _extract_token_ids(prompt_tokens)
211300

212301
completion_tokens = prompt_completion_ids[len(prompt_ids) :]
213302
completion_in_chat_template = tokenizer_model.decode(completion_tokens, skip_special_tokens=False)

tests/post_training/unit/sft_data_processing_test.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,19 @@
2222
import os.path
2323
import numpy as np
2424
import jax
25+
import re
2526
from jax.sharding import Mesh
2627
from jax.experimental import mesh_utils
2728
from datasets import Dataset
2829
import transformers
2930
from parameterized import parameterized_class
30-
31+
from unittest.mock import patch
3132
from maxtext.configs import pyconfig
3233
from maxtext.utils.globals import MAXTEXT_PKG_DIR, MAXTEXT_CONFIGS_DIR, MAXTEXT_ASSETS_ROOT
3334
from maxtext.input_pipeline import hf_data_processing
3435
from maxtext.input_pipeline import input_pipeline_interface
3536
from maxtext.input_pipeline.hf_data_processing import _get_pad_id
37+
from maxtext.input_pipeline.input_pipeline_utils import verify_chat_template_generation_prompt_logic
3638

3739
PROMPT_DATA = [
3840
[
@@ -484,5 +486,51 @@ def test_system_message_not_at_beginning(self):
484486
self.get_data_iterator(dataset, ["messages"])
485487

486488

489+
@pytest.mark.external_training
490+
class SFTChatTemplateLogicTest(unittest.TestCase):
491+
LLAMA_TOKENIZER_PATH = os.path.join(MAXTEXT_ASSETS_ROOT, "llama2-chat-tokenizer")
492+
493+
@classmethod
494+
def setUpClass(cls):
495+
super().setUpClass()
496+
if not os.path.exists(cls.LLAMA_TOKENIZER_PATH):
497+
exit_code = subprocess.call(
498+
[
499+
"gsutil",
500+
"cp",
501+
"-r",
502+
"gs://maxtext-dataset/hf/llama2-chat-tokenizer",
503+
os.path.join(MAXTEXT_ASSETS_ROOT, ""),
504+
]
505+
)
506+
if exit_code != 0:
507+
raise ValueError("Failed to download llama tokenizer")
508+
509+
def setUp(self):
510+
super().setUp()
511+
self.qwen3_tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
512+
self.llama2_tokenizer = transformers.AutoTokenizer.from_pretrained(self.LLAMA_TOKENIZER_PATH)
513+
514+
def test_tokenizer_w_generation_prompt(self):
515+
verify_chat_template_generation_prompt_logic(self.qwen3_tokenizer)
516+
517+
def test_tokenizer_wo_generation_prompt(self):
518+
verify_chat_template_generation_prompt_logic(self.llama2_tokenizer)
519+
520+
def test_failure_path_with_modified_template(self):
521+
"""Verifies the function correctly raises a ValueError on a bad template."""
522+
# Replace the role within the existing add_generation_prompt block with a deliberately faulty one.
523+
fault_chat_template = re.sub(
524+
r"(\{%-?\s*if add_generation_prompt\s*%\}.*?<\|im_start\|>)assistant(.*?\{%-?\s*endif\s*%\})",
525+
r"\1wrong_role\2",
526+
self.qwen3_tokenizer.chat_template,
527+
flags=re.DOTALL,
528+
)
529+
with patch.object(self.qwen3_tokenizer, "chat_template", fault_chat_template):
530+
# Verify that our function catches the mismatch and raises the expected error
531+
with self.assertRaisesRegex(ValueError, "Chat template generation prompt mismatch!"):
532+
verify_chat_template_generation_prompt_logic(self.qwen3_tokenizer)
533+
534+
487535
if __name__ == "__main__":
488536
unittest.main()

0 commit comments

Comments
 (0)