Skip to content

Commit 944741b

Browse files
committed
Optimize mHC for expansion rate 4 using convex combination of permutations and add enable_mhc_k4_shortcut feature gate
1 parent 7c8d658 commit 944741b

4 files changed

Lines changed: 149 additions & 21 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,6 +1224,12 @@ force_q_layout: false
12241224
mhc_expansion_rate: 1
12251225
# The number of iterations for the Sinkhorn-Knopp algorithm.
12261226
sinkhorn_iterations: 20
1227+
# mHC-lite: https://openreview.net/pdf?id=5IJX6kvOif
1228+
# - Controls whether to generate the MHC doubly stochastic matrix via
1229+
# permutation-based convex combination rather than Sinkhorn-Knopp.
1230+
# - Practical only for a small mhc_expansion_rate (e.g., k=4), as the number
1231+
# of permutation matrices scales factorially (k!).
1232+
enable_mhc_lite: False
12271233

12281234
################################## DeepSeek Engram ##################################
12291235
# Indices of transformer layers where Engram are integrated; leave empty [] to disable.

src/maxtext/configs/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,6 +1407,14 @@ class ManifoldConstrainedHyperConnections(BaseModel):
14071407

14081408
mhc_expansion_rate: PositiveInt = Field(1, description="The number of parallel streams in Hyper Connection.")
14091409
sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.")
1410+
enable_mhc_lite: bool = Field(
1411+
False,
1412+
description=(
1413+
"Whether to generate the MHC doubly stochastic matrix via "
1414+
"permutation-based convex combination rather than Sinkhorn-Knopp. "
1415+
"Practical only for a small mhc_expansion_rate (e.g., k=4)."
1416+
),
1417+
)
14101418

14111419

14121420
class DilocoParams(BaseModel):

src/maxtext/layers/mhc.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
"""DeepSeek Manifold-Constrained Hyper Connections (mHC) Layer."""
1616

17+
import itertools
18+
import math
1719
from typing import Callable
20+
1821
from flax import nnx
1922
import jax
2023
import jax.numpy as jnp
@@ -25,6 +28,17 @@
2528
from maxtext.layers.normalizations import RMSNorm
2629

2730

