From 2c88ab1ceb3817b8ce6adc21225bce79dd603d22 Mon Sep 17 00:00:00 2001 From: dandragona Date: Tue, 12 May 2026 21:31:45 +0000 Subject: [PATCH] Optimize mHC for expansion rate 4 using convex combination of permutations and add enable_mhc_k4_shortcut feature gate --- src/maxtext/configs/base.yml | 9 ++++ src/maxtext/configs/types.py | 8 +++ src/maxtext/layers/mhc.py | 62 ++++++++++++++++++---- tests/unit/mhc_test.py | 99 +++++++++++++++++++++++++++++++----- 4 files changed, 154 insertions(+), 24 deletions(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 1bb0f59429..d79823b6ab 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -1224,6 +1224,15 @@ force_q_layout: false mhc_expansion_rate: 1 # The number of iterations for the Sinkhorn-Knopp algorithm. sinkhorn_iterations: 20 +# mHC-lite: https://openreview.net/pdf?id=5IJX6kvOif +# - Controls whether to generate the MHC doubly stochastic matrix via +# permutation-based convex combination rather than Sinkhorn-Knopp. +# - Practical only for a small mhc_expansion_rate (e.g., k=4), as the number +# of permutation matrices scales factorially (k!). +# - Ideally, this should be used whenever expansion rate is small as it removes +# the expensive sinkhorn iterations, the downside to this approach is that +# it is factorial in k. +enable_mhc_lite: False ################################## DeepSeek Engram ################################## # Indices of transformer layers where Engram are integrated; leave empty [] to disable. diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index f6a92bbb8a..86a5147fa2 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1407,6 +1407,14 @@ class ManifoldConstrainedHyperConnections(BaseModel): mhc_expansion_rate: PositiveInt = Field(1, description="The number of parallel streams in Hyper Connection.") sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.") + enable_mhc_lite: bool = Field( + False, + description=( + "Whether to generate the MHC doubly stochastic matrix via " + "permutation-based convex combination rather than Sinkhorn-Knopp. " + "Practical only for a small mhc_expansion_rate (e.g., k=4)." + ), + ) class DilocoParams(BaseModel): diff --git a/src/maxtext/layers/mhc.py b/src/maxtext/layers/mhc.py index ce700aafcd..03e0eb6eee 100644 --- a/src/maxtext/layers/mhc.py +++ b/src/maxtext/layers/mhc.py @@ -14,7 +14,10 @@ """DeepSeek Manifold-Constrained Hyper Connections (mHC) Layer.""" +import itertools +import math from typing import Callable + from flax import nnx import jax import jax.numpy as jnp @@ -25,6 +28,17 @@ from maxtext.layers.normalizations import RMSNorm +def get_permutation_matrices(k: int) -> Array: + """Generates all permutation matrices of size k. + + Reference: mHC-lite: https://openreview.net/pdf?id=5IJX6kvOif + Shape: (k!, k, k) + """ + perms = list(itertools.permutations(range(k))) + perms_array = jnp.array(perms) + return jnp.eye(k)[perms_array] + + def get_functions(expansion_rate: int): """Creates functions to broadcast a single feature stream into multiple @@ -118,6 +132,17 @@ def __init__( out_sharding=(None,), ) + if self.config.enable_mhc_lite: + num_perms = math.factorial(self.k) + res_out_dim = num_perms + res_beta_shape = (num_perms,) + res_beta_sharding = (None,) + self.permutation_matrices = get_permutation_matrices(self.k) + else: + res_out_dim = self.k * self.k + res_beta_shape = (self.k, self.k) + res_beta_sharding = (None, None) + # Weight matrices scale_init = nd_dense_init(1.0, "fan_in", "normal") in_axis = 0 @@ -126,7 +151,7 @@ def __init__( self.res_alpha = nnx.Param( scale_init( self.rngs.params(), - (self.k * self.dim, self.k * self.k), + (self.k * self.dim, res_out_dim), self.weight_dtype, in_axis=in_axis, out_axis=out_axis, @@ -156,8 +181,8 @@ def __init__( # Biases self.res_beta = nnx.Param( - default_bias_init(self.rngs.params(), (self.k, self.k), self.weight_dtype), - out_sharding=(None, None), + default_bias_init(self.rngs.params(), res_beta_shape, self.weight_dtype), + out_sharding=res_beta_sharding, ) self.pre_beta = nnx.Param( default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype), @@ -174,13 +199,30 @@ def res_mapping(self, x: Array): res_alpha = jnp.asarray(self.res_alpha[...], self.dtype) res_beta = jnp.asarray(self.res_beta[...], self.dtype) res_alpha_scale = jnp.asarray(self.res_alpha_scale[...], self.dtype) - # Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k) - h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision) - b, s, _ = h_res.shape - h_res = jnp.reshape(h_res, (b, s, self.k, self.k)) - intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :] - output = sinkhorn(intermediate, self.sinkhorn_iterations) - return output + + if self.config.enable_mhc_lite: + # Apply projection: (b, s, k*d) @ (k*d, k!) -> (b, s, k!) + h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision) + intermediate = res_alpha_scale * h_res + res_beta[None, None, :] + # Use float32 for numerical stability during softmax + weights = jax.nn.softmax(intermediate.astype(jnp.float32), axis=-1).astype(self.dtype) + # Sum the permutation matrices with the weights + permutation_matrices = self.permutation_matrices.astype(self.dtype) + output = jnp.einsum( + "bsn,nkm -> bskm", + weights, + permutation_matrices, + precision=self.matmul_precision, + ) + return output + else: + # Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k) + h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision) + b, s, _ = h_res.shape + h_res = jnp.reshape(h_res, (b, s, self.k, self.k)) + intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :] + output = sinkhorn(intermediate, self.sinkhorn_iterations) + return output def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int): """Helper function for both pre and post mappings.""" diff --git a/tests/unit/mhc_test.py b/tests/unit/mhc_test.py index e5c80ebb88..81cd7b58b4 100644 --- a/tests/unit/mhc_test.py +++ b/tests/unit/mhc_test.py @@ -15,21 +15,20 @@ """Test for DeepSeek Manifold-Constrained Hyper Connections (mHC).""" import unittest -import pytest - +from absl.testing import parameterized from flax import nnx from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp from jax.sharding import Mesh -import numpy as np - -from maxtext.configs import pyconfig from maxtext.common.common_types import HyperConnectionType +from maxtext.configs import pyconfig from maxtext.layers import attention_mla, linears, mhc, moe from maxtext.layers.initializers import nd_dense_init from maxtext.layers.normalizations import RMSNorm from maxtext.utils import maxtext_utils +import numpy as np +import pytest from tests.utils.test_helpers import get_test_config_path @@ -86,14 +85,15 @@ def test_doubly_stochastic_property(self): np.testing.assert_allclose(col_sums, jnp.ones_like(col_sums), atol=1e-3) -class TestMHC(unittest.TestCase): +class TestMHC(parameterized.TestCase): """Test for MHC module""" - def setUp(self): + def _setup_mhc(self, rate, enable_mhc_lite=False): + """Sets up the common configurations and modules for MHC testing.""" self.dim = 16 self.config = pyconfig.initialize( [None, get_test_config_path()], - run_name="test_mhc", + run_name=f"test_mhc_k{rate}", enable_checkpointing=False, model_name="deepseek-custom", per_device_batch_size=jax.device_count(), @@ -105,7 +105,8 @@ def setUp(self): # override override_model_config=True, base_emb_dim=self.dim, - mhc_expansion_rate=3, + mhc_expansion_rate=rate, + enable_mhc_lite=enable_mhc_lite, num_experts=4, num_experts_per_tok=2, engram_layers=[], @@ -135,7 +136,10 @@ def setUp(self): # Skip GPU due to NotImplementedError: dynamic grid bounds not supported in the Triton backend @pytest.mark.tpu_only - def test_moe_layer_output_shape(self): + @parameterized.named_parameters(("Rate3", 3), ("Rate4", 4)) + def test_moe_layer_output_shape(self, rate): + self._setup_mhc(rate) + with nn_partitioning.axis_rules(self.config.logical_axis_rules): module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs) layer = moe.RoutedMoE( @@ -154,12 +158,14 @@ def test_moe_layer_output_shape(self): b, s, k, d = self.x.shape output, metadata = module(self.pre_norm, layer, x=self.x, mhc_type=HyperConnectionType.MLP_MOE) # metadata includes load_balance_loss & moe_bias_updates - self.assertEqual(len(metadata), 2) + self.assertLen(metadata, 2) for key, value in metadata.items(): self.assertIsNotNone(value, f"Key '{key}' has a value of None") self.assertEqual(output.shape, (b, s, k, d)) - def test_dense_layer_output_shape(self): + @parameterized.named_parameters(("Rate3", 3), ("Rate4", 4)) + def test_dense_layer_output_shape(self, rate): + self._setup_mhc(rate) with nn_partitioning.axis_rules(self.config.logical_axis_rules): module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs) layer = linears.MlpBlock( @@ -180,8 +186,14 @@ def test_dense_layer_output_shape(self): self.assertDictEqual(metadata, {}) self.assertEqual(output.shape, (b, s, k, d)) - def test_attention_layer_output_shape(self): - inputs_shape = (self.config.per_device_batch_size, self.config.max_target_length, self.config.emb_dim) + @parameterized.named_parameters(("Rate3", 3), ("Rate4", 4)) + def test_attention_layer_output_shape(self, rate): + self._setup_mhc(rate) + inputs_shape = ( + self.config.per_device_batch_size, + self.config.max_target_length, + self.config.emb_dim, + ) with nn_partitioning.axis_rules(self.config.logical_axis_rules): module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs) layer = attention_mla.MLA( @@ -219,6 +231,65 @@ def test_attention_layer_output_shape(self): self.assertDictEqual(metadata, {}) self.assertEqual(output.shape, (b, s, k, d)) + def test_mhc_lite_doubly_stochastic(self): + """Verify that mHC-lite output is doubly stochastic (rows/cols sum to 1).""" + self._setup_mhc(4, enable_mhc_lite=True) + with nn_partitioning.axis_rules(self.config.logical_axis_rules): + module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs) + + b, s, k, d = self.x.shape + + # Generate random input X + random_x = jax.random.normal(jax.random.PRNGKey(42), (b, s, k * d)) + norm_x = module.mhc_norm(random_x) + + # Output from mHC-lite mapping + res_mapping_out = module.res_mapping(norm_x) + + row_sums = jnp.sum(res_mapping_out, axis=-1) + col_sums = jnp.sum(res_mapping_out, axis=-2) + + # Check if sums are close to 1.0 + np.testing.assert_allclose(row_sums, jnp.ones_like(row_sums), atol=1e-2) + np.testing.assert_allclose(col_sums, jnp.ones_like(col_sums), atol=1e-2) + + def test_feature_flag_gates_lite(self): + """Verify that setting enable_mhc_lite=False falls back to Sinkhorn.""" + self.dim = 16 + self.config = pyconfig.initialize( + [None, get_test_config_path()], + run_name="test_mhc_lite_gated", + enable_checkpointing=False, + model_name="deepseek-custom", + per_device_batch_size=4, + max_target_length=7, + max_prefill_predict_length=7, + attention="dot_product", + routed_bias_update_rate=0.01, + load_balance_loss_weight=0.02, + # override + override_model_config=True, + base_emb_dim=self.dim, + mhc_expansion_rate=4, + enable_mhc_lite=False, + num_experts=4, + num_experts_per_tok=2, + engram_layers=[], + ) + devices_array = maxtext_utils.create_device_mesh(self.config) + self.mesh = Mesh(devices_array, self.config.mesh_axes) + self.rngs = nnx.Rngs(params=jax.random.key(0), dropout=jax.random.key(42)) + + with nn_partitioning.axis_rules(self.config.logical_axis_rules): + module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs) + + # Shape of res_alpha should be (4*16, 4*4) = (64, 16) instead of (64, 24) + self.assertEqual(module.res_alpha.shape, (64, 16)) + # Shape of res_beta should be (4, 4) instead of (24,) + self.assertEqual(module.res_beta.shape, (4, 4)) + # Permutation matrices shouldn't be defined + self.assertFalse(hasattr(module, "permutation_matrices")) + if __name__ == "__main__": unittest.main()