Skip to content

Commit af334f1

Browse files
Merge pull request #4010 from AI-Hypercomputer:gemma4_sft
PiperOrigin-RevId: 923592408
2 parents 30b0d2d + dd610f2 commit af334f1

3 files changed

Lines changed: 139 additions & 96 deletions

File tree

src/maxtext/input_pipeline/hf_data_processing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,6 @@ def preprocessing_pipeline(
321321
)
322322
operations = []
323323
if use_sft:
324-
input_pipeline_utils.verify_chat_template_generation_prompt_logic(tokenizer)
325324
operations.append(
326325
input_pipeline_utils.SFTPromptMasking(
327326
text_column_name=data_column_names[0],

src/maxtext/input_pipeline/input_pipeline_utils.py

Lines changed: 30 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from threading import current_thread
2020
from typing import Any, Iterable, TYPE_CHECKING
2121

22-
from jinja2 import TemplateError
23-
2422
if TYPE_CHECKING:
2523
import datasets
2624
import tensorflow as tf
@@ -210,83 +208,24 @@ def extract_token_ids(tokens):
210208
raise ValueError(f"Can't extract token_ids from type {type(tokens)}")
211209

212210

213-
def verify_chat_template_generation_prompt_logic(tokenizer_model):
214-
"""Verifies the tokenizer's chat template for correct SFT loss masking.
215-
216-
This function ensures that the tokens added by `add_generation_prompt=True`
217-
are identical to the tokens that begin an assistant's turn in a complete
218-
conversation, which is critical for masking prompt tokens during SFT loss
219-
calculation.
220-
221-
Example of a mismatch:
222-
A `ValueError` is raised if the generation prompt and the actual
223-
assistant prefix do not match. For example:
224-
225-
- `add_generation_prompt=True` on a user message produces a prompt ending in:
226-
`...<|im_start|>generation\n`
227-
- A full turn with an assistant message starts the reply with:
228-
`...<|im_start|>assistant\n...`
229-
230-
This function would fail because the tokens for "generation" do not
231-
match the tokens for "assistant".
232-
233-
Args:
234-
tokenizer_model: The Hugging Face tokenizer instance to verify.
235-
236-
Raises:
237-
ValueError: If the `add_generation_prompt` tokens do not exactly
238-
match the beginning of an assistant message in the template.
239-
"""
240-
dummy_msgs = [{"role": "system", "content": "System message"}, {"role": "user", "content": "Test message"}]
241-
242-
try:
243-
prompt_wo_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=False, tokenize=True)
244-
except TemplateError:
245-
max_logging.info(
246-
"Tokenizer failed to apply chat template with 'system' role. "
247-
"Falling back to 'user' role only for chat template verification."
248-
)
249-
dummy_msgs.pop(0)
250-
prompt_wo_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=False, tokenize=True)
251-
prompt_wo_gen_ids = extract_token_ids(prompt_wo_gen_tokens)
252-
253-
prompt_w_gen_tokens = tokenizer_model.apply_chat_template(dummy_msgs, add_generation_prompt=True, tokenize=True)
254-
prompt_w_gen_ids = extract_token_ids(prompt_w_gen_tokens)
255-
256-
if prompt_w_gen_ids[: len(prompt_wo_gen_ids)] != prompt_wo_gen_ids:
257-
raise ValueError("Unable to extract generation prompt tokens.")
258-
# Extract the tokenized generation prompt (the expected assistant prefix)
259-
assistant_prefix = prompt_w_gen_ids[len(prompt_wo_gen_ids) :]
260-
full_turn_tokens = extract_token_ids(
261-
tokenizer_model.apply_chat_template(
262-
dummy_msgs + [{"role": "assistant", "content": "Dummy response"}], add_generation_prompt=False, tokenize=True
263-
)
264-
)
265-
full_turn_ids = extract_token_ids(full_turn_tokens)
266-
# Extract the actual tokens that appear right after the user message in the full turn
267-
actual_prefix_in_full_turn = full_turn_ids[len(prompt_wo_gen_ids) : len(prompt_wo_gen_ids) + len(assistant_prefix)]
268-
269-
if actual_prefix_in_full_turn != assistant_prefix:
270-
expected_str = tokenizer_model.decode(assistant_prefix)
271-
actual_str = tokenizer_model.decode(actual_prefix_in_full_turn)
272-
raise ValueError(
273-
"Chat template generation prompt mismatch!\n"
274-
f"Expected assistant prefix tokens: {assistant_prefix} ('{expected_str}')\n"
275-
f"Actual prefix tokens found: {actual_prefix_in_full_turn} ('{actual_str}')\n"
276-
"This means the tokenizer's chat template will break the sft masking logic."
277-
)
211+
def _get_completion_in_chat_template(tokenizer_model, round_msgs):
212+
"""Calculates the completion part of a conversation turn formatted with a chat template.
278213
214+
Uses the longest-common-prefix between the full conversation tokens and the
215+
generation-prompt tokens to locate where the completion starts.
279216
280-
def _get_completion_in_chat_template(tokenizer_model, round_msgs):
281-
"""
282-
Calculates the completion part of a conversation turn when formatted with a chat template.
217+
For most models (Llama, Qwen, …) the generation prompt is an exact prefix of the
218+
full conversation, so common_len == len(prompt_ids).
283219
284-
This function handles both older and current Hugging Face tokenizers. Modern tokenizers
285-
may return a `BatchEncoding` object instead of a simple list of token IDs.
220+
For Gemma4, add_generation_prompt=True emits thinking-channel tokens
221+
(<|channel>thought\\n<channel|>) that diverge from the plain conversation
222+
at the model-turn boundary. The common prefix ends just before that
223+
divergence, and the completion correctly captures the thinking content
224+
and response tokens.
286225
287226
Args:
288227
tokenizer_model: The tokenizer instance.
289-
round_msgs: A list of messages for the current conversational turn, including the assistant's response.
228+
round_msgs: Messages for the current conversational turn including the assistant response.
290229
291230
Returns:
292231
A string representing the completion formatted by the chat template.
@@ -298,9 +237,24 @@ def _get_completion_in_chat_template(tokenizer_model, round_msgs):
298237
prompt_completion_ids = extract_token_ids(prompt_completion_tokens)
299238
prompt_ids = extract_token_ids(prompt_tokens)
300239

