diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 2570f5c915..73e28c33a4 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -645,6 +645,7 @@ tokenizer_path: "" tokenizer_type: "sentencepiece" # Currently supporting: "tiktoken", "sentencepiece", "huggingface" use_chat_template: false chat_template_path: "" # path to chat template json file +chat_template: "" # Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template. tokenize_train_data: true # false if the dataset is pre-tokenized tokenize_eval_data: true # false if the dataset is pre-tokenized add_bos: true diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index 133d66c730..94011ba5c1 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -208,7 +208,7 @@ reasoning_start_token: '' reasoning_end_token: '' solution_start_token: '' solution_end_token: '' -chat_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json' +data_template_path: 'maxtext/examples/chat_templates/gsm8k_rl.json' skip_jax_distributed_system: true # ====== Dataset Configuration ====== diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 8c17bc24dd..6cad0b6b8b 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1081,6 +1081,10 @@ class Tokenizer(BaseModel): "", description="Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template.", ) + data_template_path: str = Field( + "", + description="Path to a chat template file to be used when tokenizing the dataset.", + ) tokenize_train_data: bool = Field(True, description="If False, assumes the training dataset is pre-tokenized.") tokenize_eval_data: bool = Field(True, description="If False, assumes the evaluation dataset is pre-tokenized.") add_bos: bool = Field(True, description="Whether to add a beginning-of-sentence token.") diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index d4826f73a0..997f021d4e 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -248,10 +248,10 @@ def prepare_datasets( model_tokenizer: AutoTokenizer, ) -> tuple[grain.IterDataset, grain.IterDataset | None]: """Setup and return train and test datasets.""" - template_config = load_data_template_from_file(trainer_config.chat_template_path) + template_config = load_data_template_from_file(trainer_config.data_template_path) if template_config is None: raise ValueError( - f"Chat template is required for processing dataset but failed to load from {trainer_config.chat_template_path}" + f"Chat template is required for processing dataset but failed to load from {trainer_config.data_template_path}" ) # Prepare train and test data from training data for certain datasets @@ -548,10 +548,10 @@ def _reward_fn(**kwargs): epsilon_high=trainer_config.rl.epsilon_high, ) # Instantiate the custom MaxText chat parser - template_config = load_data_template_from_file(trainer_config.chat_template_path) + template_config = load_data_template_from_file(trainer_config.data_template_path) if template_config is None: raise ValueError( - f"Chat template is required for AgenticGRPOLearner but failed to load from {trainer_config.chat_template_path}" + f"Chat template is required for AgenticGRPOLearner but failed to load from {trainer_config.data_template_path}" ) chat_parser = utils_rl.MaxTextChatParser( model_tokenizer=model_tokenizer, template_config=template_config, tmvp_config=trainer_config @@ -582,6 +582,28 @@ def _reward_fn(**kwargs): return rl_cluster, rl_trainer, optimizer +def configure_tokenizer_chat_template(model_tokenizer: Any, trainer_config: Any) -> None: + """Populates the tokenizer's chat_template from config if missing.""" + if getattr(model_tokenizer, "chat_template", None) is None: + if getattr(trainer_config, "chat_template", None): + model_tokenizer.chat_template = trainer_config.chat_template + elif getattr(trainer_config, "chat_template_path", None): + from maxtext.input_pipeline.instruction_data_processing import ( # pylint: disable=import-outside-toplevel + load_chat_template_from_file, + ) + + model_tokenizer.chat_template = load_chat_template_from_file(trainer_config.chat_template_path) + else: + raise ValueError( + f"Tokenizer {getattr(trainer_config, 'tokenizer_path', None)!r} has no chat_template " + "and config.chat_template / config.chat_template_path " + "are both empty. Either pick an instruction-tuned tokenizer that " + "ships with a chat_template, set config.chat_template to a Jinja " + "string, or set config.chat_template_path to a JSON file " + "with a 'chat_template' key." + ) + + def rl_train(argv: Sequence[str], kwargs: dict): """ Run RL training with the provided configuration. @@ -616,6 +638,7 @@ def _rl_train_impl(argv: Sequence[str], kwargs: dict): trainer_config.tokenizer_path, token=trainer_config.hf_access_token or None, ) + configure_tokenizer_chat_template(model_tokenizer, trainer_config) reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = model_creation_utils.create_models_and_meshes( trainer_config, diff --git a/tests/post_training/unit/train_rl_test.py b/tests/post_training/unit/train_rl_test.py index ede2029dea..56ba2a7dcf 100644 --- a/tests/post_training/unit/train_rl_test.py +++ b/tests/post_training/unit/train_rl_test.py @@ -58,7 +58,10 @@ def test_setup_configs_and_devices_pathways_split(self): # Following the pattern in distillation_checkpointing_test.py for mocking jax objects with ( mock.patch.object(jax, "devices", return_value=mock_devices), - mock.patch("maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", return_value=mock_config), + mock.patch( + "maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", + return_value=mock_config, + ), ): trainer_config, sampler_config, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices( ["dummy", "dummy"] @@ -87,7 +90,10 @@ def test_setup_configs_and_devices_pathways_fractional_split(self): with ( mock.patch.object(jax, "devices", return_value=mock_devices), - mock.patch("maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", return_value=mock_config), + mock.patch( + "maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", + return_value=mock_config, + ), ): _, _, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices(["dummy", "dummy"]) @@ -189,7 +195,10 @@ def test_get_rollout_kwargs_no_dp(self): "tensor_parallel_size": 2, "expert_parallel_size": 4, } - self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 16), expected_result) + self.assertEqual( + train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 16), + expected_result, + ) @pytest.mark.cpu_only def test_get_rollout_kwargs_auto_tp(self): @@ -204,7 +213,10 @@ def test_get_rollout_kwargs_auto_tp(self): "tensor_parallel_size": 2, "expert_parallel_size": 1, } - self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), expected_result) + self.assertEqual( + train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), + expected_result, + ) @pytest.mark.cpu_only def test_get_rollout_kwargs_fixed_tp_dp(self): @@ -219,7 +231,10 @@ def test_get_rollout_kwargs_fixed_tp_dp(self): "tensor_parallel_size": 2, "expert_parallel_size": 1, } - self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), expected_result) + self.assertEqual( + train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 4), + expected_result, + ) @pytest.mark.cpu_only def test_get_rollout_kwargs_auto_ep(self): @@ -235,7 +250,10 @@ def test_get_rollout_kwargs_auto_ep(self): "tensor_parallel_size": 2, "expert_parallel_size": 2, } - self.assertEqual(train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8), expected_result) + self.assertEqual( + train_rl.get_rollout_kwargs_for_parallelism(sampler_config, 8), + expected_result, + ) @pytest.mark.cpu_only def test_get_rollout_kwargs_errors(self): @@ -307,7 +325,10 @@ def tokenize_side_effect(text): {"question": "short", "answer": "a3"}, {"question": "long", "answer": "a4"}, ] - test_data = [{"question": "short", "answer": "a5"}, {"question": "long", "answer": "a6"}] + test_data = [ + {"question": "short", "answer": "a5"}, + {"question": "long", "answer": "a6"}, + ] train_map_ds = grain.MapDataset.source(train_data) test_map_ds = grain.MapDataset.source(test_data) @@ -334,7 +355,7 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config eval_split="eval", hf_train_files=None, hf_eval_files=None, - chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json", + data_template_path="maxtext/examples/chat_templates/gsm8k_rl.json", data_shuffle_seed=42, max_prefill_predict_length=10, batch_size=2, @@ -346,8 +367,14 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config ) with ( - mock.patch("maxtext.trainers.post_train.rl.train_rl.get_dataset", side_effect=get_dataset_side_effect), - mock.patch("maxtext.trainers.post_train.rl.utils_rl.process_data", side_effect=get_filtered_data_side_effect), + mock.patch( + "maxtext.trainers.post_train.rl.train_rl.get_dataset", + side_effect=get_dataset_side_effect, + ), + mock.patch( + "maxtext.trainers.post_train.rl.utils_rl.process_data", + side_effect=get_filtered_data_side_effect, + ), ): train_dataset, test_dataset = train_rl.prepare_datasets(trainer_config, mock_tokenizer) @@ -378,7 +405,10 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config def test_prepare_datasets_with_split(self, mock_load): mock_ds = mock.MagicMock() mock_split_result = { - "train": [{"question": "q1", "answer": "a1"}, {"question": "q2", "answer": "a2"}], + "train": [ + {"question": "q1", "answer": "a1"}, + {"question": "q2", "answer": "a2"}, + ], "test": [{"question": "q3", "answer": "a3"}], } mock_ds.train_test_split.return_value = mock_split_result @@ -389,7 +419,7 @@ def test_prepare_datasets_with_split(self, mock_load): eval_dataset_name="open-r1/OpenR1-Math-220k", train_split="train", hf_train_files="hf://open-r1/OpenR1-Math-220k/data/dummy.parquet", - chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json", + data_template_path="maxtext/examples/chat_templates/gsm8k_rl.json", data_shuffle_seed=42, num_batches=1, batch_size=5, @@ -435,7 +465,7 @@ def test_prepare_datasets_without_split(self, mock_load): eval_split="test", hf_train_files="hf://openai/gsm8k/data/dummy.parquet", hf_eval_files="hf://openai/gsm8k/data/dummy.parquet", - chat_template_path="maxtext/examples/chat_templates/gsm8k_rl.json", + data_template_path="maxtext/examples/chat_templates/gsm8k_rl.json", data_shuffle_seed=42, num_batches=1, batch_size=5, @@ -480,7 +510,102 @@ def test_rl_train_invalid_vocab_tiling(self, mock_setup): mock_setup.return_value = (mock_config, mock_config, [], []) with self.assertRaisesRegex(ValueError, "Vocab Tiling is not supported with RL"): - train_rl._rl_train_impl([], {}) + train_rl._rl_train_impl([], {}) # pylint: disable=protected-access + + +class TokenizerChatTemplateTest(unittest.TestCase): + """Unit tests for configure_tokenizer_chat_template.""" + + @pytest.mark.cpu_only + def test_chat_template_populated_from_config_string(self): + """Test that chat_template is set from config.chat_template when tokenizer lacks one.""" + mock_tokenizer = mock.MagicMock() + mock_tokenizer.chat_template = None + trainer_config = SimpleNamespace( + chat_template="{{ messages[0].content }}", + chat_template_path=None, + tokenizer_path="dummy-base-model", + ) + train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config) + self.assertEqual(mock_tokenizer.chat_template, "{{ messages[0].content }}") + + @pytest.mark.cpu_only + @mock.patch("maxtext.input_pipeline.instruction_data_processing.load_chat_template_from_file") + def test_chat_template_populated_from_config_file(self, mock_load): + """Test that chat_template is loaded from chat_template_path when tokenizer lacks one.""" + mock_tokenizer = mock.MagicMock() + mock_tokenizer.chat_template = None + mock_load.return_value = "{% for message in messages %}{{ message.content }}{% endfor %}" + trainer_config = SimpleNamespace( + chat_template=None, + chat_template_path="/path/to/jinja_template.json", + tokenizer_path="dummy-base-model", + ) + train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config) + mock_load.assert_called_once_with("/path/to/jinja_template.json") + self.assertEqual( + mock_tokenizer.chat_template, + "{% for message in messages %}{{ message.content }}{% endfor %}", + ) + + @pytest.mark.cpu_only + def test_chat_template_raises_value_error_when_empty(self): + """Test that ValueError is raised when tokenizer lacks chat_template and both config options are empty.""" + mock_tokenizer = mock.MagicMock() + mock_tokenizer.chat_template = None + trainer_config = SimpleNamespace( + chat_template=None, + chat_template_path=None, + tokenizer_path="dummy-base-model", + ) + with self.assertRaisesRegex(ValueError, "Tokenizer 'dummy-base-model' has no chat_template"): + train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config) + + @pytest.mark.cpu_only + def test_chat_template_unchanged_when_already_exists(self): + """Test that an existing chat_template on the tokenizer is preserved (backward compatibility).""" + mock_tokenizer = mock.MagicMock() + mock_tokenizer.chat_template = "{{ existing_template }}" + trainer_config = SimpleNamespace( + chat_template="{{ overridden_template }}", + chat_template_path=None, + tokenizer_path="dummy-instruction-tuned-model", + ) + train_rl.configure_tokenizer_chat_template(mock_tokenizer, trainer_config) + self.assertEqual(mock_tokenizer.chat_template, "{{ existing_template }}") + + @pytest.mark.cpu_only + def test_apply_chat_template_works_after_configuration(self): + """Verifies apply_chat_template succeeds and produces the expected format after our code path runs.""" + + class DummyTokenizer: # pylint: disable=missing-class-docstring + + def __init__(self): + self.chat_template = None + + def apply_chat_template(self, conversation, tokenize=False): + if self.chat_template is None: + raise ValueError("Cannot apply chat template because chat_template is None") + import jinja2 # pylint: disable=import-outside-toplevel + + env = jinja2.Environment() + template = env.from_string(self.chat_template) + return template.render(messages=conversation) + + tokenizer = DummyTokenizer() + trainer_config = SimpleNamespace( + chat_template="{{ messages[0].content }}", + chat_template_path=None, + tokenizer_path="dummy-base-model", + ) + # Initially, apply_chat_template fails (simulating HF tokenizer crash when chat_template is None) + with self.assertRaises(ValueError): + tokenizer.apply_chat_template([{"role": "user", "content": "Hello!"}]) + # Run the proposed change + train_rl.configure_tokenizer_chat_template(tokenizer, trainer_config) + # Verify apply_chat_template now runs successfully and renders correct content + rendered = tokenizer.apply_chat_template([{"role": "user", "content": "Hello!"}]) + self.assertEqual(rendered, "Hello!") if __name__ == "__main__":