Skip to content

Commit 5f3dc2b

Browse files
fix Gemma 4 multimodal chat-template markers in processor_gemma4
The Gemma 4 multimodal SFT path was emitting Gemma 3 chat-template markers ("<start_of_turn>", "<end_of_turn>") which are NOT special tokens in the Gemma 4 tokenizer. They BPE-tokenize into 7-token noise sequences each, so a training label like "A<end_of_turn>" became an 8-token sequence ([236776 'A', 236820 '<', 643 'end', 236779 '_', 1340 'of', 236779 '_', 887 'turn', 236813 '>']). With sft_train_on_completion_only=true the model learned to reproduce this noise sequence after every answer, producing severe response-format collapse post-SFT (e.g. "A<B<C<D<..."). The Gemma 4 chat template uses different special tokens: <bos> (id 2) <|turn> (id 105) <turn|> (id 106) This CL switches the prompt and response formatters to use them. PiperOrigin-RevId: 931396545
1 parent 3190805 commit 5f3dc2b

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxtext/multimodal/processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def reformat_response(response, model_name):
135135
formatted_response = f"{response}<end_of_turn>"
136136
return formatted_response
137137
elif model_name in ["gemma4-26b", "gemma4-31b", "gemma4-e2b", "gemma4-e4b"]:
138-
formatted_response = f"{response}<end_of_turn>"
138+
formatted_response = f"{response}<turn|>"
139139
return formatted_response
140140
elif model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
141141
formatted_response = f"{response}<|im_end|>"

src/maxtext/multimodal/processor_gemma4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def reformat_prompt_gemma4(prompt, image_placeholder, num_images):
9494
image_placeholder_count = prompt.count(GEMMA4_IMAGE_PLACEHOLDER_IN_PROMPT)
9595
if image_placeholder_count < num_images:
9696
prompt = GEMMA4_IMAGE_PLACEHOLDER_IN_PROMPT * (num_images - image_placeholder_count) + prompt
97-
formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
97+
formatted_prompt = f"<bos><|turn>user\n{prompt}<turn|>\n<|turn>model\n"
9898
return formatted_prompt
9999

100100

0 commit comments

Comments
 (0)