Skip to content

Commit 8c4dced

Browse files
Merge pull request #3995 from AI-Hypercomputer:hengtaoguo-va
PiperOrigin-RevId: 923605138
2 parents af334f1 + 7598c46 commit 8c4dced

8 files changed

Lines changed: 115 additions & 24 deletions

File tree

src/maxtext/common/common_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,12 @@ class MultimodalInput:
8080

8181
image_embeddings: Array | None = None
8282
image_masks: Array | None = None
83+
video_embeddings: Array | None = None
84+
video_masks: Array | None = None
8385
audio_embeddings: Array | None = None
8486
audio_masks: Array | None = None
8587
bidirectional_mask: Array | None = None
88+
bidirectional_mask_video: Array | None = None
8689

8790

8891
class DecoderBlockType(enum.Enum):

src/maxtext/inference/decode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def main(argv: Sequence[str]) -> None:
184184
mrope_deltas=mrope_position_deltas,
185185
images=processor_outputs.pixel_values if config.use_multimodal else None,
186186
image_masks=processor_outputs.pixel_mask if config.use_multimodal and "llama4" in config.model_name else None,
187+
videos=getattr(processor_outputs, "video_values", None) if config.use_multimodal else None,
187188
audio_values=processor_outputs.audio_values if config.use_audio else None,
188189
audio_masks=processor_outputs.audio_mask if config.use_audio else None,
189190
true_length=true_length,

