|
22 | 22 | import os.path |
23 | 23 | import numpy as np |
24 | 24 | import jax |
25 | | -import re |
26 | 25 | from jax.sharding import Mesh |
27 | 26 | from jax.experimental import mesh_utils |
28 | 27 | from datasets import Dataset |
29 | 28 | import transformers |
30 | 29 | from parameterized import parameterized_class |
31 | | -from unittest.mock import patch |
32 | 30 | from maxtext.configs import pyconfig |
33 | 31 | from maxtext.utils.globals import MAXTEXT_PKG_DIR, MAXTEXT_CONFIGS_DIR, MAXTEXT_ASSETS_ROOT |
34 | 32 | from maxtext.input_pipeline import hf_data_processing |
35 | 33 | from maxtext.input_pipeline import input_pipeline_interface |
36 | 34 | from maxtext.input_pipeline.hf_data_processing import _get_pad_id |
37 | | -from maxtext.input_pipeline.input_pipeline_utils import verify_chat_template_generation_prompt_logic |
| 35 | +from maxtext.input_pipeline.input_pipeline_utils import apply_chat_template, SFTPromptMasking, tokenization |
38 | 36 |
|
39 | 37 | PROMPT_DATA = [ |
40 | 38 | [ |
@@ -512,26 +510,118 @@ def setUp(self): |
512 | 510 | super().setUp() |
513 | 511 | self.qwen3_tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-4B") |
514 | 512 | self.llama2_tokenizer = transformers.AutoTokenizer.from_pretrained(self.LLAMA_TOKENIZER_PATH) |
| 513 | + self.gemma4_tokenizer = transformers.AutoTokenizer.from_pretrained("google/gemma-4-26B-A4B-it") |
| 514 | + |
| 515 | + def _apply_chat_template(self, tokenizer): |
| 516 | + """Helper function to apply the chat template to a sample input and return the result for testing.""" |
| 517 | + messages = [ |
| 518 | + {"role": "user", "content": "Q1"}, |
| 519 | + {"role": "assistant", "content": "A1"}, |
| 520 | + {"role": "user", "content": "Q2"}, |
| 521 | + {"role": "assistant", "content": "A2"}, |
| 522 | + ] |
| 523 | + example = {"messages": messages} |
| 524 | + return apply_chat_template(example, tokenizer, "messages") |
| 525 | + |
| 526 | + def test_apply_chat_template_with_qwen3_tokenizer(self): |
| 527 | + """Verifies that apply_chat_template correctly applies Qwen3's chat template.""" |
| 528 | + result = self._apply_chat_template(self.qwen3_tokenizer) |
| 529 | + self.assertEqual(result["is_prompt"], [True, False, True, False]) |
| 530 | + self.assertEqual(len(result["messages"]), 4) |
| 531 | + self.assertIn("<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\n", result["messages"][0]) |
| 532 | + self.assertIn("<think>\n\n</think>\n\nA1<|im_end|>\n", result["messages"][1]) |
| 533 | + self.assertIn("<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n", result["messages"][2]) |
| 534 | + self.assertIn("<think>\n\n</think>\n\nA2<|im_end|>\n", result["messages"][3]) |
| 535 | + |
| 536 | + def test_apply_chat_template_with_llama2_tokenizer(self): |
| 537 | + """Verifies that apply_chat_template correctly applies Llama2's chat template.""" |
| 538 | + result = self._apply_chat_template(self.llama2_tokenizer) |
| 539 | + self.assertEqual(result["is_prompt"], [True, False, True, False]) |
| 540 | + self.assertEqual(len(result["messages"]), 4) |
| 541 | + self.assertIn("<s>[INST] Q1 [/INST]", result["messages"][0]) |
| 542 | + self.assertIn("A1 </s>", result["messages"][1]) |
| 543 | + self.assertIn("<s>[INST] Q2 [/INST]", result["messages"][2]) |
| 544 | + self.assertIn("A2 </s>", result["messages"][3]) |
| 545 | + |
| 546 | + def test_apply_chat_template_with_gemma4_tokenizer(self): |
| 547 | + """Verifies that apply_chat_template correctly applies Gemma4's chat template.""" |
| 548 | + result = self._apply_chat_template(self.gemma4_tokenizer) |
| 549 | + self.assertEqual(result["is_prompt"], [True, False, True, False]) |
| 550 | + self.assertEqual(len(result["messages"]), 4) |
| 551 | + self.assertIn("<|turn>user\nQ1<turn|>\n<|turn>model\n<|channel>thought\n<channel|>", result["messages"][0]) |
| 552 | + self.assertIn("A1<turn|>\n", result["messages"][1]) |
| 553 | + self.assertIn("<|turn>user\nQ2<turn|>\n<|turn>model\n<|channel>thought\n<channel|>", result["messages"][2]) |
| 554 | + self.assertIn("A2<turn|>\n", result["messages"][3]) |
515 | 555 |
|
516 | | - def test_tokenizer_w_generation_prompt(self): |
517 | | - verify_chat_template_generation_prompt_logic(self.qwen3_tokenizer) |
518 | 556 |
|
519 | | - def test_tokenizer_wo_generation_prompt(self): |
520 | | - verify_chat_template_generation_prompt_logic(self.llama2_tokenizer) |
| 557 | +@pytest.mark.external_training |
| 558 | +class SFTPromptMaskingTest(unittest.TestCase): |
521 | 559 |
|
522 | | - def test_failure_path_with_modified_template(self): |
523 | | - """Verifies the function correctly raises a ValueError on a bad template.""" |
524 | | - # Replace the role within the existing add_generation_prompt block with a deliberately faulty one. |
525 | | - fault_chat_template = re.sub( |
526 | | - r"(\{%-?\s*if add_generation_prompt\s*%\}.*?<\|im_start\|>)assistant(.*?\{%-?\s*endif\s*%\})", |
527 | | - r"\1wrong_role\2", |
528 | | - self.qwen3_tokenizer.chat_template, |
529 | | - flags=re.DOTALL, |
| 560 | + def setUp(self): |
| 561 | + super().setUp() |
| 562 | + self.max_target_length = 50 |
| 563 | + self.qwen3_tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen3-4B") |
| 564 | + self.gemma4_tokenizer = transformers.AutoTokenizer.from_pretrained("google/gemma-4-26B-A4B-it") |
| 565 | + |
| 566 | + def _apply_prompt_masking(self, tokenizer, unk_id, completion_only=True): |
| 567 | + """Helper function to apply the prompt masking to a sample input and return the result for testing.""" |
| 568 | + messages = [ |
| 569 | + {"role": "user", "content": "Q1"}, |
| 570 | + {"role": "assistant", "content": "A1"}, |
| 571 | + {"role": "user", "content": "Q2"}, |
| 572 | + {"role": "assistant", "content": "A2"}, |
| 573 | + ] |
| 574 | + example = {"messages": messages} |
| 575 | + modified_example = apply_chat_template(example, tokenizer, "messages") |
| 576 | + tokenized_example = tokenization(modified_example, tokenizer, False, self.max_target_length, ["messages"]) |
| 577 | + op = SFTPromptMasking( |
| 578 | + text_column_name="messages", |
| 579 | + completion_only=completion_only, |
| 580 | + max_target_length=self.max_target_length, |
| 581 | + unk_id=unk_id, |
530 | 582 | ) |
531 | | - with patch.object(self.qwen3_tokenizer, "chat_template", fault_chat_template): |
532 | | - # Verify that our function catches the mismatch and raises the expected error |
533 | | - with self.assertRaisesRegex(ValueError, "Chat template generation prompt mismatch!"): |
534 | | - verify_chat_template_generation_prompt_logic(self.qwen3_tokenizer) |
| 583 | + return op.map({"messages": tokenized_example["messages"], "is_prompt": modified_example["is_prompt"]}) |
| 584 | + |
| 585 | + def _verify_prompt_masking(self, tokenizer, inputs, targets, unk_id): |
| 586 | + """Helper function to verify that the prompt masking was applied correctly.""" |
| 587 | + # Unmasked positions must match inputs exactly |
| 588 | + np.testing.assert_array_equal(inputs[targets != unk_id], targets[targets != unk_id]) |
| 589 | + |
| 590 | + # Some tokens must be masked |
| 591 | + self.assertTrue(np.any(targets == unk_id)) |
| 592 | + |
| 593 | + # Decoding unmasked tokens yields completions, not prompts |
| 594 | + completion = tokenizer.decode(targets[targets != unk_id], skip_special_tokens=False) |
| 595 | + self.assertIn("A1", completion) |
| 596 | + self.assertIn("A2", completion) |
| 597 | + self.assertNotIn("Q1", completion) |
| 598 | + self.assertNotIn("Q2", completion) |
| 599 | + |
| 600 | + def test_sft_prompt_masking_with_qwen3_tokenizer(self): |
| 601 | + """Verifies that SFTPromptMasking correctly applies masking for Qwen3's chat template.""" |
| 602 | + unk_id = _get_pad_id(self.qwen3_tokenizer) |
| 603 | + result = self._apply_prompt_masking(self.qwen3_tokenizer, unk_id) |
| 604 | + inputs, targets = result["inputs"], result["targets"] |
| 605 | + self._verify_prompt_masking(self.qwen3_tokenizer, inputs, targets, unk_id) |
| 606 | + |
| 607 | + def test_sft_prompt_masking_with_gemma4_tokenizer(self): |
| 608 | + """Verifies that SFTPromptMasking correctly applies masking for Gemma4's chat template.""" |
| 609 | + unk_id = _get_pad_id(self.gemma4_tokenizer) |
| 610 | + result = self._apply_prompt_masking(self.gemma4_tokenizer, unk_id) |
| 611 | + inputs, targets = result["inputs"], result["targets"] |
| 612 | + self._verify_prompt_masking(self.gemma4_tokenizer, inputs, targets, unk_id) |
| 613 | + |
| 614 | + def test_sft_no_prompt_masking_with_qwen3_tokenizer(self): |
| 615 | + """Verifies that prompt masking is not applied when completion_only=False with Qwen3 tokenizer.""" |
| 616 | + unk_id = _get_pad_id(self.qwen3_tokenizer) |
| 617 | + result = self._apply_prompt_masking(self.qwen3_tokenizer, unk_id, completion_only=False) |
| 618 | + np.testing.assert_array_equal(result["inputs"], result["targets"]) |
| 619 | + |
| 620 | + def test_sft_no_prompt_masking_with_gemma4_tokenizer(self): |
| 621 | + """Verifies that prompt masking is not applied when completion_only=False with Gemma4 tokenizer.""" |
| 622 | + unk_id = _get_pad_id(self.gemma4_tokenizer) |
| 623 | + result = self._apply_prompt_masking(self.gemma4_tokenizer, unk_id, completion_only=False) |
| 624 | + np.testing.assert_array_equal(result["inputs"], result["targets"]) |
535 | 625 |
|
536 | 626 |
|
537 | 627 | if __name__ == "__main__": |
|
0 commit comments