1515"""Test for DeepSeek Manifold-Constrained Hyper Connections (mHC)."""
1616
1717import unittest
18- import pytest
19-
18+ from absl .testing import parameterized
2019from flax import nnx
2120from flax .linen import partitioning as nn_partitioning
2221import jax
2322import jax .numpy as jnp
2423from jax .sharding import Mesh
2524import numpy as np
25+ import pytest
2626
2727from maxtext .configs import pyconfig
2828from maxtext .common .common_types import HyperConnectionType
@@ -86,14 +86,15 @@ def test_doubly_stochastic_property(self):
8686 np .testing .assert_allclose (col_sums , jnp .ones_like (col_sums ), atol = 1e-3 )
8787
8888
89- class TestMHC (unittest .TestCase ):
89+ class TestMHC (parameterized .TestCase ):
9090 """Test for MHC module"""
9191
92- def setUp (self ):
92+ def _setup_mhc (self , rate , enable_mhc_lite = False ):
93+ """Sets up the common configurations and modules for MHC testing."""
9394 self .dim = 16
9495 self .config = pyconfig .initialize (
9596 [None , get_test_config_path ()],
96- run_name = "test_mhc " ,
97+ run_name = f"test_mhc_k { rate } " ,
9798 enable_checkpointing = False ,
9899 model_name = "deepseek-custom" ,
99100 per_device_batch_size = jax .device_count (),
@@ -105,7 +106,8 @@ def setUp(self):
105106 # override
106107 override_model_config = True ,
107108 base_emb_dim = self .dim ,
108- mhc_expansion_rate = 3 ,
109+ mhc_expansion_rate = rate ,
110+ enable_mhc_lite = enable_mhc_lite ,
109111 num_experts = 4 ,
110112 num_experts_per_tok = 2 ,
111113 engram_layers = [],
@@ -135,7 +137,10 @@ def setUp(self):
135137
136138 # Skip GPU due to NotImplementedError: dynamic grid bounds not supported in the Triton backend
137139 @pytest .mark .tpu_only
138- def test_moe_layer_output_shape (self ):
140+ @parameterized .named_parameters (("Rate3" , 3 ), ("Rate4" , 4 ))
141+ def test_moe_layer_output_shape (self , rate ):
142+ self ._setup_mhc (rate )
143+
139144 with nn_partitioning .axis_rules (self .config .logical_axis_rules ):
140145 module = mhc .ManifoldConstrainedHyperConnections (self .config , self .dim , self .mesh , self .rngs )
141146 layer = moe .RoutedMoE (
@@ -154,12 +159,14 @@ def test_moe_layer_output_shape(self):
154159 b , s , k , d = self .x .shape
155160 output , metadata = module (self .pre_norm , layer , x = self .x , mhc_type = HyperConnectionType .MLP_MOE )
156161 # metadata includes load_balance_loss & moe_bias_updates
157- self .assertEqual ( len ( metadata ) , 2 )
162+ self .assertLen ( metadata , 2 )
158163 for key , value in metadata .items ():
159164 self .assertIsNotNone (value , f"Key '{ key } ' has a value of None" )
160165 self .assertEqual (output .shape , (b , s , k , d ))
161166
162- def test_dense_layer_output_shape (self ):
167+ @parameterized .named_parameters (("Rate3" , 3 ), ("Rate4" , 4 ))
168+ def test_dense_layer_output_shape (self , rate ):
169+ self ._setup_mhc (rate )
163170 with nn_partitioning .axis_rules (self .config .logical_axis_rules ):
164171 module = mhc .ManifoldConstrainedHyperConnections (self .config , self .dim , self .mesh , self .rngs )
165172 layer = linears .MlpBlock (
@@ -180,8 +187,14 @@ def test_dense_layer_output_shape(self):
180187 self .assertDictEqual (metadata , {})
181188 self .assertEqual (output .shape , (b , s , k , d ))
182189
183- def test_attention_layer_output_shape (self ):
184- inputs_shape = (self .config .per_device_batch_size , self .config .max_target_length , self .config .emb_dim )
190+ @parameterized .named_parameters (("Rate3" , 3 ), ("Rate4" , 4 ))
191+ def test_attention_layer_output_shape (self , rate ):
192+ self ._setup_mhc (rate )
193+ inputs_shape = (
194+ self .config .per_device_batch_size ,
195+ self .config .max_target_length ,
196+ self .config .emb_dim ,
197+ )
185198 with nn_partitioning .axis_rules (self .config .logical_axis_rules ):
186199 module = mhc .ManifoldConstrainedHyperConnections (self .config , self .dim , self .mesh , self .rngs )
187200 layer = attention_mla .MLA (
@@ -219,6 +232,65 @@ def test_attention_layer_output_shape(self):
219232 self .assertDictEqual (metadata , {})
220233 self .assertEqual (output .shape , (b , s , k , d ))
221234
235+ def test_mhc_lite_doubly_stochastic (self ):
236+ """Verify that mHC-lite output is doubly stochastic (rows/cols sum to 1)."""
237+ self ._setup_mhc (4 , enable_mhc_lite = True )
238+ with nn_partitioning .axis_rules (self .config .logical_axis_rules ):
239+ module = mhc .ManifoldConstrainedHyperConnections (self .config , self .dim , self .mesh , self .rngs )
240+
241+ b , s , k , d = self .x .shape
242+
243+ # Generate random input X
244+ random_x = jax .random .normal (jax .random .PRNGKey (42 ), (b , s , k * d ))
245+ norm_x = module .mhc_norm (random_x )
246+
247+ # Output from mHC-lite mapping
248+ res_mapping_out = module .res_mapping (norm_x )
249+
250+ row_sums = jnp .sum (res_mapping_out , axis = - 1 )
251+ col_sums = jnp .sum (res_mapping_out , axis = - 2 )
252+
253+ # Check if sums are close to 1.0
254+ np .testing .assert_allclose (row_sums , jnp .ones_like (row_sums ), atol = 1e-2 )
255+ np .testing .assert_allclose (col_sums , jnp .ones_like (col_sums ), atol = 1e-2 )
256+
257+ def test_feature_flag_gates_lite (self ):
258+ """Verify that setting enable_mhc_lite=False falls back to Sinkhorn."""
259+ self .dim = 16
260+ self .config = pyconfig .initialize (
261+ [None , get_test_config_path ()],
262+ run_name = "test_mhc_lite_gated" ,
263+ enable_checkpointing = False ,
264+ model_name = "deepseek-custom" ,
265+ per_device_batch_size = 4 ,
266+ max_target_length = 7 ,
267+ max_prefill_predict_length = 7 ,
268+ attention = "dot_product" ,
269+ routed_bias_update_rate = 0.01 ,
270+ load_balance_loss_weight = 0.02 ,
271+ # override
272+ override_model_config = True ,
273+ base_emb_dim = self .dim ,
274+ mhc_expansion_rate = 4 ,
275+ enable_mhc_lite = False ,
276+ num_experts = 4 ,
277+ num_experts_per_tok = 2 ,
278+ engram_layers = [],
279+ )
280+ devices_array = maxtext_utils .create_device_mesh (self .config )
281+ self .mesh = Mesh (devices_array , self .config .mesh_axes )
282+ self .rngs = nnx .Rngs (params = jax .random .key (0 ), dropout = jax .random .key (42 ))
283+
284+ with nn_partitioning .axis_rules (self .config .logical_axis_rules ):
285+ module = mhc .ManifoldConstrainedHyperConnections (self .config , self .dim , self .mesh , self .rngs )
286+
287+ # Shape of res_alpha should be (4*16, 4*4) = (64, 16) instead of (64, 24)
288+ self .assertEqual (module .res_alpha .shape , (64 , 16 ))
289+ # Shape of res_beta should be (4, 4) instead of (24,)
290+ self .assertEqual (module .res_beta .shape , (4 , 4 ))
291+ # Permutation matrices shouldn't be defined
292+ self .assertFalse (hasattr (module , "permutation_matrices" ))
293+
222294
223295if __name__ == "__main__" :
224296 unittest .main ()
0 commit comments