Skip to content

Commit 973d8e1

Browse files
Merge pull request #3214 from AI-Hypercomputer:qwen-deepstack
PiperOrigin-RevId: 874794710
2 parents f1fc688 + 7da6a17 commit 973d8e1

6 files changed

Lines changed: 134 additions & 5 deletions

File tree

src/maxtext/configs/models/qwen3-omni-30b-a3b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ max_position_embeddings: 65536
3939

4040
# General Model Settings
4141
enable_dropout: False
42+
scan_layers: False # deepstack does not support scan_layers
4243

4344
# Vision Encoder Configuration
4445
# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py

src/maxtext/configs/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1928,6 +1928,13 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig":
19281928
if self.steps == -1:
19291929
self.steps = self.learning_rate_schedule_steps
19301930

1931+
# Validate deepstack + scan_layers incompatibility
1932+
if self.deepstack_visual_indexes_for_vit and self.scan_layers:
1933+
raise ValueError(
1934+
"Deepstack visual embedding injection requires scan_layers=False. "
1935+
"Set scan_layers=False in your config to use deepstack features."
1936+
)
1937+
19311938
# Validate WSD learning rate schedule fractions
19321939
if self.lr_schedule_type == LearningRateScheduleType.WSD:
19331940
total_fraction = self.warmup_steps_fraction + self.wsd_decay_steps_fraction