31+
def get_permutation_matrices(k: int) -> Array:
32+
"""Generates all permutation matrices of size k.
33+
34+
Reference: mHC-lite: https://openreview.net/pdf?id=5IJX6kvOif
35+
Shape: (k!, k, k)
36+
"""
37+
perms = list(itertools.permutations(range(k)))
38+
perms_array = jnp.array(perms)
39+
return jnp.eye(k)[perms_array]
40+
41+
2842
def get_functions(expansion_rate: int):
2943
"""Creates functions to broadcast a single feature stream into multiple
3044
@@ -118,6 +132,17 @@ def __init__(
118132
out_sharding=(None,),
119133
)
120134

135+
if self.config.enable_mhc_lite:
136+
num_perms = math.factorial(self.k)
137+
res_out_dim = num_perms
138+
res_beta_shape = (num_perms,)
139+
res_beta_sharding = (None,)
140+
self.permutation_matrices = get_permutation_matrices(self.k)
141+
else:
142+
res_out_dim = self.k * self.k
143+
res_beta_shape = (self.k, self.k)
144+
res_beta_sharding = (None, None)
145+
121146
# Weight matrices
122147
scale_init = nd_dense_init(1.0, "fan_in", "normal")
123148
in_axis = 0
@@ -126,7 +151,7 @@ def __init__(
126151
self.res_alpha = nnx.Param(
127152
scale_init(
128153
self.rngs.params(),
129-
(self.k * self.dim, self.k * self.k),
154+
(self.k * self.dim, res_out_dim),
130155
self.weight_dtype,
131156
in_axis=in_axis,
132157
out_axis=out_axis,
@@ -156,8 +181,8 @@ def __init__(
156181

157182
# Biases
158183
self.res_beta = nnx.Param(
159-
default_bias_init(self.rngs.params(), (self.k, self.k), self.weight_dtype),
160-
out_sharding=(None, None),
184+
default_bias_init(self.rngs.params(), res_beta_shape, self.weight_dtype),
185+
out_sharding=res_beta_sharding,
161186
)
162187
self.pre_beta = nnx.Param(
163188
default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype),
@@ -174,13 +199,30 @@ def res_mapping(self, x: Array):
174199
res_alpha = jnp.asarray(self.res_alpha[...], self.dtype)
175200
res_beta = jnp.asarray(self.res_beta[...], self.dtype)
176201
res_alpha_scale = jnp.asarray(self.res_alpha_scale[...], self.dtype)
177-
# Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k)
178-
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
179-
b, s, _ = h_res.shape
180-
h_res = jnp.reshape(h_res, (b, s, self.k, self.k))
181-
intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :]
182-
output = sinkhorn(intermediate, self.sinkhorn_iterations)
183-
return output
202+
203+
if self.config.enable_mhc_lite:
204+
# Apply projection: (b, s, k*d) @ (k*d, k!) -> (b, s, k!)
205+
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
206+
intermediate = res_alpha_scale * h_res + res_beta[None, None, :]
207+
# Use float32 for numerical stability during softmax
208+
weights = jax.nn.softmax(intermediate.astype(jnp.float32), axis=-1).astype(self.dtype)
209+
# Sum the permutation matrices with the weights
210+
permutation_matrices = self.permutation_matrices.astype(self.dtype)
211+
output = jnp.einsum(
212+
"bsn,nkm -> bskm",
213+
weights,
214+
permutation_matrices,
215+
precision=self.matmul_precision,
216+
)
217+
return output
218+
else:
219+
# Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k)
220+
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
221+
b, s, _ = h_res.shape
222+
h_res = jnp.reshape(h_res, (b, s, self.k, self.k))
223+
intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :]
224+
output = sinkhorn(intermediate, self.sinkhorn_iterations)
225+
return output
184226

185227
def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int):
186228
"""Helper function for both pre and post mappings."""

tests/unit/mhc_test.py

Lines changed: 83 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
"""Test for DeepSeek Manifold-Constrained Hyper Connections (mHC)."""
1616

1717
import unittest
18-
import pytest
19-
18+
from absl.testing import parameterized
2019
from flax import nnx
2120
from flax.linen import partitioning as nn_partitioning
2221
import jax
2322
import jax.numpy as jnp
2423
from jax.sharding import Mesh
2524
import numpy as np
25+
import pytest
2626

2727
from maxtext.configs import pyconfig
2828
from maxtext.common.common_types import HyperConnectionType
@@ -86,14 +86,15 @@ def test_doubly_stochastic_property(self):
8686
np.testing.assert_allclose(col_sums, jnp.ones_like(col_sums), atol=1e-3)
8787

8888

89-
class TestMHC(unittest.TestCase):
89+
class TestMHC(parameterized.TestCase):
9090
"""Test for MHC module"""
9191

92-
def setUp(self):
92+
def _setup_mhc(self, rate, enable_mhc_lite=False):
93+
"""Sets up the common configurations and modules for MHC testing."""
9394
self.dim = 16
9495
self.config = pyconfig.initialize(
9596
[None, get_test_config_path()],
96-
run_name="test_mhc",
97+
run_name=f"test_mhc_k{rate}",
9798
enable_checkpointing=False,
9899
model_name="deepseek-custom",
99100
per_device_batch_size=jax.device_count(),
@@ -105,7 +106,8 @@ def setUp(self):
105106
# override
106107
override_model_config=True,
107108
base_emb_dim=self.dim,
108-
mhc_expansion_rate=3,
109+
mhc_expansion_rate=rate,
110+
enable_mhc_lite=enable_mhc_lite,
109111
num_experts=4,
110112
num_experts_per_tok=2,
111113
engram_layers=[],
@@ -135,7 +137,10 @@ def setUp(self):
135137

136138
# Skip GPU due to NotImplementedError: dynamic grid bounds not supported in the Triton backend
137139
@pytest.mark.tpu_only
138-
def test_moe_layer_output_shape(self):
140+
@parameterized.named_parameters(("Rate3", 3), ("Rate4", 4))
141+
def test_moe_layer_output_shape(self, rate):
142+
self._setup_mhc(rate)
143+
139144
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
140145
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)
141146
layer = moe.RoutedMoE(
@@ -154,12 +159,14 @@ def test_moe_layer_output_shape(self):
154159
b, s, k, d = self.x.shape
155160
output, metadata = module(self.pre_norm, layer, x=self.x, mhc_type=HyperConnectionType.MLP_MOE)
156161
# metadata includes load_balance_loss & moe_bias_updates
157-
self.assertEqual(len(metadata), 2)
162+
self.assertLen(metadata, 2)
158163
for key, value in metadata.items():
159164
self.assertIsNotNone(value, f"Key '{key}' has a value of None")
160165
self.assertEqual(output.shape, (b, s, k, d))
161166

162-
def test_dense_layer_output_shape(self):
167+
@parameterized.named_parameters(("Rate3", 3), ("Rate4", 4))
168+
def test_dense_layer_output_shape(self, rate):
169+
self._setup_mhc(rate)
163170
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
164171
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)
165172
layer = linears.MlpBlock(
@@ -180,8 +187,14 @@ def test_dense_layer_output_shape(self):
180187
self.assertDictEqual(metadata, {})
181188
self.assertEqual(output.shape, (b, s, k, d))
182189

183-
def test_attention_layer_output_shape(self):
184-
inputs_shape = (self.config.per_device_batch_size, self.config.max_target_length, self.config.emb_dim)
190+
@parameterized.named_parameters(("Rate3", 3), ("Rate4", 4))
191+
def test_attention_layer_output_shape(self, rate):
192+
self._setup_mhc(rate)
193+
inputs_shape = (
194+
self.config.per_device_batch_size,
195+
self.config.max_target_length,
196+
self.config.emb_dim,
197+
)
185198
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
186199
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)
187200
layer = attention_mla.MLA(
@@ -219,6 +232,65 @@ def test_attention_layer_output_shape(self):
219232
self.assertDictEqual(metadata, {})
220233
self.assertEqual(output.shape, (b, s, k, d))
221234

235+
def test_mhc_lite_doubly_stochastic(self):
236+
"""Verify that mHC-lite output is doubly stochastic (rows/cols sum to 1)."""
237+
self._setup_mhc(4, enable_mhc_lite=True)
238+
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
239+
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)
240+
241+
b, s, k, d = self.x.shape
242+
243+
# Generate random input X
244+
random_x = jax.random.normal(jax.random.PRNGKey(42), (b, s, k * d))
245+
norm_x = module.mhc_norm(random_x)
246+
247+
# Output from mHC-lite mapping
248+
res_mapping_out = module.res_mapping(norm_x)
249+
250+
row_sums = jnp.sum(res_mapping_out, axis=-1)
251+
col_sums = jnp.sum(res_mapping_out, axis=-2)
252+
253+
# Check if sums are close to 1.0
254+
np.testing.assert_allclose(row_sums, jnp.ones_like(row_sums), atol=1e-2)
255+
np.testing.assert_allclose(col_sums, jnp.ones_like(col_sums), atol=1e-2)
256+
257+
def test_feature_flag_gates_lite(self):
258+
"""Verify that setting enable_mhc_lite=False falls back to Sinkhorn."""
259+
self.dim = 16
260+
self.config = pyconfig.initialize(
261+
[None, get_test_config_path()],
262+
run_name="test_mhc_lite_gated",
263+
enable_checkpointing=False,
264+
model_name="deepseek-custom",
265+
per_device_batch_size=4,
266+
max_target_length=7,
267+
max_prefill_predict_length=7,
268+
attention="dot_product",
269+
routed_bias_update_rate=0.01,
270+
load_balance_loss_weight=0.02,
271+
# override
272+
override_model_config=True,
273+
base_emb_dim=self.dim,
274+
mhc_expansion_rate=4,
275+
enable_mhc_lite=False,
276+
num_experts=4,
277+
num_experts_per_tok=2,
278+
engram_layers=[],
279+
)
280+
devices_array = maxtext_utils.create_device_mesh(self.config)
281+
self.mesh = Mesh(devices_array, self.config.mesh_axes)
282+
self.rngs = nnx.Rngs(params=jax.random.key(0), dropout=jax.random.key(42))
283+
284+
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
285+
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)
286+
287+
# Shape of res_alpha should be (4*16, 4*4) = (64, 16) instead of (64, 24)
288+
self.assertEqual(module.res_alpha.shape, (64, 16))
289+
# Shape of res_beta should be (4, 4) instead of (24,)
290+
self.assertEqual(module.res_beta.shape, (4, 4))
291+
# Permutation matrices shouldn't be defined
292+
self.assertFalse(hasattr(module, "permutation_matrices"))
293+
222294

223295
if __name__ == "__main__":
224296
unittest.main()

0 commit comments

Comments
 (0)