301-
completion_tokens = prompt_completion_ids[len(prompt_ids) :]
302-
completion_in_chat_template = tokenizer_model.decode(completion_tokens, skip_special_tokens=False)
303-
return completion_in_chat_template
240+
# Walk forward until the two sequences diverge
241+
common_len = 0
242+
for full_id, prompt_id in zip(prompt_completion_ids, prompt_ids):
243+
if full_id == prompt_id:
244+
common_len += 1
245+
else:
246+
break
247+
248+
if common_len == 0:
249+
raise ValueError(
250+
"Chat template generation prompt mismatch: no common prefix tokens found.\n"
251+
f"Full conversation tokens: {prompt_completion_ids} ('{tokenizer_model.decode(prompt_completion_ids)}')\n"
252+
f"Generation prompt tokens: {prompt_ids} ('{tokenizer_model.decode(prompt_ids)}')\n"
253+
"Cannot determine completion boundary."
254+
)
255+
256+
completion_tokens = prompt_completion_ids[common_len:]
257+
return tokenizer_model.decode(completion_tokens, skip_special_tokens=False)
304258

305259

306260
def apply_chat_template(example, tokenizer_model, data_column_name):

tests/post_training/unit/sft_data_processing_test.py

Lines changed: 109 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,17 @@
2222
import os.path
2323
import numpy as np
2424
import jax
25-
import re
2625
from jax.sharding import Mesh
2726
from jax.experimental import mesh_utils
2827
from datasets import Dataset
2928
import transformers
3029
from parameterized import parameterized_class
31-
from unittest.mock import patch
3230
from maxtext.configs import pyconfig
3331
from maxtext.utils.globals import MAXTEXT_PKG_DIR, MAXTEXT_CONFIGS_DIR, MAXTEXT_ASSETS_ROOT
3432
from maxtext.input_pipeline import hf_data_processing
3533
from maxtext.input_pipeline import input_pipeline_interface
3634
from maxtext.input_pipeline.hf_data_processing import _get_pad_id
37-
from maxtext.input_pipeline.input_pipeline_utils import verify_chat_template_generation_prompt_logic
35+
from maxtext.input_pipeline.input_pipeline_utils import apply_chat_template, SFTPromptMasking, tokenization
3836

