Skip to content

Commit 2c88ab1

Browse files
committed
Optimize mHC for expansion rate 4 using convex combination of permutations and add enable_mhc_k4_shortcut feature gate
1 parent 7c8d658 commit 2c88ab1

4 files changed

Lines changed: 154 additions & 24 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,6 +1224,15 @@ force_q_layout: false
12241224
mhc_expansion_rate: 1
12251225
# The number of iterations for the Sinkhorn-Knopp algorithm.
12261226
sinkhorn_iterations: 20
1227+
# mHC-lite: https://openreview.net/pdf?id=5IJX6kvOif
1228+
# - Controls whether to generate the MHC doubly stochastic matrix via
1229+
# permutation-based convex combination rather than Sinkhorn-Knopp.
1230+
# - Practical only for a small mhc_expansion_rate (e.g., k=4), as the number
1231+
# of permutation matrices scales factorially (k!).
1232+
# - Ideally, this should be used whenever expansion rate is small as it removes
1233+
# the expensive sinkhorn iterations, the downside to this approach is that
1234+
# it is factorial in k.
1235+
enable_mhc_lite: False
12271236

12281237
################################## DeepSeek Engram ##################################
12291238
# Indices of transformer layers where Engram are integrated; leave empty [] to disable.

src/maxtext/configs/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,6 +1407,14 @@ class ManifoldConstrainedHyperConnections(BaseModel):
14071407

14081408
mhc_expansion_rate: PositiveInt = Field(1, description="The number of parallel streams in Hyper Connection.")
14091409
sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.")
1410+
enable_mhc_lite: bool = Field(
1411+
False,
1412+
description=(
1413+
"Whether to generate the MHC doubly stochastic matrix via "
1414+
"permutation-based convex combination rather than Sinkhorn-Knopp. "
1415+
"Practical only for a small mhc_expansion_rate (e.g., k=4)."
1416+
),
1417+
)
14101418

14111419

14121420
class DilocoParams(BaseModel):

src/maxtext/layers/mhc.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
"""DeepSeek Manifold-Constrained Hyper Connections (mHC) Layer."""
1616

17+
import itertools
18+
import math
1719
from typing import Callable
20+
1821
from flax import nnx
1922
import jax
2023
import jax.numpy as jnp
@@ -25,6 +28,17 @@
2528
from maxtext.layers.normalizations import RMSNorm
2629

2730

