@@ -495,29 +495,48 @@ class SFTChatTemplateLogicTest(unittest.TestCase):
495495 def setUpClass (cls ):
496496 super ().setUpClass ()
497497 if not os .path .exists (cls .LLAMA_TOKENIZER_PATH ):
498- exit_code = subprocess .call (
499- [
500- "gcloud" ,
501- "storage" ,
502- "cp" ,
503- "-r" ,
504- "gs://maxtext-dataset/hf/llama2-chat-tokenizer" ,
505- os .path .join (MAXTEXT_ASSETS_ROOT , "" ),
506- ]
507- )
508- if exit_code != 0 :
509- raise ValueError ("Failed to download llama tokenizer" )
498+ try :
499+ subprocess .call (
500+ [
501+ "gcloud" ,
502+ "storage" ,
503+ "cp" ,
504+ "-r" ,
505+ "gs://maxtext-dataset/hf/llama2-chat-tokenizer" ,
506+ os .path .join (MAXTEXT_ASSETS_ROOT , "" ),
507+ ]
508+ )
509+ except Exception : # pylint: disable=broad-except
510+ pass
510511
511512 def setUp (self ):
512513 super ().setUp ()
513514 self .qwen3_tokenizer = transformers .AutoTokenizer .from_pretrained ("Qwen/Qwen3-4B" )
514- self .llama2_tokenizer = transformers .AutoTokenizer .from_pretrained (self .LLAMA_TOKENIZER_PATH )
515+ try :
516+ self .llama2_tokenizer = transformers .AutoTokenizer .from_pretrained (self .LLAMA_TOKENIZER_PATH )
517+ except Exception : # pylint: disable=broad-except
518+ self .llama2_tokenizer = transformers .AutoTokenizer .from_pretrained ("NousResearch/Llama-2-7b-chat-hf" )
519+ self .llama2_tokenizer .chat_template = (
520+ "{% for message in messages %}"
521+ "{% if message['role'] == 'user' %}"
522+ "{{ bos_token + '[INST] ' + message['content'] | trim + ' [/INST]' }}"
523+ "{% elif message['role'] == 'system' %}"
524+ "{{ '<<SYS>>\\ n' + message['content'] | trim + '\\ n<</SYS>>\\ n\\ n' }}"
525+ "{% elif message['role'] == 'assistant' %}"
526+ "{{ ' ' + message['content'] | trim + ' ' + eos_token }}"
527+ "{% endif %}"
528+ "{% endfor %}"
529+ )
530+ self .gemma4_tokenizer = transformers .AutoTokenizer .from_pretrained ("google/gemma-4-26b-a4b-it" )
515531
516532 def test_tokenizer_w_generation_prompt (self ):
517- verify_chat_template_generation_prompt_logic (self .qwen3_tokenizer )
533+ self . assertTrue ( verify_chat_template_generation_prompt_logic (self .qwen3_tokenizer ) )
518534
519535 def test_tokenizer_wo_generation_prompt (self ):
520- verify_chat_template_generation_prompt_logic (self .llama2_tokenizer )
536+ self .assertTrue (verify_chat_template_generation_prompt_logic (self .llama2_tokenizer ))
537+
538+ def test_tokenizer_gemma4_w_thought_channel (self ):
539+ self .assertTrue (verify_chat_template_generation_prompt_logic (self .gemma4_tokenizer ))
521540
522541 def test_failure_path_with_modified_template (self ):
523542 """Verifies the function correctly raises a ValueError on a bad template."""
0 commit comments