@@ -472,5 +472,94 @@ def test_prepare_datasets_without_split(self, mock_load):
472472 assert mock_load .call_count == len (expected_calls )
473473
474474
475+ class TokenizerChatTemplateTest (unittest .TestCase ):
476+ """Unit tests for configure_tokenizer_chat_template."""
477+
478+ @pytest .mark .cpu_only
479+ def test_chat_template_populated_from_config_string (self ):
480+ """Test that chat_template is set from config.chat_template when tokenizer lacks one."""
481+ mock_tokenizer = mock .MagicMock ()
482+ mock_tokenizer .chat_template = None
483+ trainer_config = SimpleNamespace (
484+ chat_template = "{{ messages[0].content }}" ,
485+ tokenizer_chat_template_path = None ,
486+ tokenizer_path = "dummy-base-model" ,
487+ )
488+ train_rl .configure_tokenizer_chat_template (mock_tokenizer , trainer_config )
489+ self .assertEqual (mock_tokenizer .chat_template , "{{ messages[0].content }}" )
490+
491+ @pytest .mark .cpu_only
492+ @mock .patch ("maxtext.input_pipeline.instruction_data_processing.load_chat_template_from_file" )
493+ def test_chat_template_populated_from_config_file (self , mock_load ):
494+ """Test that chat_template is loaded from tokenizer_chat_template_path when tokenizer lacks one."""
495+ mock_tokenizer = mock .MagicMock ()
496+ mock_tokenizer .chat_template = None
497+ mock_load .return_value = "{% for message in messages %}{{ message.content }}{% endfor %}"
498+ trainer_config = SimpleNamespace (
499+ chat_template = None ,
500+ tokenizer_chat_template_path = "/path/to/jinja_template.json" ,
501+ tokenizer_path = "dummy-base-model" ,
502+ )
503+ train_rl .configure_tokenizer_chat_template (mock_tokenizer , trainer_config )
504+ mock_load .assert_called_once_with ("/path/to/jinja_template.json" )
505+ self .assertEqual (
506+ mock_tokenizer .chat_template ,
507+ "{% for message in messages %}{{ message.content }}{% endfor %}" ,
508+ )
509+ @pytest .mark .cpu_only
510+ def test_chat_template_raises_value_error_when_empty (self ):
511+ """Test that ValueError is raised when tokenizer lacks chat_template and both config options are empty."""
512+ mock_tokenizer = mock .MagicMock ()
513+ mock_tokenizer .chat_template = None
514+ trainer_config = SimpleNamespace (
515+ chat_template = None ,
516+ tokenizer_chat_template_path = None ,
517+ tokenizer_path = "dummy-base-model" ,
518+ )
519+ with self .assertRaisesRegex (ValueError , "Tokenizer 'dummy-base-model' has no chat_template" ):
520+ train_rl .configure_tokenizer_chat_template (mock_tokenizer , trainer_config )
521+
522+ @pytest .mark .cpu_only
523+ def test_chat_template_unchanged_when_already_exists (self ):
524+ """Test that an existing chat_template on the tokenizer is preserved (backward compatibility)."""
525+ mock_tokenizer = mock .MagicMock ()
526+ mock_tokenizer .chat_template = "{{ existing_template }}"
527+ trainer_config = SimpleNamespace (
528+ chat_template = "{{ overridden_template }}" ,
529+ tokenizer_chat_template_path = None ,
530+ tokenizer_path = "dummy-instruction-tuned-model" ,
531+ )
532+ train_rl .configure_tokenizer_chat_template (mock_tokenizer , trainer_config )
533+ self .assertEqual (mock_tokenizer .chat_template , "{{ existing_template }}" )
534+
535+ @pytest .mark .cpu_only
536+ def test_apply_chat_template_works_after_configuration (self ):
537+ """Verifies apply_chat_template succeeds and produces the expected format after our code path runs."""
538+ class DummyTokenizer :
539+ def __init__ (self ):
540+ self .chat_template = None
541+
542+ def apply_chat_template (self , conversation , tokenize = False ):
543+ if self .chat_template is None :
544+ raise ValueError ("Cannot apply chat template because chat_template is None" )
545+ import jinja2
546+ env = jinja2 .Environment ()
547+ template = env .from_string (self .chat_template )
548+ return template .render (messages = conversation )
549+ tokenizer = DummyTokenizer ()
550+ trainer_config = SimpleNamespace (
551+ chat_template = "{{ messages[0].content }}" ,
552+ tokenizer_chat_template_path = None ,
553+ tokenizer_path = "dummy-base-model" ,
554+ )
555+ # Initially, apply_chat_template fails (simulating HF tokenizer crash when chat_template is None)
556+ with self .assertRaises (ValueError ):
557+ tokenizer .apply_chat_template ([{"role" : "user" , "content" : "Hello!" }])
558+ # Run the proposed change
559+ train_rl .configure_tokenizer_chat_template (tokenizer , trainer_config )
560+ # Verify apply_chat_template now runs successfully and renders correct content
561+ rendered = tokenizer .apply_chat_template ([{"role" : "user" , "content" : "Hello!" }])
562+ self .assertEqual (rendered , "Hello!" )
563+
475564if __name__ == "__main__" :
476565 unittest .main ()
0 commit comments