src/maxtext/inference/maxengine/maxengine.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,8 @@ def _prefill_jit(
408408
mrope_deltas: jax.Array | None = None,
409409
images: jax.Array | None = None,
410410
image_masks: jax.Array | None = None,
411+
videos: jax.Array | None = None,
412+
video_masks: jax.Array | None = None,
411413
audio_values: jax.Array | None = None,
412414
audio_masks: jax.Array | None = None,
413415
true_length: int,
@@ -504,6 +506,8 @@ def _prefill_jit(
504506
positions,
505507
encoder_images=images,
506508
encoder_image_masks=image_masks,
509+
encoder_videos=videos,
510+
encoder_video_masks=video_masks,
507511
encoder_audios=audio_values,
508512
decoder_segment_ids=sequence_indicator,
509513
enable_dropout=False,
@@ -586,6 +590,8 @@ def prefill(
586590
mrope_deltas: jax.Array | None = None,
587591
images: jax.Array | None = None,
588592
image_masks: jax.Array | None = None,
593+
videos: jax.Array | None = None,
594+
video_masks: jax.Array | None = None,
589595
audio_values: jax.Array | None = None,
590596
audio_masks: jax.Array | None = None,
591597
true_length: int,
@@ -617,6 +623,8 @@ def prefill(
617623
mrope_deltas=mrope_deltas,
618624
images=images,
619625
image_masks=image_masks,
626+
videos=videos,
627+
video_masks=video_masks,
620628
audio_values=audio_values,
621629
audio_masks=audio_masks,
622630
sampler=sampler,

src/maxtext/layers/decoders.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,9 @@ def _apply_embedding(
642642
image_embeddings = multimodal_input.image_embeddings
643643
bidirectional_mask = multimodal_input.bidirectional_mask
644644
image_masks = multimodal_input.image_masks
645+
video_embeddings = getattr(multimodal_input, "video_embeddings", None)
646+
video_masks = getattr(multimodal_input, "video_masks", None)
647+
bidirectional_mask_video = getattr(multimodal_input, "bidirectional_mask_video", None)
645648
audio_embeddings = multimodal_input.audio_embeddings
646649
audio_masks = multimodal_input.audio_masks
647650

@@ -669,6 +672,17 @@ def _apply_embedding(
669672
else:
670673
raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}")
671674

675+
if video_embeddings is not None and cfg.use_multimodal:
676+
if cfg.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]:
677+
y = mm_utils.merge_mm_embeddings(
678+
text_embeddings=y,
679+
multimodal_embeddings=video_embeddings,
680+
mask=bidirectional_mask_video,
681+
token_masks=video_masks,
682+
)
683+
else:
684+
raise ValueError(f"Unsupported model_name for video: {cfg.model_name}")
685+
672686
if audio_embeddings is not None and cfg.use_audio:
673687
if cfg.model_name in ["qwen3-omni-30b-a3b"]:
674688
y = mm_utils.merge_mm_embeddings(

src/maxtext/models/models.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ def __call__(
127127
decoder_segment_ids=None,
128128
encoder_images: None | jnp.ndarray = None,
129129
encoder_image_masks: None | jnp.ndarray = None,
130+
encoder_videos: None | jnp.ndarray = None,
131+
encoder_video_masks: None | jnp.ndarray = None,
130132
encoder_audios: None | jnp.ndarray = None,
131133
enable_dropout=True,
132134
model_mode=MODEL_MODE_TRAIN,
@@ -153,17 +155,28 @@ def __call__(
153155
f" which is always {DECODING_ACTIVE_SEQUENCE_INDICATOR}."
154156
)
155157

156-
bidirectional_mask = None
158+
bidirectional_mask_image = None
159+
bidirectional_mask_video = None
157160
image_embeddings = None
161+
video_embeddings = None
158162
audio_embeddings = None
159163
deepstack_visual_embeds = None
160164

161165
if self.config.use_multimodal and encoder_images is not None:
162166
image_embeddings, deepstack_visual_embeds = self.vision_encoder(
163167
input_images=encoder_images, deterministic=not enable_dropout
164168
)
169+
bidirectional_mask_image = mm_processor.get_bidirectional_mask_vision(
170+
self.config, decoder_input_tokens, is_video=False
171+
)
165172

166-
bidirectional_mask = mm_processor.get_bidirectional_mask_vision(self.config, decoder_input_tokens)
173+
if self.config.use_multimodal and encoder_videos is not None:
174+
video_embeddings, deepstack_visual_embeds = self.vision_encoder(
175+
input_images=encoder_videos, deterministic=not enable_dropout
176+
)
177+
bidirectional_mask_video = mm_processor.get_bidirectional_mask_vision(
178+
self.config, decoder_input_tokens, is_video=True
179+
)
167180

168181
if self.config.use_multimodal and encoder_audios is not None and self.audio_encoder is not None:
169182
audio_embeddings = self.audio_encoder(input_audio=encoder_audios, deterministic=not enable_dropout)
@@ -174,13 +187,16 @@ def __call__(
174187
audio_masks = mm_processor.get_bidirectional_mask_audio(self.config, decoder_input_tokens)
175188

176189
multimodal_input = None
177-
if image_embeddings is not None or audio_embeddings is not None:
190+
if image_embeddings is not None or video_embeddings is not None or audio_embeddings is not None:
178191
multimodal_input = MultimodalInput(
179192
image_embeddings=image_embeddings,
180193
image_masks=encoder_image_masks,
194+
video_embeddings=video_embeddings,
195+
video_masks=encoder_video_masks,
181196
audio_embeddings=audio_embeddings,
182197
audio_masks=audio_masks,
183-
bidirectional_mask=bidirectional_mask,
198+
bidirectional_mask=bidirectional_mask_image,
199+
bidirectional_mask_video=bidirectional_mask_video,
184200
)
185201

186202
logits, hidden_state, kv_caches = self.decoder(
@@ -425,6 +441,8 @@ def __call__(
425441
cache=None,
426442
encoder_images: jax.Array | None = None,
427443
encoder_image_masks: jax.Array | None = None,
444+
encoder_videos: jax.Array | None = None,
445+
encoder_video_masks: jax.Array | None = None,
428446
encoder_audios: jax.Array | None = None,
429447
enable_dropout=True,
430448
model_mode=MODEL_MODE_TRAIN,
@@ -466,16 +484,28 @@ def __call__(
466484
f" which is always {DECODING_ACTIVE_SEQUENCE_INDICATOR}."
467485
)
468486

469-
bidirectional_mask = None
487+
bidirectional_mask_image = None
488+
bidirectional_mask_video = None
470489
image_embeddings = None
490+
video_embeddings = None
491+
audio_embeddings = None
471492
deepstack_visual_embeds = None
472493
if self.config.use_multimodal and encoder_images is not None:
473494
image_embeddings, deepstack_visual_embeds = self.vision_encoder(
474495
input_images=encoder_images, deterministic=not enable_dropout
475496
)
476-
bidirectional_mask = mm_processor.get_bidirectional_mask_vision(self.config, decoder_input_tokens)
497+
bidirectional_mask_image = mm_processor.get_bidirectional_mask_vision(
498+
self.config, decoder_input_tokens, is_video=False
499+
)
500+
501+
if self.config.use_multimodal and encoder_videos is not None:
502+
video_embeddings, deepstack_visual_embeds = self.vision_encoder(
503+
input_images=encoder_videos, deterministic=not enable_dropout
504+
)
505+
bidirectional_mask_video = mm_processor.get_bidirectional_mask_vision(
506+
self.config, decoder_input_tokens, is_video=True
507+
)
477508

478-
audio_embeddings = None
479509
if self.config.use_multimodal and encoder_audios is not None and self.audio_encoder is not None:
480510
audio_embeddings = self.audio_encoder(input_audio=encoder_audios, deterministic=not enable_dropout)
481511

@@ -485,13 +515,16 @@ def __call__(
485515
audio_masks = mm_processor.get_bidirectional_mask_audio(self.config, decoder_input_tokens)
486516

487517
multimodal_input = None
488-
if image_embeddings is not None or audio_embeddings is not None:
518+
if image_embeddings is not None or video_embeddings is not None or audio_embeddings is not None:
489519
multimodal_input = MultimodalInput(
490520
image_embeddings=image_embeddings,
491521
image_masks=encoder_image_masks,
522+
video_embeddings=video_embeddings,
523+
video_masks=encoder_video_masks,
492524
audio_embeddings=audio_embeddings,
493525
audio_masks=audio_masks,
494-
bidirectional_mask=bidirectional_mask,
526+
bidirectional_mask=bidirectional_mask_image,
527+
bidirectional_mask_video=bidirectional_mask_video,
495528
)
496529

497530
mutable_collections = []

src/maxtext/multimodal/processor.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def get_dummy_audio_shape_for_init(config):
207207
return audio_shape
208208

209209

210-
def get_bidirectional_mask_vision(config, decoder_input_tokens):
210+
def get_bidirectional_mask_vision(config, decoder_input_tokens, is_video: bool = False):
211211
"""Get the bidirectional mask for specific models."""
212212
bidirectional_mask_vision = None
213213
if config.model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]:
@@ -225,11 +225,10 @@ def get_bidirectional_mask_vision(config, decoder_input_tokens):
225225
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]:
226226
from maxtext.multimodal.processor_qwen3_omni import QWEN3_OMNI_IMAGE_TOKEN, QWEN3_OMNI_VIDEO_TOKEN # pylint: disable=import-outside-toplevel
227227

228-
# Create bidirectional_mask for vision/video token merging
229-
bidirectional_mask_vision = (decoder_input_tokens == QWEN3_OMNI_IMAGE_TOKEN) | (
230-
decoder_input_tokens == QWEN3_OMNI_VIDEO_TOKEN
231-
)
232-
# Create image/video mask for deepstack visual embedding injection
228+
if is_video:
229+
bidirectional_mask_vision = decoder_input_tokens == QWEN3_OMNI_VIDEO_TOKEN
230+
else:
231+
bidirectional_mask_vision = decoder_input_tokens == QWEN3_OMNI_IMAGE_TOKEN
233232
return bidirectional_mask_vision
234233

235234

src/maxtext/multimodal/processor_qwen3_omni.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,12 @@ def _np_extract_fbank_features(waveform_batch: np.ndarray) -> np.ndarray:
472472

473473
def pre_process_audio_qwen3_omni(audio_array):
474474
"""Preprocess audio for Qwen3-Omni model."""
475+
chunk_samples = 16000 # hop_length (160) * chunk_size (100)
476+
remainder = len(audio_array) % chunk_samples
477+
if remainder > 0:
478+
padding_size = chunk_samples - remainder
479+
audio_array = np.pad(audio_array, (0, padding_size), mode="constant")
480+
475481
audio_features = np.expand_dims(audio_array, axis=0) # Add batch dimension
476482
audio_features = _np_extract_fbank_features(audio_features)
477483
audio_features_mask = np.ones((audio_features.shape[0], audio_features.shape[2]), dtype=np.int32)
@@ -532,7 +538,17 @@ def preprocess_mm_data_qwen3_omni(config):
532538
if config.video_path:
533539
video_array, _ = _read_video_decord(config.video_path)
534540
video_processed, video_grid_thw = preprocess_video(video_array, config)
535-
processor_outputs.video_values = video_processed
541+
video_values = np.reshape(
542+
video_processed,
543+
(
544+
1,
545+
config.num_channels_for_vit,
546+
config.temporal_patch_size_for_vit * video_grid_thw[0, 0],
547+
config.patch_size_for_vit * video_grid_thw[0, 1],
548+
config.patch_size_for_vit * video_grid_thw[0, 2],
549+
),
550+
)
551+
processor_outputs.video_values = video_values
536552
processor_outputs.video_grid_thw = video_grid_thw
537553
processor_outputs.video_second_per_grid = np.asarray([config.temporal_patch_size_for_vit], dtype=np.float32)
538554
processor_outputs.num_videos = 1 # Only one video for now.
@@ -1143,6 +1159,9 @@ def get_mm_offsets_qwen3_omni(config, processor_output):
11431159
if processor_output.audio_lengths is not None:
11441160
audio_lengths = processor_output.audio_lengths
11451161
for audio_len in audio_lengths:
1146-
total_offset += int(audio_len) - 1 # -1 for the original <|audio_pad|> token
1162+
if getattr(config, "use_audio_in_video", False):
1163+
total_offset += int(audio_len) + 2 # +2 for <|audio_start|> and <|audio_end|>, no <|audio_pad|> to remove
1164+
else:
1165+
total_offset += int(audio_len) - 1 # -1 for the original <|audio_pad|> token
11471166

11481167
return total_offset

tests/unit/qwen3_omni_layers_test.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,9 @@ def test_vision_encoder_single_image(self):
551551
grid_thw = np.array([[1, h, w]], dtype=np.int64)
552552
grid_thw_torch = torch.from_numpy(grid_thw)
553553

554-
torch_output, torch_deep_feats = torch_encoder(torch_hidden_states, grid_thw_torch)
554+
torch_encoder_output = torch_encoder(torch_hidden_states, grid_thw_torch)
555+
torch_output = torch_encoder_output.pooler_output
556+
torch_deep_feats = torch_encoder_output.deepstack_features
555557
jax_encoder_output, jax_deep_feats = jax_encoder(jax_hidden_states)
556558
jax_output = jax_projector(jax_encoder_output)
557559

@@ -561,8 +563,8 @@ def test_vision_encoder_single_image(self):
561563
assert_all_close_jax_torch(
562564
jax_output,
563565
torch_output,
564-
rtol=1e-2,
565-
atol=1e-2,
566+
rtol=1.5e-2,
567+
atol=1.5e-2,
566568
error_msg="Vision encoder final output differs",
567569
)
568570

@@ -576,8 +578,8 @@ def test_vision_encoder_single_image(self):
576578
assert_all_close_jax_torch(
577579
jax_feat,
578580
torch_feat,
579-
rtol=1e-2,
580-
atol=1e-2,
581+
rtol=1.5e-2,
582+
atol=1.5e-2,
581583
error_msg=f"Deep feature {i} differs",
582584
)
583585

@@ -722,6 +724,16 @@ def test_preprocess_mm_data(self):
722724
USE_AUDIO_IN_VIDEO = True
723725
hf_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
724726
audios, images, videos = process_mm_info(conversation, use_audio_in_video=USE_AUDIO_IN_VIDEO)
727+
if audios is not None:
728+
padded_audios = []
729+
for audio in audios:
730+
chunk_samples = 16000
731+
remainder = len(audio) % chunk_samples
732+
if remainder > 0:
733+
padding_size = chunk_samples - remainder
734+
audio = np.pad(audio, (0, padding_size), mode="constant")
735+
padded_audios.append(audio)
736+
audios = padded_audios
725737
hf_processor_outputs = processor(
726738
text=hf_prompt,
727739
audio=audios,
@@ -749,9 +761,11 @@ def test_preprocess_mm_data(self):
749761
rtol=1e-2,
750762
atol=1e-2,
751763
)
764+
hf_pixel_values_videos = np.array(hf_processor_outputs["pixel_values_videos"]).astype(np.float32)
765+
mt_video_values = np.array(mt_processor_outputs.video_values).reshape(hf_pixel_values_videos.shape)
752766
assert np.allclose(
753-
mt_processor_outputs.video_values,
754-
np.array(hf_processor_outputs["pixel_values_videos"]).astype(np.float32),
767+
mt_video_values,
768+
hf_pixel_values_videos,
755769
rtol=5e-2,
756770
atol=5e-2,
757771
)

0 commit comments

Comments
 (0)