@@ -150,6 +150,13 @@ def construct_prompt(
150150 question = parsed_dataset_example .question ,
151151 choices = choices_text if choices_text else "N/A" ,
152152 )
153+ if config .use_multimodal and "qwen3-omni" in config .model_name :
154+ prompt = mm_processor .reformat_prompt (
155+ prompt ,
156+ image_placeholder ,
157+ config .model_name ,
158+ num_images = 1 ,
159+ )
153160 elif local_args .ckpt_type == "sft" :
154161 prompt = mm_processor .reformat_prompt (
155162 parsed_dataset_example .question ,
@@ -200,11 +207,31 @@ def main(config, local_args):
200207 print ("\n " + "*" * 50 )
201208
202209 # Tokenize the input
203- tokens , true_length = tokenizer .encode (prompt , is_bos = True , prefill_lengths = [prefill_length ])
210+ is_bos = config .add_bos and getattr (tokenizer , "bos_id" , None ) is not None
211+ tokens , true_length = tokenizer .encode (prompt , is_bos = is_bos , prefill_lengths = [prefill_length ])
212+ position_ids = None
213+ mrope_position_deltas = None
214+
204215 if config .use_multimodal :
205216 tokens = mm_processor .prepare_text_for_image_fusion (tokens = tokens , config = config , processor_output = processor_output )
206217 image_offsets = mm_processor .get_image_offsets (config = config , processor_output = processor_output )
207218 true_length += image_offsets
219+
220+ if config .use_mrope :
221+ from maxtext .multimodal import processor_qwen3_omni # pylint: disable=import-outside-toplevel
222+
223+ position_ids , mrope_position_deltas = processor_qwen3_omni .get_rope_index (
224+ input_ids = tokens [np .newaxis , :], # Add batch dimension for processing
225+ image_grid_thw = processor_output .pixel_grid_thw , # pytype: disable=attribute-error
226+ video_grid_thw = processor_output .video_grid_thw , # pytype: disable=attribute-error
227+ attention_mask = np .ones_like (tokens )[np .newaxis , :],
228+ use_audio_in_video = config .use_audio and getattr (processor_output , "num_videos" , 0 ) > 0 ,
229+ audio_lengths = processor_output .audio_lengths , # pytype: disable=attribute-error
230+ second_per_grids = processor_output .video_second_per_grid , # pytype: disable=attribute-error
231+ spatial_merge_size = config .spatial_merge_size_for_vit , # pytype: disable=attribute-error
232+ position_id_per_seconds = config .position_id_per_seconds ,
233+ )
234+
208235 if true_length > max_prefill_predict_length :
209236 max_logging .log (
210237 f"Warning: Prompt length { true_length } exceeds max prefill length" f" { max_prefill_predict_length } . Truncating."
@@ -216,7 +243,18 @@ def main(config, local_args):
216243
217244 # Perform prefill
218245 prefill_result , first_token = engine .prefill (
219- params = params , padded_tokens = tokens , images = processor_output .pixel_values , true_length = true_length
246+ params = params ,
247+ padded_tokens = tokens ,
248+ positions = position_ids ,
249+ mrope_deltas = mrope_position_deltas ,
250+ images = processor_output .pixel_values if config .use_multimodal else None ,
251+ image_masks = getattr (processor_output , "pixel_mask" , None )
252+ if config .use_multimodal and "llama4" in config .model_name
253+ else None ,
254+ audio_values = getattr (processor_output , "audio_values" , None ) if config .use_audio else None ,
255+ audio_masks = getattr (processor_output , "audio_mask" , None ) if config .use_audio else None ,
256+ true_length = true_length ,
257+ slot = 0 ,
220258 )
221259 slot = 0
222260
@@ -243,8 +281,9 @@ def main(config, local_args):
243281 break
244282
245283 correct_answer = parsed_dataset_example .answer
284+ # If fails to parse answer, use the raw output as the predicted answer for correctness checking
246285 if predicted_answer == utils_rl .FALLBACK_ANSWER :
247- predicted_answer = utils_rl . extract_answer ( output , tmvp_config )
286+ predicted_answer = output
248287
249288 exact_correct , _ = utils_rl .check_correctness (predicted_answer , [correct_answer ], tmvp_config )
250289 is_correct = exact_correct
0 commit comments