Skip to content

Commit b277c40

Browse files
committed
Qwen3-Omni SFT+Eval
1 parent a1fb834 commit b277c40

5 files changed

Lines changed: 102 additions & 21 deletions

File tree

benchmarks/multimodal/multimodal_eval.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,13 @@ def construct_prompt(
150150
question=parsed_dataset_example.question,
151151
choices=choices_text if choices_text else "N/A",
152152
)
153+
if config.use_multimodal and "qwen3-omni" in config.model_name:
154+
prompt = mm_processor.reformat_prompt(
155+
prompt,
156+
image_placeholder,
157+
config.model_name,
158+
num_images=1,
159+
)
153160
elif local_args.ckpt_type == "sft":
154161
prompt = mm_processor.reformat_prompt(
155162
parsed_dataset_example.question,
@@ -200,11 +207,31 @@ def main(config, local_args):
200207
print("\n" + "*" * 50)
201208

202209
# Tokenize the input
203-
tokens, true_length = tokenizer.encode(prompt, is_bos=True, prefill_lengths=[prefill_length])
210+
is_bos = config.add_bos and getattr(tokenizer, "bos_id", None) is not None
211+
tokens, true_length = tokenizer.encode(prompt, is_bos=is_bos, prefill_lengths=[prefill_length])
212+
position_ids = None
213+
mrope_position_deltas = None
214+
204215
if config.use_multimodal:
205216
tokens = mm_processor.prepare_text_for_image_fusion(tokens=tokens, config=config, processor_output=processor_output)
206217
image_offsets = mm_processor.get_image_offsets(config=config, processor_output=processor_output)
207218
true_length += image_offsets
219+
220+
if config.use_mrope:
221+
from maxtext.multimodal import processor_qwen3_omni # pylint: disable=import-outside-toplevel
222+
223+
position_ids, mrope_position_deltas = processor_qwen3_omni.get_rope_index(
224+
input_ids=tokens[np.newaxis, :], # Add batch dimension for processing
225+
image_grid_thw=processor_output.pixel_grid_thw, # pytype: disable=attribute-error
226+
video_grid_thw=processor_output.video_grid_thw, # pytype: disable=attribute-error
227+
attention_mask=np.ones_like(tokens)[np.newaxis, :],
228+
use_audio_in_video=config.use_audio and getattr(processor_output, "num_videos", 0) > 0,
229+
audio_lengths=processor_output.audio_lengths, # pytype: disable=attribute-error
230+
second_per_grids=processor_output.video_second_per_grid, # pytype: disable=attribute-error
231+
spatial_merge_size=config.spatial_merge_size_for_vit, # pytype: disable=attribute-error
232+
position_id_per_seconds=config.position_id_per_seconds,
233+
)
234+
208235
if true_length > max_prefill_predict_length:
209236
max_logging.log(
210237
f"Warning: Prompt length {true_length} exceeds max prefill length" f" {max_prefill_predict_length}. Truncating."
@@ -216,7 +243,18 @@ def main(config, local_args):
216243

217244
# Perform prefill
218245
prefill_result, first_token = engine.prefill(
219-
params=params, padded_tokens=tokens, images=processor_output.pixel_values, true_length=true_length
246+
params=params,
247+
padded_tokens=tokens,
248+
positions=position_ids,
249+
mrope_deltas=mrope_position_deltas,
250+
images=processor_output.pixel_values if config.use_multimodal else None,
251+
image_masks=getattr(processor_output, "pixel_mask", None)
252+
if config.use_multimodal and "llama4" in config.model_name
253+
else None,
254+
audio_values=getattr(processor_output, "audio_values", None) if config.use_audio else None,
255+
audio_masks=getattr(processor_output, "audio_mask", None) if config.use_audio else None,
256+
true_length=true_length,
257+
slot=0,
220258
)
221259
slot = 0
222260

@@ -243,8 +281,9 @@ def main(config, local_args):
243281
break
244282

245283
correct_answer = parsed_dataset_example.answer
284+
# If fails to parse answer, use the raw output as the predicted answer for correctness checking
246285
if predicted_answer == utils_rl.FALLBACK_ANSWER:
247-
predicted_answer = utils_rl.extract_answer(output, tmvp_config)
286+
predicted_answer = output
248287

249288
exact_correct, _ = utils_rl.check_correctness(predicted_answer, [correct_answer], tmvp_config)
250289
is_correct = exact_correct

src/maxtext/input_pipeline/hf_data_processing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def vision_sft_preprocessing_pipeline(
118118
add_eos_token=False,
119119
legacy=False,
120120
token=config.hf_access_token,
121+
extra_special_tokens={},
121122
)
122123
pad_id = _get_pad_id(tokenizer)
123124

@@ -256,6 +257,7 @@ def preprocessing_pipeline(
256257
add_eos_token=add_eos if not use_sft else False,
257258
legacy=False,
258259
token=hf_access_token,
260+
extra_special_tokens={},
259261
)
260262

261263
dataset = dataset.select_columns(data_column_names)

src/maxtext/input_pipeline/input_pipeline_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,9 @@ def _pad_image_and_mask(self, preprocessed_image: mm_utils.PreprocessorOutput) -
745745
if preprocessed_image.pixel_values is None:
746746
raise ValueError("Input preprocessed_image must have pixel_values to pad images.")
747747

748+
if self.config.model_name and self.config.model_name.startswith("qwen3-omni"):
749+
return preprocessed_image
750+
748751
# Determine the maximum number of images/masks allowed.
749752
image_offsets = mm_processor.get_image_offsets(self.config, preprocessed_image)
750753
single_image_offset = image_offsets // preprocessed_image.pixel_values.shape[0]

src/maxtext/multimodal/processor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def preprocess_image_for_training(image, model_name):
6868
from maxtext.multimodal.processor_llama4 import preprocess_mm_data_llama4 # pylint: disable=import-outside-toplevel
6969

7070
return preprocess_mm_data_llama4(image)
71+
elif model_name in ["qwen3-omni-30b-a3b"]:
72+
from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training # pylint: disable=import-outside-toplevel
73+
74+
return preprocess_mm_data_qwen3_omni_for_training(image)
7175
else:
7276
raise ValueError(f"Model {model_name} not supported for image preprocessing.")
7377

src/maxtext/multimodal/processor_qwen3_omni.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def smart_resize(
122122
return h_bar, w_bar
123123

124124

125-
def pre_process_qwen3_image(image: np.ndarray | list[np.ndarray], config):
125+
def pre_process_qwen3_image(image: np.ndarray | list[np.ndarray], config, force_resize=None):
126126
"""Performs a bi-linear resize (with anti-aliasing) and normalizes the image."""
127127
patch_size = config.patch_size_for_vit
128128
merge_size = config.spatial_merge_size_for_vit
@@ -135,23 +135,27 @@ def pre_process_qwen3_image(image: np.ndarray | list[np.ndarray], config):
135135

136136
for img in images_in:
137137
pil_img = Image.fromarray(img)
138-
# Qwen3-Omni performs one resize during fetch_image and another resize before patchify.
139-
resized_height_1, resized_width_1 = smart_resize(
140-
height=img.shape[0],
141-
width=img.shape[1],
142-
factor=IMAGE_FACTOR,
143-
min_pixels=MIN_PIXELS,
144-
max_pixels=MAX_PIXELS,
145-
)
146-
pil_img = pil_img.resize((resized_width_1, resized_height_1))
147-
resized_height_2, resized_width_2 = smart_resize(
148-
height=resized_height_1,
149-
width=resized_width_1,
150-
factor=patch_size * merge_size,
151-
min_pixels=MIN_PIXELS,
152-
max_pixels=MAX_PIXELS,
153-
)
154-
resized_img_pil = pil_img.resize((resized_width_2, resized_height_2), resample=resample_method)
138+
if force_resize is not None:
139+
resized_height_2, resized_width_2 = force_resize
140+
resized_img_pil = pil_img.resize((resized_width_2, resized_height_2), resample=resample_method)
141+
else:
142+
# Qwen3-Omni performs one resize during fetch_image and another resize before patchify.
143+
resized_height_1, resized_width_1 = smart_resize(
144+
height=img.shape[0],
145+
width=img.shape[1],
146+
factor=IMAGE_FACTOR,
147+
min_pixels=MIN_PIXELS,
148+
max_pixels=MAX_PIXELS,
149+
)
150+
pil_img = pil_img.resize((resized_width_1, resized_height_1))
151+
resized_height_2, resized_width_2 = smart_resize(
152+
height=resized_height_1,
153+
width=resized_width_1,
154+
factor=patch_size * merge_size,
155+
min_pixels=MIN_PIXELS,
156+
max_pixels=MAX_PIXELS,
157+
)
158+
resized_img_pil = pil_img.resize((resized_width_2, resized_height_2), resample=resample_method)
155159
resized_img_np = np.array(resized_img_pil).astype(np.float32)
156160

157161
img_np = mm_utils.normalize_images(resized_img_np, mean=IMAGE_MEAN, std=IMAGE_STD)
@@ -474,6 +478,35 @@ def pre_process_audio_qwen3_omni(audio_array):
474478
return audio_features, audio_features_mask
475479

476480

481+
def preprocess_mm_data_qwen3_omni_for_training(images):
482+
"""Preprocesses image(s) for Qwen3-Omni SFT training using default model constants."""
483+
484+
class _DefaultConfig:
485+
patch_size_for_vit = 16
486+
spatial_merge_size_for_vit = 2
487+
temporal_patch_size_for_vit = QWEN3_TEMPORAL_PATCH_SIZE
488+
489+
images_in = [images] if isinstance(images, np.ndarray) else images
490+
pixel_values, pixel_grid_thw = pre_process_qwen3_image(
491+
images_in, _DefaultConfig(), force_resize=(QWEN3_OMNI_IMAGE_SIZE, QWEN3_OMNI_IMAGE_SIZE)
492+
)
493+
pixel_values = np.reshape(
494+
pixel_values,
495+
(
496+
len(images_in),
497+
3, # num_channels_for_vit
498+
_DefaultConfig.temporal_patch_size_for_vit * pixel_grid_thw[0, 0],
499+
_DefaultConfig.patch_size_for_vit * pixel_grid_thw[0, 1],
500+
_DefaultConfig.patch_size_for_vit * pixel_grid_thw[0, 2],
501+
),
502+
)
503+
return Qwen3OmniPreprocessorOutput(
504+
num_images=len(images_in),
505+
pixel_values=pixel_values,
506+
pixel_grid_thw=pixel_grid_thw,
507+
)
508+
509+
477510
def preprocess_mm_data_qwen3_omni(config):
478511
"""Placeholder for multimodal data preprocessing."""
479512
processor_outputs = Qwen3OmniPreprocessorOutput()

0 commit comments

Comments
 (0)