@@ -41,6 +41,7 @@ def __call__(self, x, model_mode):
4141 return jnp .ones ((x .shape [0 ], x .shape [1 ], self .emb_dim ))
4242
4343
44+ @pytest .mark .integration_test
4445class TestDeepSeekScanEngram (unittest .TestCase ):
4546 """Test DeepSeek decoder block with scan_layers=True and engram_layers."""
4647
@@ -53,16 +54,16 @@ class TestDeepSeekScanEngram(unittest.TestCase):
5354 "first_num_dense_layers=5" ,
5455 "base_num_decoder_layers=10" ,
5556 "num_decoder_layers=10" ,
56- "base_emb_dim=64 " ,
57- "base_mlp_dim=64 " ,
58- "base_moe_mlp_dim=64 " ,
59- "base_num_query_heads=2 " ,
60- "base_num_kv_heads=2 " ,
61- "head_dim=32 " ,
62- "indexer_head_dim=32 " ,
63- "qk_nope_head_dim=32 " ,
64- "qk_rope_head_dim=16 " ,
65- "v_head_dim=32 " ,
57+ "base_emb_dim=8 " ,
58+ "base_mlp_dim=8 " ,
59+ "base_moe_mlp_dim=8 " ,
60+ "base_num_query_heads=1 " ,
61+ "base_num_kv_heads=1 " ,
62+ "head_dim=4 " ,
63+ "indexer_head_dim=4 " ,
64+ "qk_nope_head_dim=4 " ,
65+ "qk_rope_head_dim=4 " ,
66+ "v_head_dim=4 " ,
6667 "vocab_size=128" ,
6768 "mhc_expansion_rate=4" ,
6869 "attention=dot_product" ,
@@ -71,15 +72,24 @@ class TestDeepSeekScanEngram(unittest.TestCase):
7172 "max_prefill_predict_length=8" ,
7273 "enable_checkpointing=False" ,
7374 "engram_num_heads=1" ,
74- "engram_head_dim=8 " ,
75+ "engram_head_dim=4 " ,
7576 "engram_vocab_bases=[128,128]" ,
7677 "engram_max_ngram_size=3" ,
7778 "engram_kernel_size=4" ,
79+ "num_experts=2" ,
80+ "num_experts_per_tok=1" ,
7881 "hf_access_token=dummy" ,
7982 "tokenizer_path=dummy" ,
8083 ]
8184
82- def _test_engram_pattern (self , mock_from_pretrained , engram_layers_str , expected_keys ):
85+ def _test_engram_pattern (
86+ self ,
87+ mock_from_pretrained ,
88+ engram_layers_str ,
89+ expected_keys ,
90+ first_num_dense_layers = 5 ,
91+ base_num_decoder_layers = 10 ,
92+ ):
8393 """Helper method to test different engram layer patterns."""
8494
8595 # Setup mock tokenizer
@@ -106,7 +116,16 @@ def batch_decode(self, token_ids, *args, **kwargs):
106116 mock_from_pretrained .return_value = MockTokenizer ()
107117
108118 config_path = os .path .join (MAXTEXT_PKG_DIR , "configs" , "base.yml" )
109- config = pyconfig .initialize ([None , config_path ] + self ._COMMON_CONFIG + [f"engram_layers=[{ engram_layers_str } ]" ])
119+ config = pyconfig .initialize (
120+ [None , config_path ]
121+ + self ._COMMON_CONFIG
122+ + [
123+ f"engram_layers=[{ engram_layers_str } ]" ,
124+ f"first_num_dense_layers={ first_num_dense_layers } " ,
125+ f"base_num_decoder_layers={ base_num_decoder_layers } " ,
126+ f"num_decoder_layers={ base_num_decoder_layers } " ,
127+ ]
128+ )
110129
111130 devices_array = maxtext_utils .create_device_mesh (config )
112131 mesh = Mesh (devices_array , config .mesh_axes )
@@ -126,7 +145,7 @@ def batch_decode(self, token_ids, *args, **kwargs):
126145
127146 shared_embedding = DummyEmbedding (emb_dim = config .emb_dim )
128147
129- with mesh :
148+ with mesh , jax . disable_jit () :
130149 variables = decoder .init (
131150 {"params" : jax .random .PRNGKey (0 ), "dropout" : jax .random .PRNGKey (1 ), "aqt" : jax .random .PRNGKey (2 )},
132151 shared_embedding = shared_embedding ,
@@ -154,15 +173,16 @@ def test_decoder_init_engram_2_8(self, mock_from_pretrained):
154173 """Test engram layers at indices 2 and 8."""
155174 self ._test_engram_pattern (
156175 mock_from_pretrained ,
157- "2,8 " ,
176+ "1,4 " ,
158177 [
159- "dense_layers_0_1" ,
160- "dense_layers_engram_2" ,
161- "dense_layers_3_4" ,
162- "moe_layers_5_7" ,
163- "moe_layers_engram_8" ,
164- "moe_layers_9_9" ,
178+ "dense_layers_0_0" ,
179+ "dense_layers_engram_1" ,
180+ "dense_layers_2_2" ,
181+ "moe_layers_3_3" ,
182+ "moe_layers_engram_4" ,
165183 ],
184+ first_num_dense_layers = 3 ,
185+ base_num_decoder_layers = 5 ,
166186 )
167187
168188 @pytest .mark .tpu_only
@@ -171,8 +191,10 @@ def test_decoder_init_engram_0_5(self, mock_from_pretrained):
171191 """Test engram layers at indices 0 and 5 - first engram layer of block."""
172192 self ._test_engram_pattern (
173193 mock_from_pretrained ,
174- "0,5" ,
175- ["dense_layers_engram_0" , "dense_layers_1_4" , "moe_layers_engram_5" , "moe_layers_6_9" ],
194+ "0,1" ,
195+ ["dense_layers_engram_0" , "moe_layers_engram_1" ],
196+ first_num_dense_layers = 1 ,
197+ base_num_decoder_layers = 2 ,
176198 )
177199
178200 @pytest .mark .tpu_only
@@ -181,6 +203,8 @@ def test_decoder_init_engram_4_9(self, mock_from_pretrained):
181203 """Test engram layers at indices 4 and 9 - last engram layer of block."""
182204 self ._test_engram_pattern (
183205 mock_from_pretrained ,
184- "4,9" ,
185- ["dense_layers_0_3" , "dense_layers_engram_4" , "moe_layers_5_8" , "moe_layers_engram_9" ],
206+ "1,3" ,
207+ ["dense_layers_0_0" , "dense_layers_engram_1" , "moe_layers_2_2" , "moe_layers_engram_3" ],
208+ first_num_dense_layers = 2 ,
209+ base_num_decoder_layers = 4 ,
186210 )
0 commit comments