3937
PROMPT_DATA = [
4038
[
@@ -512,26 +510,118 @@ def setUp(self):
512510
super().setUp()
513511
self.qwen3_tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
514512
self.llama2_tokenizer = transformers.AutoTokenizer.from_pretrained(self.LLAMA_TOKENIZER_PATH)
513+
self.gemma4_tokenizer = transformers.AutoTokenizer.from_pretrained("google/gemma-4-26B-A4B-it")
514+
515+
def _apply_chat_template(self, tokenizer):
516+
"""Helper function to apply the chat template to a sample input and return the result for testing."""
517+
messages = [
518+
{"role": "user", "content": "Q1"},
519+
{"role": "assistant", "content": "A1"},
520+
{"role": "user", "content": "Q2"},
521+
{"role": "assistant", "content": "A2"},
522+
]
523+
example = {"messages": messages}
524+
return apply_chat_template(example, tokenizer, "messages")
525+
526+
def test_apply_chat_template_with_qwen3_tokenizer(self):
527+
"""Verifies that apply_chat_template correctly applies Qwen3's chat template."""
528+
result = self._apply_chat_template(self.qwen3_tokenizer)
529+
self.assertEqual(result["is_prompt"], [True, False, True, False])
530+
self.assertEqual(len(result["messages"]), 4)
531+
self.assertIn("<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n", result["messages"][0])
532+
self.assertIn("<think>\n\n</think>\n\nA1<|im_end|>\n", result["messages"][1])
533+
self.assertIn("<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n", result["messages"][2])
534+
self.assertIn("<think>\n\n</think>\n\nA2<|im_end|>\n", result["messages"][3])
535+
536+
def test_apply_chat_template_with_llama2_tokenizer(self):
537+
"""Verifies that apply_chat_template correctly applies Llama2's chat template."""
538+
result = self._apply_chat_template(self.llama2_tokenizer)
539+
self.assertEqual(result["is_prompt"], [True, False, True, False])
540+
self.assertEqual(len(result["messages"]), 4)
541+
self.assertIn("<s>[INST] Q1 [/INST]", result["messages"][0])
542+
self.assertIn("A1 </s>", result["messages"][1])
543+
self.assertIn("<s>[INST] Q2 [/INST]", result["messages"][2])
544+
self.assertIn("A2 </s>", result["messages"][3])
545+
546+
def test_apply_chat_template_with_gemma4_tokenizer(self):
547+
"""Verifies that apply_chat_template correctly applies Gemma4's chat template."""
548+
result = self._apply_chat_template(self.gemma4_tokenizer)
549+
self.assertEqual(result["is_prompt"], [True, False, True, False])
550+
self.assertEqual(len(result["messages"]), 4)
551+
self.assertIn("<|turn>user\nQ1<turn|>\n<|turn>model\n<|channel>thought\n<channel|>", result["messages"][0])
552+
self.assertIn("A1<turn|>\n", result["messages"][1])
553+
self.assertIn("<|turn>user\nQ2<turn|>\n<|turn>model\n<|channel>thought\n<channel|>", result["messages"][2])
554+
self.assertIn("A2<turn|>\n", result["messages"][3])
515555

516-
def test_tokenizer_w_generation_prompt(self):
517-
verify_chat_template_generation_prompt_logic(self.qwen3_tokenizer)
518556

519-
def test_tokenizer_wo_generation_prompt(self):
520-
verify_chat_template_generation_prompt_logic(self.llama2_tokenizer)
557+
@pytest.mark.external_training
558+
class SFTPromptMaskingTest(unittest.TestCase):
521559

522-
def test_failure_path_with_modified_template(self):
523-
"""Verifies the function correctly raises a ValueError on a bad template."""
524-
# Replace the role within the existing add_generation_prompt block with a deliberately faulty one.
525-
fault_chat_template = re.sub(
526-
r"(\{%-?\s*if add_generation_prompt\s*%\}.*?<\|im_start\|>)assistant(.*?\{%-?\s*endif\s*%\})",
527-
r"\1wrong_role\2",
528-
self.qwen3_tokenizer.chat_template,
529-
flags=re.DOTALL,
560+
def setUp(self):
561+
super().setUp()
562+
self.max_target_length = 50
563+
self.qwen3_tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
564+
self.gemma4_tokenizer = transformers.AutoTokenizer.from_pretrained("google/gemma-4-26B-A4B-it")
565+
566+
def _apply_prompt_masking(self, tokenizer, unk_id, completion_only=True):
567+
"""Helper function to apply the prompt masking to a sample input and return the result for testing."""
568+
messages = [
569+
{"role": "user", "content": "Q1"},
570+
{"role": "assistant", "content": "A1"},
571+
{"role": "user", "content": "Q2"},
572+
{"role": "assistant", "content": "A2"},
573+
]
574+
example = {"messages": messages}
575+
modified_example = apply_chat_template(example, tokenizer, "messages")
576+
tokenized_example = tokenization(modified_example, tokenizer, False, self.max_target_length, ["messages"])
577+
op = SFTPromptMasking(
578+
text_column_name="messages",
579+
completion_only=completion_only,
580+
max_target_length=self.max_target_length,
581+
unk_id=unk_id,
530582
)
531-
with patch.object(self.qwen3_tokenizer, "chat_template", fault_chat_template):
532-
# Verify that our function catches the mismatch and raises the expected error
533-
with self.assertRaisesRegex(ValueError, "Chat template generation prompt mismatch!"):
534-
verify_chat_template_generation_prompt_logic(self.qwen3_tokenizer)
583+
return op.map({"messages": tokenized_example["messages"], "is_prompt": modified_example["is_prompt"]})
584+
585+
def _verify_prompt_masking(self, tokenizer, inputs, targets, unk_id):
586+
"""Helper function to verify that the prompt masking was applied correctly."""
587+
# Unmasked positions must match inputs exactly
588+
np.testing.assert_array_equal(inputs[targets != unk_id], targets[targets != unk_id])
589+
590+
# Some tokens must be masked
591+
self.assertTrue(np.any(targets == unk_id))
592+
593+
# Decoding unmasked tokens yields completions, not prompts
594+
completion = tokenizer.decode(targets[targets != unk_id], skip_special_tokens=False)
595+
self.assertIn("A1", completion)
596+
self.assertIn("A2", completion)
597+
self.assertNotIn("Q1", completion)
598+
self.assertNotIn("Q2", completion)
599+
600+
def test_sft_prompt_masking_with_qwen3_tokenizer(self):
601+
"""Verifies that SFTPromptMasking correctly applies masking for Qwen3's chat template."""
602+
unk_id = _get_pad_id(self.qwen3_tokenizer)
603+
result = self._apply_prompt_masking(self.qwen3_tokenizer, unk_id)
604+
inputs, targets = result["inputs"], result["targets"]
605+
self._verify_prompt_masking(self.qwen3_tokenizer, inputs, targets, unk_id)
606+
607+
def test_sft_prompt_masking_with_gemma4_tokenizer(self):
608+
"""Verifies that SFTPromptMasking correctly applies masking for Gemma4's chat template."""
609+
unk_id = _get_pad_id(self.gemma4_tokenizer)
610+
result = self._apply_prompt_masking(self.gemma4_tokenizer, unk_id)
611+
inputs, targets = result["inputs"], result["targets"]
612+
self._verify_prompt_masking(self.gemma4_tokenizer, inputs, targets, unk_id)
613+
614+
def test_sft_no_prompt_masking_with_qwen3_tokenizer(self):
615+
"""Verifies that prompt masking is not applied when completion_only=False with Qwen3 tokenizer."""
616+
unk_id = _get_pad_id(self.qwen3_tokenizer)
617+
result = self._apply_prompt_masking(self.qwen3_tokenizer, unk_id, completion_only=False)
618+
np.testing.assert_array_equal(result["inputs"], result["targets"])
619+
620+
def test_sft_no_prompt_masking_with_gemma4_tokenizer(self):
621+
"""Verifies that prompt masking is not applied when completion_only=False with Gemma4 tokenizer."""
622+
unk_id = _get_pad_id(self.gemma4_tokenizer)
623+
result = self._apply_prompt_masking(self.gemma4_tokenizer, unk_id, completion_only=False)
624+
np.testing.assert_array_equal(result["inputs"], result["targets"])
535625

536626

537627
if __name__ == "__main__":

0 commit comments

Comments
 (0)