1414
1515"""Test for DeepSeek Manifold-Constrained Hyper Connections (mHC)."""
1616
17- import unittest
18- import pytest
19-
17+ from absl .testing import absltest
18+ from absl .testing import parameterized
2019from flax import nnx
2120from flax .linen import partitioning as nn_partitioning
2221import jax
2322import jax .numpy as jnp
2423from jax .sharding import Mesh
2524import numpy as np
25+ import pytest
2626
2727from maxtext .configs import pyconfig
2828from maxtext .common .common_types import HyperConnectionType
3333from tests .utils .test_helpers import get_test_config_path
3434
3535
36- class TestExpandReduce (unittest .TestCase ):
36+ class TestExpandReduce (absltest .TestCase ):
3737 """Unit tests for MHC dimension expansion and reduction operations."""
3838
3939 def setUp (self ):
@@ -65,7 +65,7 @@ def test_value_identity(self):
6565 np .testing .assert_allclose (out , expected , rtol = 1e-5 )
6666
6767
68- class TestSinkhorn (unittest .TestCase ):
68+ class TestSinkhorn (absltest .TestCase ):
6969 """Unit tests for MHC Sinkhorn Algorithm."""
7070
7171 def setUp (self ):
@@ -86,17 +86,19 @@ def test_doubly_stochastic_property(self):
8686 np .testing .assert_allclose (col_sums , jnp .ones_like (col_sums ), atol = 1e-3 )
8787
8888
89- class TestMHC (unittest .TestCase ):
89+ class TestMHC (parameterized .TestCase ):
9090 """Test for MHC module"""
9191
92- def setUp (self ):
92+ def _setup_mhc (self , rate , enable_mhc_lite = False ):
93+ """Sets up the common configurations and modules for MHC testing."""
9394 self .dim = 16
9495 self .config = pyconfig .initialize (
9596 [None , get_test_config_path ()],
96- run_name = "test_mhc" ,
97+ skip_jax_distributed_system = True ,
98+ run_name = f"test_mhc_k{ rate } " ,
9799 enable_checkpointing = False ,
98100 model_name = "deepseek-custom" ,
99- per_device_batch_size = jax .device_count (),
101+ per_device_batch_size = max ( 4 , jax .device_count () ),
100102 max_target_length = 7 ,
101103 max_prefill_predict_length = 7 ,
102104 attention = "dot_product" ,
@@ -105,7 +107,8 @@ def setUp(self):
105107 # override
106108 override_model_config = True ,
107109 base_emb_dim = self .dim ,
108- mhc_expansion_rate = 3 ,
110+ mhc_expansion_rate = rate ,
111+ enable_mhc_lite = enable_mhc_lite ,
109112 num_experts = 4 ,
110113 num_experts_per_tok = 2 ,
111114 engram_layers = [],
@@ -135,7 +138,10 @@ def setUp(self):
135138
136139 # Skip GPU due to NotImplementedError: dynamic grid bounds not supported in the Triton backend
137140 @pytest .mark .tpu_only
138- def test_moe_layer_output_shape (self ):
141+ @parameterized .named_parameters (("Rate3" , 3 ), ("Rate4" , 4 ))
142+ def test_moe_layer_output_shape (self , rate ):
143+ self ._setup_mhc (rate )
144+
139145 with nn_partitioning .axis_rules (self .config .logical_axis_rules ):
140146 module = mhc .ManifoldConstrainedHyperConnections (self .config , self .dim , self .mesh , self .rngs )
141147 layer = moe .RoutedMoE (
@@ -154,12 +160,14 @@ def test_moe_layer_output_shape(self):
154160 b , s , k , d = self .x .shape
155161 output , metadata = module (self .pre_norm , layer , x = self .x , mhc_type = HyperConnectionType .MLP_MOE )
156162 # metadata includes load_balance_loss & moe_bias_updates
157- self .assertEqual ( len ( metadata ) , 2 )
163+ self .assertLen ( metadata , 2 )
158164 for key , value in metadata .items ():
159165 self .assertIsNotNone (value , f"Key '{ key } ' has a value of None" )
160166 self .assertEqual (output .shape , (b , s , k , d ))
161167
162- def test_dense_layer_output_shape (self ):
168+ @parameterized .named_parameters (("Rate3" , 3 ), ("Rate4" , 4 ))
169+ def test_dense_layer_output_shape (self , rate ):
170+ self ._setup_mhc (rate )
163171 with nn_partitioning .axis_rules (self .config .logical_axis_rules ):
164172 module = mhc .ManifoldConstrainedHyperConnections (self .config , self .dim , self .mesh , self .rngs )
165173 layer = linears .MlpBlock (
@@ -180,8 +188,14 @@ def test_dense_layer_output_shape(self):
180188 self .assertDictEqual (metadata , {})
181189 self .assertEqual (output .shape , (b , s , k , d ))
182190
183- def test_attention_layer_output_shape (self ):
184- inputs_shape = (self .config .per_device_batch_size , self .config .max_target_length , self .config .emb_dim )
191+ @parameterized .named_parameters (("Rate3" , 3 ), ("Rate4" , 4 ))
192+ def test_attention_layer_output_shape (self , rate ):
193+ self ._setup_mhc (rate )
194+ inputs_shape = (
195+ self .config .per_device_batch_size ,
196+ self .config .max_target_length ,
197+ self .config .emb_dim ,
198+ )
185199 with nn_partitioning .axis_rules (self .config .logical_axis_rules ):
186200 module = mhc .ManifoldConstrainedHyperConnections (self .config , self .dim , self .mesh , self .rngs )
187201 layer = attention_mla .MLA (
@@ -219,6 +233,66 @@ def test_attention_layer_output_shape(self):
219233 self .assertDictEqual (metadata , {})
220234 self .assertEqual (output .shape , (b , s , k , d ))
221235
236+ def test_mhc_lite_doubly_stochastic (self ):
237+ """Verify that mHC-lite output is doubly stochastic (rows/cols sum to 1)."""
238+ self ._setup_mhc (4 , enable_mhc_lite = True )
239+ with nn_partitioning .axis_rules (self .config .logical_axis_rules ):
240+ module = mhc .ManifoldConstrainedHyperConnections (self .config , self .dim , self .mesh , self .rngs )
241+
242+ b , s , k , d = self .x .shape
243+
244+ # Generate random input X
245+ random_x = jax .random .normal (jax .random .PRNGKey (42 ), (b , s , k * d ))
246+ norm_x = module .mhc_norm (random_x )
247+
248+ # Output from mHC-lite mapping
249+ res_mapping_out = module .res_mapping (norm_x )
250+
251+ row_sums = jnp .sum (res_mapping_out , axis = - 1 )
252+ col_sums = jnp .sum (res_mapping_out , axis = - 2 )
253+
254+ # Check if sums are close to 1.0
255+ np .testing .assert_allclose (row_sums , jnp .ones_like (row_sums ), atol = 1e-2 )
256+ np .testing .assert_allclose (col_sums , jnp .ones_like (col_sums ), atol = 1e-2 )
257+
258+ def test_feature_flag_gates_lite (self ):
259+ """Verify that setting enable_mhc_lite=False falls back to Sinkhorn."""
260+ self .dim = 16
261+ self .config = pyconfig .initialize (
262+ [None , get_test_config_path ()],
263+ skip_jax_distributed_system = True ,
264+ run_name = "test_mhc_lite_gated" ,
265+ enable_checkpointing = False ,
266+ model_name = "deepseek-custom" ,
267+ per_device_batch_size = 4 ,
268+ max_target_length = 7 ,
269+ max_prefill_predict_length = 7 ,
270+ attention = "dot_product" ,
271+ routed_bias_update_rate = 0.01 ,
272+ load_balance_loss_weight = 0.02 ,
273+ # override
274+ override_model_config = True ,
275+ base_emb_dim = self .dim ,
276+ mhc_expansion_rate = 4 ,
277+ enable_mhc_lite = False ,
278+ num_experts = 4 ,
279+ num_experts_per_tok = 2 ,
280+ engram_layers = [],
281+ )
282+ devices_array = maxtext_utils .create_device_mesh (self .config )
283+ self .mesh = Mesh (devices_array , self .config .mesh_axes )
284+ self .rngs = nnx .Rngs (params = jax .random .key (0 ), dropout = jax .random .key (42 ))
285+
286+ with nn_partitioning .axis_rules (self .config .logical_axis_rules ):
287+ module = mhc .ManifoldConstrainedHyperConnections (self .config , self .dim , self .mesh , self .rngs )
288+
289+ # Shape of res_alpha should be (4*16, 4*4) = (64, 16) instead of (64, 24)
290+ self .assertEqual (module .res_alpha .shape , (64 , 16 ))
291+ # Shape of res_beta should be (4, 4) instead of (24,)
292+ self .assertEqual (module .res_beta .shape , (4 , 4 ))
293+ # Permutation matrices shouldn't be defined
294+ self .assertFalse (hasattr (module , "permutation_matrices" ))
295+
222296
223297if __name__ == "__main__" :
224- unittest .main ()
298+ absltest .main ()
0 commit comments