@@ -58,7 +58,10 @@ def test_setup_configs_and_devices_pathways_split(self):
5858 # Following the pattern in distillation_checkpointing_test.py for mocking jax objects
5959 with (
6060 mock .patch .object (jax , "devices" , return_value = mock_devices ),
61- mock .patch ("maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic" , return_value = mock_config ),
61+ mock .patch (
62+ "maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic" ,
63+ return_value = mock_config ,
64+ ),
6265 ):
6366 trainer_config , sampler_config , trainer_devices , sampler_devices = model_creation_utils .setup_configs_and_devices (
6467 ["dummy" , "dummy" ]
@@ -87,7 +90,10 @@ def test_setup_configs_and_devices_pathways_fractional_split(self):
8790
8891 with (
8992 mock .patch .object (jax , "devices" , return_value = mock_devices ),
90- mock .patch ("maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic" , return_value = mock_config ),
93+ mock .patch (
94+ "maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic" ,
95+ return_value = mock_config ,
96+ ),
9197 ):
9298 _ , _ , trainer_devices , sampler_devices = model_creation_utils .setup_configs_and_devices (["dummy" , "dummy" ])
9399
@@ -189,7 +195,10 @@ def test_get_rollout_kwargs_no_dp(self):
189195 "tensor_parallel_size" : 2 ,
190196 "expert_parallel_size" : 4 ,
191197 }
192- self .assertEqual (train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 16 ), expected_result )
198+ self .assertEqual (
199+ train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 16 ),
200+ expected_result ,
201+ )
193202
194203 @pytest .mark .cpu_only
195204 def test_get_rollout_kwargs_auto_tp (self ):
@@ -204,7 +213,10 @@ def test_get_rollout_kwargs_auto_tp(self):
204213 "tensor_parallel_size" : 2 ,
205214 "expert_parallel_size" : 1 ,
206215 }
207- self .assertEqual (train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 4 ), expected_result )
216+ self .assertEqual (
217+ train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 4 ),
218+ expected_result ,
219+ )
208220
209221 @pytest .mark .cpu_only
210222 def test_get_rollout_kwargs_fixed_tp_dp (self ):
@@ -219,7 +231,10 @@ def test_get_rollout_kwargs_fixed_tp_dp(self):
219231 "tensor_parallel_size" : 2 ,
220232 "expert_parallel_size" : 1 ,
221233 }
222- self .assertEqual (train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 4 ), expected_result )
234+ self .assertEqual (
235+ train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 4 ),
236+ expected_result ,
237+ )
223238
224239 @pytest .mark .cpu_only
225240 def test_get_rollout_kwargs_auto_ep (self ):
@@ -235,7 +250,10 @@ def test_get_rollout_kwargs_auto_ep(self):
235250 "tensor_parallel_size" : 2 ,
236251 "expert_parallel_size" : 2 ,
237252 }
238- self .assertEqual (train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 8 ), expected_result )
253+ self .assertEqual (
254+ train_rl .get_rollout_kwargs_for_parallelism (sampler_config , 8 ),
255+ expected_result ,
256+ )
239257
240258 @pytest .mark .cpu_only
241259 def test_get_rollout_kwargs_errors (self ):
@@ -307,7 +325,10 @@ def tokenize_side_effect(text):
307325 {"question" : "short" , "answer" : "a3" },
308326 {"question" : "long" , "answer" : "a4" },
309327 ]
310- test_data = [{"question" : "short" , "answer" : "a5" }, {"question" : "long" , "answer" : "a6" }]
328+ test_data = [
329+ {"question" : "short" , "answer" : "a5" },
330+ {"question" : "long" , "answer" : "a6" },
331+ ]
311332 train_map_ds = grain .MapDataset .source (train_data )
312333 test_map_ds = grain .MapDataset .source (test_data )
313334
@@ -334,7 +355,7 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config
334355 eval_split = "eval" ,
335356 hf_train_files = None ,
336357 hf_eval_files = None ,
337- chat_template_path = "maxtext/examples/chat_templates/gsm8k_rl.json" ,
358+ data_template_path = "maxtext/examples/chat_templates/gsm8k_rl.json" ,
338359 data_shuffle_seed = 42 ,
339360 max_prefill_predict_length = 10 ,
340361 batch_size = 2 ,
@@ -346,8 +367,14 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config
346367 )
347368
348369 with (
349- mock .patch ("maxtext.trainers.post_train.rl.train_rl.get_dataset" , side_effect = get_dataset_side_effect ),
350- mock .patch ("maxtext.trainers.post_train.rl.utils_rl.process_data" , side_effect = get_filtered_data_side_effect ),
370+ mock .patch (
371+ "maxtext.trainers.post_train.rl.train_rl.get_dataset" ,
372+ side_effect = get_dataset_side_effect ,
373+ ),
374+ mock .patch (
375+ "maxtext.trainers.post_train.rl.utils_rl.process_data" ,
376+ side_effect = get_filtered_data_side_effect ,
377+ ),
351378 ):
352379 train_dataset , test_dataset = train_rl .prepare_datasets (trainer_config , mock_tokenizer )
353380
@@ -378,7 +405,10 @@ def get_filtered_data_side_effect(dataset_name, model_tokenizer, template_config
378405 def test_prepare_datasets_with_split (self , mock_load ):
379406 mock_ds = mock .MagicMock ()
380407 mock_split_result = {
381- "train" : [{"question" : "q1" , "answer" : "a1" }, {"question" : "q2" , "answer" : "a2" }],
408+ "train" : [
409+ {"question" : "q1" , "answer" : "a1" },
410+ {"question" : "q2" , "answer" : "a2" },
411+ ],
382412 "test" : [{"question" : "q3" , "answer" : "a3" }],
383413 }
384414 mock_ds .train_test_split .return_value = mock_split_result
@@ -389,7 +419,7 @@ def test_prepare_datasets_with_split(self, mock_load):
389419 eval_dataset_name = "open-r1/OpenR1-Math-220k" ,
390420 train_split = "train" ,
391421 hf_train_files = "hf://open-r1/OpenR1-Math-220k/data/dummy.parquet" ,
392- chat_template_path = "maxtext/examples/chat_templates/gsm8k_rl.json" ,
422+ data_template_path = "maxtext/examples/chat_templates/gsm8k_rl.json" ,
393423 data_shuffle_seed = 42 ,
394424 num_batches = 1 ,
395425 batch_size = 5 ,
@@ -435,7 +465,7 @@ def test_prepare_datasets_without_split(self, mock_load):
435465 eval_split = "test" ,
436466 hf_train_files = "hf://openai/gsm8k/data/dummy.parquet" ,
437467 hf_eval_files = "hf://openai/gsm8k/data/dummy.parquet" ,
438- chat_template_path = "maxtext/examples/chat_templates/gsm8k_rl.json" ,
468+ data_template_path = "maxtext/examples/chat_templates/gsm8k_rl.json" ,
439469 data_shuffle_seed = 42 ,
440470 num_batches = 1 ,
441471 batch_size = 5 ,
@@ -496,5 +526,100 @@ def test_rl_train_invalid_optimizer_memory_host_offload(self, mock_setup):
496526 train_rl ._rl_train_impl ([], {}) # pylint: disable=protected-access
497527
498528
529+ class TokenizerChatTemplateTest (unittest .TestCase ):
530+ """Unit tests for configure_tokenizer_chat_template."""
531+
532+ @pytest .mark .cpu_only
533+ def test_chat_template_populated_from_config_string (self ):
534+ """Test that chat_template is set from config.chat_template when tokenizer lacks one."""
535+ mock_tokenizer = mock .MagicMock ()
536+ mock_tokenizer .chat_template = None
537+ trainer_config = SimpleNamespace (
538+ chat_template = "{{ messages[0].content }}" ,
539+ chat_template_path = None ,
540+ tokenizer_path = "dummy-base-model" ,
541+ )
542+ train_rl .configure_tokenizer_chat_template (mock_tokenizer , trainer_config )
543+ self .assertEqual (mock_tokenizer .chat_template , "{{ messages[0].content }}" )
544+
545+ @pytest .mark .cpu_only
546+ @mock .patch ("maxtext.input_pipeline.instruction_data_processing.load_chat_template_from_file" )
547+ def test_chat_template_populated_from_config_file (self , mock_load ):
548+ """Test that chat_template is loaded from chat_template_path when tokenizer lacks one."""
549+ mock_tokenizer = mock .MagicMock ()
550+ mock_tokenizer .chat_template = None
551+ mock_load .return_value = "{% for message in messages %}{{ message.content }}{% endfor %}"
552+ trainer_config = SimpleNamespace (
553+ chat_template = None ,
554+ chat_template_path = "/path/to/jinja_template.json" ,
555+ tokenizer_path = "dummy-base-model" ,
556+ )
557+ train_rl .configure_tokenizer_chat_template (mock_tokenizer , trainer_config )
558+ mock_load .assert_called_once_with ("/path/to/jinja_template.json" )
559+ self .assertEqual (
560+ mock_tokenizer .chat_template ,
561+ "{% for message in messages %}{{ message.content }}{% endfor %}" ,
562+ )
563+
564+ @pytest .mark .cpu_only
565+ def test_chat_template_raises_value_error_when_empty (self ):
566+ """Test that ValueError is raised when tokenizer lacks chat_template and both config options are empty."""
567+ mock_tokenizer = mock .MagicMock ()
568+ mock_tokenizer .chat_template = None
569+ trainer_config = SimpleNamespace (
570+ chat_template = None ,
571+ chat_template_path = None ,
572+ tokenizer_path = "dummy-base-model" ,
573+ )
574+ with self .assertRaisesRegex (ValueError , "Tokenizer 'dummy-base-model' has no chat_template" ):
575+ train_rl .configure_tokenizer_chat_template (mock_tokenizer , trainer_config )
576+
577+ @pytest .mark .cpu_only
578+ def test_chat_template_unchanged_when_already_exists (self ):
579+ """Test that an existing chat_template on the tokenizer is preserved (backward compatibility)."""
580+ mock_tokenizer = mock .MagicMock ()
581+ mock_tokenizer .chat_template = "{{ existing_template }}"
582+ trainer_config = SimpleNamespace (
583+ chat_template = "{{ overridden_template }}" ,
584+ chat_template_path = None ,
585+ tokenizer_path = "dummy-instruction-tuned-model" ,
586+ )
587+ train_rl .configure_tokenizer_chat_template (mock_tokenizer , trainer_config )
588+ self .assertEqual (mock_tokenizer .chat_template , "{{ existing_template }}" )
589+
590+ @pytest .mark .cpu_only
591+ def test_apply_chat_template_works_after_configuration (self ):
592+ """Verifies apply_chat_template succeeds and produces the expected format after our code path runs."""
593+
594+ class DummyTokenizer : # pylint: disable=missing-class-docstring
595+
596+ def __init__ (self ):
597+ self .chat_template = None
598+
599+ def apply_chat_template (self , conversation , tokenize = False ):
600+ if self .chat_template is None :
601+ raise ValueError ("Cannot apply chat template because chat_template is None" )
602+ import jinja2 # pylint: disable=import-outside-toplevel
603+
604+ env = jinja2 .Environment ()
605+ template = env .from_string (self .chat_template )
606+ return template .render (messages = conversation )
607+
608+ tokenizer = DummyTokenizer ()
609+ trainer_config = SimpleNamespace (
610+ chat_template = "{{ messages[0].content }}" ,
611+ chat_template_path = None ,
612+ tokenizer_path = "dummy-base-model" ,
613+ )
614+ # Initially, apply_chat_template fails (simulating HF tokenizer crash when chat_template is None)
615+ with self .assertRaises (ValueError ):
616+ tokenizer .apply_chat_template ([{"role" : "user" , "content" : "Hello!" }])
617+ # Run the proposed change
618+ train_rl .configure_tokenizer_chat_template (tokenizer , trainer_config )
619+ # Verify apply_chat_template now runs successfully and renders correct content
620+ rendered = tokenizer .apply_chat_template ([{"role" : "user" , "content" : "Hello!" }])
621+ self .assertEqual (rendered , "Hello!" )
622+
623+
499624if __name__ == "__main__" :
500625 unittest .main ()
0 commit comments