|
26 | 26 | import jax.numpy as jnp |
27 | 27 | from jax.sharding import Mesh |
28 | 28 | from MaxText import common_types |
29 | | -from MaxText import maxengine |
30 | 29 | from MaxText import pyconfig |
31 | 30 | from MaxText.globals import MAXTEXT_REPO_ROOT |
| 31 | +from maxtext.inference.maxengine import maxengine |
32 | 32 | from maxtext.layers.attentions import Attention |
33 | 33 | from maxtext.layers.embeddings import ( |
34 | 34 | PositionalEmbedding, |
35 | 35 | Qwen3OmniMoeVisionPosEmbedInterpolate as JaxQwen3OmniMoeVisionPosEmbedInterpolate, |
36 | 36 | Qwen3OmniMoeVisionRotaryEmbedding as JaxQwen3OmniMoeVisionRotaryEmbedding, |
37 | 37 | ) |
| 38 | +from maxtext.layers.decoders import deepstack_process |
38 | 39 | from maxtext.layers.encoders import AudioEncoder |
39 | 40 | from maxtext.models.qwen3 import ( |
40 | 41 | Qwen3OmniAudioEncoder, |
@@ -579,6 +580,86 @@ def test_vision_encoder_single_image(self): |
579 | 580 | ) |
580 | 581 |
|
581 | 582 |
|
| 583 | +class TestDeepstackProcess(unittest.TestCase): |
| 584 | + """Tests for deepstack_process. |
| 585 | +
|
| 586 | + Adds deepstack visual embeddings into decoder hidden states at the |
| 587 | + positions indicated by the bidirectional mask (visual token positions). |
| 588 | + """ |
| 589 | + |
| 590 | + def test_adds_only_at_visual_positions(self): |
| 591 | + """Visual embeddings should be added at True mask positions and nowhere else.""" |
| 592 | + batch, seq_len, hidden_dim = 2, 8, 4 |
| 593 | + hidden_states = jnp.zeros((batch, seq_len, hidden_dim)) |
| 594 | + # positions 1, 3, 5 are visual for both batch items (3 visual tokens each) |
| 595 | + mask = jnp.array( |
| 596 | + [ |
| 597 | + [False, True, False, True, False, True, False, False], |
| 598 | + [False, True, False, True, False, True, False, False], |
| 599 | + ] |
| 600 | + ) |
| 601 | + visual_embeds = jnp.ones((batch, 3, hidden_dim)) |
| 602 | + |
| 603 | + result = deepstack_process(hidden_states, mask, visual_embeds) |
| 604 | + |
| 605 | + for b in range(batch): |
| 606 | + for pos in [1, 3, 5]: |
| 607 | + np.testing.assert_allclose(np.array(result[b, pos]), np.ones(hidden_dim), err_msg=f"batch={b} pos={pos}") |
| 608 | + for pos in [0, 2, 4, 6, 7]: |
| 609 | + np.testing.assert_allclose(np.array(result[b, pos]), np.zeros(hidden_dim), err_msg=f"batch={b} pos={pos}") |
| 610 | + |
| 611 | + def test_visual_tokens_mapped_in_order(self): |
| 612 | + """Each visual embed should be added to the corresponding visual position in cumsum order.""" |
| 613 | + batch, seq_len, hidden_dim = 1, 6, 2 |
| 614 | + hidden_states = jnp.zeros((batch, seq_len, hidden_dim)) |
| 615 | + mask = jnp.array([[False, True, False, True, False, False]]) |
| 616 | + # two distinct visual tokens, a third token that won't be used |
| 617 | + visual_embeds = jnp.array([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]]) |
| 618 | + |
| 619 | + result = deepstack_process(hidden_states, mask, visual_embeds) |
| 620 | + |
| 621 | + # 1st visual position → visual_embeds[0, 0] |
| 622 | + np.testing.assert_allclose(np.array(result[0, 1]), [1.0, 2.0]) |
| 623 | + # 2nd visual position → visual_embeds[0, 1] |
| 624 | + np.testing.assert_allclose(np.array(result[0, 3]), [3.0, 4.0]) |
| 625 | + # non-visual positions untouched |
| 626 | + for pos in [0, 2, 4, 5]: |
| 627 | + np.testing.assert_allclose(np.array(result[0, pos]), [0.0, 0.0]) |
| 628 | + |
| 629 | + def test_matches_reference_scatter(self): |
| 630 | + """Output must match a reference numpy loop that scatters visual embeds by position.""" |
| 631 | + batch, seq_len, hidden_dim, num_visual = 2, 10, 8, 4 |
| 632 | + np.random.seed(0) |
| 633 | + |
| 634 | + hidden_np = np.random.randn(batch, seq_len, hidden_dim).astype(np.float32) |
| 635 | + mask_np = np.zeros((batch, seq_len), dtype=bool) |
| 636 | + mask_np[:, [1, 3, 5, 7]] = True # 4 visual tokens per batch item |
| 637 | + visual_np = np.random.randn(batch, num_visual, hidden_dim).astype(np.float32) |
| 638 | + |
| 639 | + # Reference: per-batch scatter |
| 640 | + expected = hidden_np.copy() |
| 641 | + for b in range(batch): |
| 642 | + vi = 0 |
| 643 | + for s in range(seq_len): |
| 644 | + if mask_np[b, s]: |
| 645 | + expected[b, s] += visual_np[b, vi] |
| 646 | + vi += 1 |
| 647 | + |
| 648 | + result = deepstack_process(jnp.array(hidden_np), jnp.array(mask_np), jnp.array(visual_np)) |
| 649 | + np.testing.assert_allclose(np.array(result), expected, rtol=1e-5, atol=1e-5) |
| 650 | + |
| 651 | + def test_hidden_states_unchanged_without_visual_tokens(self): |
| 652 | + """When mask is all-False, hidden states should be returned unchanged.""" |
| 653 | + batch, seq_len, hidden_dim = 2, 6, 4 |
| 654 | + np.random.seed(1) |
| 655 | + hidden_np = np.random.randn(batch, seq_len, hidden_dim).astype(np.float32) |
| 656 | + mask = jnp.zeros((batch, seq_len), dtype=bool) |
| 657 | + visual_embeds = jnp.ones((batch, 1, hidden_dim)) |
| 658 | + |
| 659 | + result = deepstack_process(jnp.array(hidden_np), mask, visual_embeds) |
| 660 | + np.testing.assert_allclose(np.array(result), hidden_np, rtol=1e-6, atol=1e-6) |
| 661 | + |
| 662 | + |
582 | 663 | class TestQwen3OmniPreprocessing(unittest.TestCase): |
583 | 664 | """Test MaxText Qwen3 Omni preprocessor against HuggingFace reference.""" |
584 | 665 |
|
|
0 commit comments