31+
def get_permutation_matrices(k: int) -> Array:
32+
"""Generates all permutation matrices of size k.
33+
34+
Reference: mHC-lite: https://openreview.net/pdf?id=5IJX6kvOif
35+
Shape: (k!, k, k)
36+
"""
37+
perms = list(itertools.permutations(range(k)))
38+
perms_array = jnp.array(perms)
39+
return jnp.eye(k)[perms_array]
40+
41+
2842
def get_functions(expansion_rate: int):
2943
"""Creates functions to broadcast a single feature stream into multiple
3044
@@ -118,6 +132,17 @@ def __init__(
118132
out_sharding=(None,),
119133
)
120134

135+
if self.config.enable_mhc_lite:
136+
num_perms = math.factorial(self.k)
137+
res_out_dim = num_perms
138+
res_beta_shape = (num_perms,)
139+
res_beta_sharding = (None,)
140+
self.permutation_matrices = get_permutation_matrices(self.k)
141+
else:
142+
res_out_dim = self.k * self.k
143+
res_beta_shape = (self.k, self.k)
144+
res_beta_sharding = (None, None)
145+
121146
# Weight matrices
122147
scale_init = nd_dense_init(1.0, "fan_in", "normal")
123148
in_axis = 0
@@ -126,7 +151,7 @@ def __init__(
126151
self.res_alpha = nnx.Param(
127152
scale_init(
128153
self.rngs.params(),
129-
(self.k * self.dim, self.k * self.k),
154+
(self.k * self.dim, res_out_dim),
130155
self.weight_dtype,
131156
in_axis=in_axis,
132157
out_axis=out_axis,
@@ -156,8 +181,8 @@ def __init__(
156181

157182
# Biases
158183
self.res_beta = nnx.Param(
159-
default_bias_init(self.rngs.params(), (self.k, self.k), self.weight_dtype),
160-
out_sharding=(None, None),
184+
default_bias_init(self.rngs.params(), res_beta_shape, self.weight_dtype),
185+
out_sharding=res_beta_sharding,
161186
)
162187
self.pre_beta = nnx.Param(
163188
default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype),
@@ -174,13 +199,30 @@ def res_mapping(self, x: Array):
174199
res_alpha = jnp.asarray(self.res_alpha[...], self.dtype)
175200
res_beta = jnp.asarray(self.res_beta[...], self.dtype)
176201
res_alpha_scale = jnp.asarray(self.res_alpha_scale[...], self.dtype)
177-
# Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k)
178-
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
179-
b, s, _ = h_res.shape
180-
h_res = jnp.reshape(h_res, (b, s, self.k, self.k))
181-
intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :]
182-
output = sinkhorn(intermediate, self.sinkhorn_iterations)
183-
return output
202+
203+
if self.config.enable_mhc_lite:
204+
# Apply projection: (b, s, k*d) @ (k*d, k!) -> (b, s, k!)
205+
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
206+
intermediate = res_alpha_scale * h_res + res_beta[None, None, :]
207+
# Use float32 for numerical stability during softmax
208+
weights = jax.nn.softmax(intermediate.astype(jnp.float32), axis=-1).astype(self.dtype)
209+
# Sum the permutation matrices with the weights
210+
permutation_matrices = self.permutation_matrices.astype(self.dtype)
211+
output = jnp.einsum(
212+
"bsn,nkm -> bskm",
213+
weights,
214+
permutation_matrices,
215+
precision=self.matmul_precision,
216+
)
217+
return output
218+
else:
219+
# Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k)
220+
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
221+
b, s, _ = h_res.shape
222+
h_res = jnp.reshape(h_res, (b, s, self.k, self.k))
223+
intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :]
224+
output = sinkhorn(intermediate, self.sinkhorn_iterations)
225+
return output
184226

185227
def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int):
186228
"""Helper function for both pre and post mappings."""

tests/unit/mhc_test.py

Lines changed: 85 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,20 @@
1515
"""Test for DeepSeek Manifold-Constrained Hyper Connections (mHC)."""
1616

1717
import unittest
18-
import pytest
19-
18+
from absl.testing import parameterized
2019
from flax import nnx
2120
from flax.linen import partitioning as nn_partitioning
2221
import jax
2322
import jax.numpy as jnp
2423
from jax.sharding import Mesh
25-
import numpy as np
26-
27-
from maxtext.configs import pyconfig
2824
from maxtext.common.common_types import HyperConnectionType
25+
from maxtext.configs import pyconfig
2926
from maxtext.layers import attention_mla, linears, mhc, moe
3027
from maxtext.layers.initializers import nd_dense_init
3128
from maxtext.layers.normalizations import RMSNorm
3229
from maxtext.utils import maxtext_utils
30+
import numpy as np
31+
import pytest
3332
from tests.utils.test_helpers import get_test_config_path
3433

3534

@@ -86,14 +85,15 @@ def test_doubly_stochastic_property(self):
8685
np.testing.assert_allclose(col_sums, jnp.ones_like(col_sums), atol=1e-3)
8786

8887

89-
class TestMHC(unittest.TestCase):
88+
class TestMHC(parameterized.TestCase):
9089
"""Test for MHC module"""
9190

92-
def setUp(self):
91+
def _setup_mhc(self, rate, enable_mhc_lite=False):
92+
"""Sets up the common configurations and modules for MHC testing."""
9393
self.dim = 16
9494
self.config = pyconfig.initialize(
9595
[None, get_test_config_path()],
96-
run_name="test_mhc",
96+
run_name=f"test_mhc_k{rate}",
9797
enable_checkpointing=False,
9898
model_name="deepseek-custom",
9999
per_device_batch_size=jax.device_count(),
@@ -105,7 +105,8 @@ def setUp(self):
105105
# override
106106
override_model_config=True,
107107
base_emb_dim=self.dim,
108-
mhc_expansion_rate=3,
108+
mhc_expansion_rate=rate,
109+
enable_mhc_lite=enable_mhc_lite,
109110
num_experts=4,
110111
num_experts_per_tok=2,
111112
engram_layers=[],
@@ -135,7 +136,10 @@ def setUp(self):
135136

136137
# Skip GPU due to NotImplementedError: dynamic grid bounds not supported in the Triton backend
137138
@pytest.mark.tpu_only
138-
def test_moe_layer_output_shape(self):
139+
@parameterized.named_parameters(("Rate3", 3), ("Rate4", 4))
140+
def test_moe_layer_output_shape(self, rate):
141+
self._setup_mhc(rate)
142+
139143
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
140144
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)
141145
layer = moe.RoutedMoE(
@@ -154,12 +158,14 @@ def test_moe_layer_output_shape(self):
154158
b, s, k, d = self.x.shape
155159
output, metadata = module(self.pre_norm, layer, x=self.x, mhc_type=HyperConnectionType.MLP_MOE)
156160
# metadata includes load_balance_loss & moe_bias_updates
157-
self.assertEqual(len(metadata), 2)
161+
self.assertLen(metadata, 2)
158162
for key, value in metadata.items():
159163
self.assertIsNotNone(value, f"Key '{key}' has a value of None")
160164
self.assertEqual(output.shape, (b, s, k, d))
161165

162-
def test_dense_layer_output_shape(self):
166+
@parameterized.named_parameters(("Rate3", 3), ("Rate4", 4))
167+
def test_dense_layer_output_shape(self, rate):
168+
self._setup_mhc(rate)
163169
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
164170
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)
165171
layer = linears.MlpBlock(
@@ -180,8 +186,14 @@ def test_dense_layer_output_shape(self):
180186
self.assertDictEqual(metadata, {})
181187
self.assertEqual(output.shape, (b, s, k, d))
182188

183-
def test_attention_layer_output_shape(self):
184-
inputs_shape = (self.config.per_device_batch_size, self.config.max_target_length, self.config.emb_dim)
189+
@parameterized.named_parameters(("Rate3", 3), ("Rate4", 4))
190+
def test_attention_layer_output_shape(self, rate):
191+
self._setup_mhc(rate)
192+
inputs_shape = (
193+
self.config.per_device_batch_size,
194+
self.config.max_target_length,
195+
self.config.emb_dim,
196+
)
185197
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
186198
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)
187199
layer = attention_mla.MLA(
@@ -219,6 +231,65 @@ def test_attention_layer_output_shape(self):
219231
self.assertDictEqual(metadata, {})
220232
self.assertEqual(output.shape, (b, s, k, d))
221233

234+
def test_mhc_lite_doubly_stochastic(self):
235+
"""Verify that mHC-lite output is doubly stochastic (rows/cols sum to 1)."""
236+
self._setup_mhc(4, enable_mhc_lite=True)
237+
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
238+
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)
239+
240+
b, s, k, d = self.x.shape
241+
242+
# Generate random input X
243+
random_x = jax.random.normal(jax.random.PRNGKey(42), (b, s, k * d))
244+
norm_x = module.mhc_norm(random_x)
245+
246+
# Output from mHC-lite mapping
247+
res_mapping_out = module.res_mapping(norm_x)
248+
249+
row_sums = jnp.sum(res_mapping_out, axis=-1)
250+
col_sums = jnp.sum(res_mapping_out, axis=-2)
251+
252+
# Check if sums are close to 1.0
253+
np.testing.assert_allclose(row_sums, jnp.ones_like(row_sums), atol=1e-2)
254+
np.testing.assert_allclose(col_sums, jnp.ones_like(col_sums), atol=1e-2)
255+
256+
def test_feature_flag_gates_lite(self):
257+
"""Verify that setting enable_mhc_lite=False falls back to Sinkhorn."""
258+
self.dim = 16
259+
self.config = pyconfig.initialize(
260+
[None, get_test_config_path()],
261+
run_name="test_mhc_lite_gated",
262+
enable_checkpointing=False,
263+
model_name="deepseek-custom",
264+
per_device_batch_size=4,
265+
max_target_length=7,
266+
max_prefill_predict_length=7,
267+
attention="dot_product",
268+
routed_bias_update_rate=0.01,
269+
load_balance_loss_weight=0.02,
270+
# override
271+
override_model_config=True,
272+
base_emb_dim=self.dim,
273+
mhc_expansion_rate=4,
274+
enable_mhc_lite=False,
275+
num_experts=4,
276+
num_experts_per_tok=2,
277+
engram_layers=[],
278+
)
279+
devices_array = maxtext_utils.create_device_mesh(self.config)
280+
self.mesh = Mesh(devices_array, self.config.mesh_axes)
281+
self.rngs = nnx.Rngs(params=jax.random.key(0), dropout=jax.random.key(42))
282+
283+
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
284+
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)
285+
286+
# Shape of res_alpha should be (4*16, 4*4) = (64, 16) instead of (64, 24)
287+
self.assertEqual(module.res_alpha.shape, (64, 16))
288+
# Shape of res_beta should be (4, 4) instead of (24,)
289+
self.assertEqual(module.res_beta.shape, (4, 4))
290+
# Permutation matrices shouldn't be defined
291+
self.assertFalse(hasattr(module, "permutation_matrices"))
292+
222293

223294
if __name__ == "__main__":
224295
unittest.main()

0 commit comments

Comments
 (0)