1515"""Test for DeepSeek Manifold-Constrained Hyper Connections (mHC)."""
1616
1717import unittest
18- import pytest
19-
18+ from absl .testing import parameterized
2019from flax import nnx
2120from flax .linen import partitioning as nn_partitioning
2221import jax
2322import jax .numpy as jnp
2423from jax .sharding import Mesh
25- import numpy as np
26-
27- from maxtext .configs import pyconfig
2824from maxtext .common .common_types import HyperConnectionType
25+ from maxtext .configs import pyconfig
2926from maxtext .layers import attention_mla , linears , mhc , moe
3027from maxtext .layers .initializers import nd_dense_init
3128from maxtext .layers .normalizations import RMSNorm
3229from maxtext .utils import maxtext_utils
30+ import numpy as np
31+ import pytest
3332from tests .utils .test_helpers import get_test_config_path
3433
3534
@@ -86,14 +85,15 @@ def test_doubly_stochastic_property(self):
8685 np .testing .assert_allclose (col_sums , jnp .ones_like (col_sums ), atol = 1e-3 )
8786
8887
89- class TestMHC (unittest .TestCase ):
88+ class TestMHC (parameterized .TestCase ):
9089 """Test for MHC module"""
9190
92- def setUp (self ):
91+ def _setup_mhc (self , rate , enable_mhc_lite = False ):
92+ """Sets up the common configurations and modules for MHC testing."""
9393 self .dim = 16
9494 self .config = pyconfig .initialize (
9595 [None , get_test_config_path ()],
96- run_name = "test_mhc " ,
96+ run_name = f"test_mhc_k { rate } " ,
9797 enable_checkpointing = False ,
9898 model_name = "deepseek-custom" ,
9999 per_device_batch_size = jax .device_count (),
@@ -105,7 +105,8 @@ def setUp(self):
105105 # override
106106 override_model_config = True ,
107107 base_emb_dim = self .dim ,
108- mhc_expansion_rate = 3 ,
108+ mhc_expansion_rate = rate ,
109+ enable_mhc_lite = enable_mhc_lite ,
109110 num_experts = 4 ,
110111 num_experts_per_tok = 2 ,
111112 engram_layers = [],
@@ -135,7 +136,10 @@ def setUp(self):
135136
136137 # Skip GPU due to NotImplementedError: dynamic grid bounds not supported in the Triton backend
137138 @pytest .mark .tpu_only
138- def test_moe_layer_output_shape (self ):
139+ @parameterized .named_parameters (("Rate3" , 3 ), ("Rate4" , 4 ))
140+ def test_moe_layer_output_shape (self , rate ):
141+ self ._setup_mhc (rate )
142+
139143 with nn_partitioning .axis_rules (self .config .logical_axis_rules ):
140144 module = mhc .ManifoldConstrainedHyperConnections (self .config , self .dim , self .mesh , self .rngs )
141145 layer = moe .RoutedMoE (
@@ -154,12 +158,14 @@ def test_moe_layer_output_shape(self):
154158 b , s , k , d = self .x .shape
155159 output , metadata = module (self .pre_norm , layer , x = self .x , mhc_type = HyperConnectionType .MLP_MOE )
156160 # metadata includes load_balance_loss & moe_bias_updates
157- self .assertEqual ( len ( metadata ) , 2 )
161+ self .assertLen ( metadata , 2 )
158162 for key , value in metadata .items ():
159163 self .assertIsNotNone (value , f"Key '{ key } ' has a value of None" )
160164 self .assertEqual (output .shape , (b , s , k , d ))
161165
162- def test_dense_layer_output_shape (self ):
166+ @parameterized .named_parameters (("Rate3" , 3 ), ("Rate4" , 4 ))
167+ def test_dense_layer_output_shape (self , rate ):
168+ self ._setup_mhc (rate )
163169 with nn_partitioning .axis_rules (self .config .logical_axis_rules ):
164170 module = mhc .ManifoldConstrainedHyperConnections (self .config , self .dim , self .mesh , self .rngs )
165171 layer = linears .MlpBlock (
@@ -180,8 +186,14 @@ def test_dense_layer_output_shape(self):
180186 self .assertDictEqual (metadata , {})
181187 self .assertEqual (output .shape , (b , s , k , d ))
182188
183- def test_attention_layer_output_shape (self ):
184- inputs_shape = (self .config .per_device_batch_size , self .config .max_target_length , self .config .emb_dim )
189+ @parameterized .named_parameters (("Rate3" , 3 ), ("Rate4" , 4 ))
190+ def test_attention_layer_output_shape (self , rate ):
191+ self ._setup_mhc (rate )
192+ inputs_shape = (
193+ self .config .per_device_batch_size ,
194+ self .config .max_target_length ,
195+ self .config .emb_dim ,
196+ )
185197 with nn_partitioning .axis_rules (self .config .logical_axis_rules ):
186198 module = mhc .ManifoldConstrainedHyperConnections (self .config , self .dim , self .mesh , self .rngs )
187199 layer = attention_mla .MLA (
@@ -219,6 +231,69 @@ def test_attention_layer_output_shape(self):
219231 self .assertDictEqual (metadata , {})
220232 self .assertEqual (output .shape , (b , s , k , d ))
221233
234+ def test_mhc_lite_doubly_stochastic (self ):
235+ """Verify that mHC-lite output is doubly stochastic (rows/cols sum to 1)."""
236+ self ._setup_mhc (4 , enable_mhc_lite = True )
237+ with nn_partitioning .axis_rules (self .config .logical_axis_rules ):
238+ module = mhc .ManifoldConstrainedHyperConnections (
239+ self .config , self .dim , self .mesh , self .rngs
240+ )
241+
242+ b , s , k , d = self .x .shape
243+
244+ # Generate random input X
245+ random_x = jax .random .normal (jax .random .PRNGKey (42 ), (b , s , k * d ))
246+ norm_x = module .mhc_norm (random_x )
247+
248+ # Output from mHC-lite mapping
249+ res_mapping_out = module .res_mapping (norm_x )
250+
251+ row_sums = jnp .sum (res_mapping_out , axis = - 1 )
252+ col_sums = jnp .sum (res_mapping_out , axis = - 2 )
253+
254+ # Check if sums are close to 1.0
255+ np .testing .assert_allclose (row_sums , jnp .ones_like (row_sums ), atol = 1e-2 )
256+ np .testing .assert_allclose (col_sums , jnp .ones_like (col_sums ), atol = 1e-2 )
257+
258+ def test_feature_flag_gates_lite (self ):
259+ """Verify that setting enable_mhc_lite=False falls back to Sinkhorn."""
260+ self .dim = 16
261+ self .config = pyconfig .initialize (
262+ [None , get_test_config_path ()],
263+ run_name = "test_mhc_lite_gated" ,
264+ enable_checkpointing = False ,
265+ model_name = "deepseek-custom" ,
266+ per_device_batch_size = 4 ,
267+ max_target_length = 7 ,
268+ max_prefill_predict_length = 7 ,
269+ attention = "dot_product" ,
270+ routed_bias_update_rate = 0.01 ,
271+ load_balance_loss_weight = 0.02 ,
272+ # override
273+ override_model_config = True ,
274+ base_emb_dim = self .dim ,
275+ mhc_expansion_rate = 4 ,
276+ enable_mhc_lite = False ,
277+ num_experts = 4 ,
278+ num_experts_per_tok = 2 ,
279+ engram_layers = [],
280+ )
281+ devices_array = maxtext_utils .create_device_mesh (self .config )
282+ self .mesh = Mesh (devices_array , self .config .mesh_axes )
283+ self .rngs = nnx .Rngs (params = jax .random .key (0 ), dropout = jax .random .key (42 ))
284+
285+ with nn_partitioning .axis_rules (self .config .logical_axis_rules ):
286+ module = mhc .ManifoldConstrainedHyperConnections (
287+ self .config , self .dim , self .mesh , self .rngs
288+ )
289+
290+ # Shape of res_alpha should be (4*16, 4*4) = (64, 16) instead of (64, 24)
291+ self .assertEqual (module .res_alpha .shape , (64 , 16 ))
292+ # Shape of res_beta should be (4, 4) instead of (24,)
293+ self .assertEqual (module .res_beta .shape , (4 , 4 ))
294+ # Permutation matrices shouldn't be defined
295+ self .assertFalse (hasattr (module , "permutation_matrices" ))
296+
222297
223298if __name__ == "__main__" :
224299 unittest .main ()
0 commit comments