src/maxtext/layers/decoders.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,31 @@ def __call__(
265265
return inputs
266266

267267

268+
def deepstack_process(hidden_states, bidirectional_mask, visual_embeds):
269+
"""Process deepstack visual embeddings by adding them to hidden states at visual token positions.
270+
271+
Args:
272+
hidden_states: [batch, seq_len, hidden_dim] decoder hidden states
273+
bidirectional_mask: [batch, seq_len] boolean mask marking visual token positions
274+
visual_embeds: [batch, num_visual_tokens, hidden_dim] visual features from encoder layer
275+
276+
Returns:
277+
Updated hidden_states with visual features added at visual positions
278+
"""
279+
# Expand mask to [batch, seq_len, 1] for broadcasting
280+
mask_expanded = bidirectional_mask[:, :, jnp.newaxis]
281+
# Use cumsum to map each True position in mask to its index in visual_embeds
282+
visual_token_idx = jnp.cumsum(bidirectional_mask, axis=1) - 1 # [batch, seq_len], 0-indexed
283+
284+
# Gather visual tokens: for each position, get the corresponding visual token
285+
batch_idx = jnp.arange(hidden_states.shape[0])[:, jnp.newaxis] # [batch, 1]
286+
visual_embeds_scattered = visual_embeds[batch_idx, visual_token_idx, :] # [batch, seq_len, hidden]
287+
288+
# Only add where mask is True: hidden_states += visual_embeds * mask
289+
hidden_states = hidden_states + visual_embeds_scattered * mask_expanded
290+
return hidden_states
291+
292+
268293
class Decoder(nn.Module):
269294
"""A stack of decoder layers as a part of an encoder-decoder architecture."""
270295

@@ -722,6 +747,7 @@ def __call__(
722747
attention_metadata=None,
723748
audio_embeddings: None | jnp.ndarray = None,
724749
audio_masks: None | jnp.ndarray = None,
750+
deepstack_visual_embeds: None | list[jnp.ndarray] = None,
725751
):
726752
cfg = self.config
727753
mesh = self.mesh
@@ -939,6 +965,12 @@ def __call__(
939965
if kv_caches is not None and kv_cache is not None:
940966
kv_caches[lyr] = kv_cache
941967

968+
if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds):
969+
visual_embeds = deepstack_visual_embeds[lyr]
970+
# Use bidirectional_mask to identify visual token positions
971+
if bidirectional_mask is not None and visual_embeds is not None:
972+
y = deepstack_process(y, bidirectional_mask, visual_embeds)
973+
942974
assert isinstance(y, jax.Array)
943975

944976
# After the final transformer layer, `y` holds the raw, un-normalized hidden state.

src/maxtext/layers/encoders.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def __call__(self, input_images, deterministic=False):
6565
# vision encoder output, frozen params in many cases
6666
encoder = getattr(self, self.encoder_name)
6767
encoder_output = encoder(input_images, deterministic=deterministic)
68-
6968
deep_feats = None
7069
if isinstance(encoder_output, tuple):
7170
embeddings = encoder_output[0]
@@ -75,6 +74,8 @@ def __call__(self, input_images, deterministic=False):
7574

7675
if self.config.freeze_vision_encoder_params:
7776
embeddings = jax.lax.stop_gradient(embeddings)
77+
if deep_feats is not None:
78+
deep_feats = [jax.lax.stop_gradient(feat) for feat in deep_feats]
7879

7980
# vision embedder / projection layer, not frozen in most cases, trained / finetuned together with main model
8081
projector = getattr(self, self.projector_name)

src/maxtext/models/models.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,12 @@ def __call__(
151151
bidirectional_mask = None
152152
image_embeddings = None
153153
audio_embeddings = None
154+
deepstack_visual_embeds = None
154155

155156
if self.config.use_multimodal and encoder_images is not None:
156-
# qwen3-omni-30b-a3b returns deep features from the vision encoder.
157-
image_embeddings, _ = self.vision_encoder(input_images=encoder_images, deterministic=not enable_dropout)
157+
image_embeddings, deepstack_visual_embeds = self.vision_encoder(
158+
input_images=encoder_images, deterministic=not enable_dropout
159+
)
158160
bidirectional_mask = mm_processor.get_bidirectional_mask_vision(self.config, decoder_input_tokens)
159161

160162
if self.config.use_multimodal and encoder_audios is not None and self.audio_encoder is not None:
@@ -182,6 +184,7 @@ def __call__(
182184
audio_masks=audio_masks,
183185
kv_caches=kv_caches,
184186
attention_metadata=attention_metadata,
187+
deepstack_visual_embeds=deepstack_visual_embeds,
185188
)
186189

187190
# If we are initializing the model AND MTP is enabled, we must create
@@ -458,8 +461,11 @@ def __call__(
458461

459462
bidirectional_mask = None
460463
image_embeddings = None
464+
deepstack_visual_embeds = None
461465
if self.config.use_multimodal and encoder_images is not None:
462-
image_embeddings, _ = self.vision_encoder(input_images=encoder_images, deterministic=not enable_dropout)
466+
image_embeddings, deepstack_visual_embeds = self.vision_encoder(
467+
input_images=encoder_images, deterministic=not enable_dropout
468+
)
463469
bidirectional_mask = mm_processor.get_bidirectional_mask_vision(self.config, decoder_input_tokens)
464470

465471
audio_embeddings = None
@@ -488,6 +494,7 @@ def __call__(
488494
audio_masks=audio_masks,
489495
kv_caches=kv_caches,
490496
attention_metadata=attention_metadata,
497+
deepstack_visual_embeds=deepstack_visual_embeds,
491498
)
492499

493500
# Materialize hidden state when vocab tiling is enabled

tests/unit/qwen3_omni_layers_test.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,16 @@
2626
import jax.numpy as jnp
2727
from jax.sharding import Mesh
2828
from MaxText import common_types
29-
from MaxText import maxengine
3029
from MaxText import pyconfig
3130
from MaxText.globals import MAXTEXT_REPO_ROOT
31+
from maxtext.inference.maxengine import maxengine
3232
from maxtext.layers.attentions import Attention
3333
from maxtext.layers.embeddings import (
3434
PositionalEmbedding,
3535
Qwen3OmniMoeVisionPosEmbedInterpolate as JaxQwen3OmniMoeVisionPosEmbedInterpolate,
3636
Qwen3OmniMoeVisionRotaryEmbedding as JaxQwen3OmniMoeVisionRotaryEmbedding,
3737
)
38+
from maxtext.layers.decoders import deepstack_process
3839
from maxtext.layers.encoders import AudioEncoder
3940
from maxtext.models.qwen3 import (
4041
Qwen3OmniAudioEncoder,
@@ -579,6 +580,86 @@ def test_vision_encoder_single_image(self):
579580
)
580581

581582

583+
class TestDeepstackProcess(unittest.TestCase):
584+
"""Tests for deepstack_process.
585+
586+
Adds deepstack visual embeddings into decoder hidden states at the
587+
positions indicated by the bidirectional mask (visual token positions).
588+
"""
589+
590+
def test_adds_only_at_visual_positions(self):
591+
"""Visual embeddings should be added at True mask positions and nowhere else."""
592+
batch, seq_len, hidden_dim = 2, 8, 4
593+
hidden_states = jnp.zeros((batch, seq_len, hidden_dim))
594+
# positions 1, 3, 5 are visual for both batch items (3 visual tokens each)
595+
mask = jnp.array(
596+
[
597+
[False, True, False, True, False, True, False, False],
598+
[False, True, False, True, False, True, False, False],
599+
]
600+
)
601+
visual_embeds = jnp.ones((batch, 3, hidden_dim))
602+
603+
result = deepstack_process(hidden_states, mask, visual_embeds)
604+
605+
for b in range(batch):
606+
for pos in [1, 3, 5]:
607+
np.testing.assert_allclose(np.array(result[b, pos]), np.ones(hidden_dim), err_msg=f"batch={b} pos={pos}")
608+
for pos in [0, 2, 4, 6, 7]:
609+
np.testing.assert_allclose(np.array(result[b, pos]), np.zeros(hidden_dim), err_msg=f"batch={b} pos={pos}")
610+
611+
def test_visual_tokens_mapped_in_order(self):
612+
"""Each visual embed should be added to the corresponding visual position in cumsum order."""
613+
batch, seq_len, hidden_dim = 1, 6, 2
614+
hidden_states = jnp.zeros((batch, seq_len, hidden_dim))
615+
mask = jnp.array([[False, True, False, True, False, False]])
616+
# two distinct visual tokens, a third token that won't be used
617+
visual_embeds = jnp.array([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]])
618+
619+
result = deepstack_process(hidden_states, mask, visual_embeds)
620+
621+
# 1st visual position → visual_embeds[0, 0]
622+
np.testing.assert_allclose(np.array(result[0, 1]), [1.0, 2.0])
623+
# 2nd visual position → visual_embeds[0, 1]
624+
np.testing.assert_allclose(np.array(result[0, 3]), [3.0, 4.0])
625+
# non-visual positions untouched
626+
for pos in [0, 2, 4, 5]:
627+
np.testing.assert_allclose(np.array(result[0, pos]), [0.0, 0.0])
628+
629+
def test_matches_reference_scatter(self):
630+
"""Output must match a reference numpy loop that scatters visual embeds by position."""
631+
batch, seq_len, hidden_dim, num_visual = 2, 10, 8, 4
632+
np.random.seed(0)
633+
634+
hidden_np = np.random.randn(batch, seq_len, hidden_dim).astype(np.float32)
635+
mask_np = np.zeros((batch, seq_len), dtype=bool)
636+
mask_np[:, [1, 3, 5, 7]] = True # 4 visual tokens per batch item
637+
visual_np = np.random.randn(batch, num_visual, hidden_dim).astype(np.float32)
638+
639+
# Reference: per-batch scatter
640+
expected = hidden_np.copy()
641+
for b in range(batch):
642+
vi = 0
643+
for s in range(seq_len):
644+
if mask_np[b, s]:
645+
expected[b, s] += visual_np[b, vi]
646+
vi += 1
647+
648+
result = deepstack_process(jnp.array(hidden_np), jnp.array(mask_np), jnp.array(visual_np))
649+
np.testing.assert_allclose(np.array(result), expected, rtol=1e-5, atol=1e-5)
650+
651+
def test_hidden_states_unchanged_without_visual_tokens(self):
652+
"""When mask is all-False, hidden states should be returned unchanged."""
653+
batch, seq_len, hidden_dim = 2, 6, 4
654+
np.random.seed(1)
655+
hidden_np = np.random.randn(batch, seq_len, hidden_dim).astype(np.float32)
656+
mask = jnp.zeros((batch, seq_len), dtype=bool)
657+
visual_embeds = jnp.ones((batch, 1, hidden_dim))
658+
659+
result = deepstack_process(jnp.array(hidden_np), mask, visual_embeds)
660+
np.testing.assert_allclose(np.array(result), hidden_np, rtol=1e-6, atol=1e-6)
661+
662+
582663
class TestQwen3OmniPreprocessing(unittest.TestCase):
583664
"""Test MaxText Qwen3 Omni preprocessor against HuggingFace reference."""
584665

0 commit comments

Comments
 (0)