Skip to content

Commit 8e83338

Browse files
committed
Optimize mHC for expansion rate 4 using convex combination of permutations and add enable_mhc_k4_shortcut feature gate
1 parent 7c8d658 commit 8e83338

5 files changed

Lines changed: 163 additions & 27 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,4 @@ gha-creds-*.json
155155

156156
# vscode workspace
157157
maxtext.code-workspace
158+
.jetskicli/

src/maxtext/configs/base.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,6 +1224,12 @@ force_q_layout: false
12241224
mhc_expansion_rate: 1
12251225
# The number of iterations for the Sinkhorn-Knopp algorithm.
12261226
sinkhorn_iterations: 20
1227+
# mHC-lite: https://openreview.net/pdf?id=5IJX6kvOif
1228+
# - Controls whether to generate the MHC doubly stochastic matrix via
1229+
# permutation-based convex combination rather than Sinkhorn-Knopp.
1230+
# - Practical only for a small mhc_expansion_rate (e.g., k=4), as the number
1231+
# of permutation matrices scales factorially (k!).
1232+
enable_mhc_lite: False
12271233

12281234
################################## DeepSeek Engram ##################################
12291235
# Indices of transformer layers where Engram are integrated; leave empty [] to disable.

src/maxtext/configs/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,6 +1407,15 @@ class ManifoldConstrainedHyperConnections(BaseModel):
14071407

14081408
mhc_expansion_rate: PositiveInt = Field(1, description="The number of parallel streams in Hyper Connection.")
14091409
sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.")
1410+
enable_mhc_lite: bool = Field(
1411+
False,
1412+
description=(
1413+
"Whether to generate the MHC doubly stochastic matrix via "
1414+
"permutation-based convex combination rather than Sinkhorn-Knopp. "
1415+
"Practical only for a small mhc_expansion_rate (e.g., k=4)."
1416+
),
1417+
)
1418+
14101419

14111420

14121421
class DilocoParams(BaseModel):

src/maxtext/layers/mhc.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,35 @@
1414

1515
"""DeepSeek Manifold-Constrained Hyper Connections (mHC) Layer."""
1616

17+
import itertools
18+
import math
1719
from typing import Callable
20+
1821
from flax import nnx
1922
import jax
2023
import jax.numpy as jnp
2124
from jax.sharding import Mesh
2225
from maxtext.common.common_types import Array, Config
2326
from maxtext.common.common_types import HyperConnectionType
24-
from maxtext.layers.initializers import default_bias_init, default_scalar_init, nd_dense_init
27+
from maxtext.layers.initializers import (
28+
default_bias_init,
29+
default_scalar_init,
30+
nd_dense_init,
31+
)
2532
from maxtext.layers.normalizations import RMSNorm
2633

2734

35+
def get_permutation_matrices(k: int) -> Array:
36+
"""Generates all permutation matrices of size k.
37+
38+
Reference: mHC-lite: https://openreview.net/pdf?id=5IJX6kvOif
39+
Shape: (k!, k, k)
40+
"""
41+
perms = list(itertools.permutations(range(k)))
42+
perms_array = jnp.array(perms)
43+
return jnp.eye(k)[perms_array]
44+
45+
2846
def get_functions(expansion_rate: int):
2947
"""Creates functions to broadcast a single feature stream into multiple
3048
@@ -118,6 +136,17 @@ def __init__(
118136
out_sharding=(None,),
119137
)
120138

139+
if self.config.enable_mhc_lite:
140+
num_perms = math.factorial(self.k)
141+
res_out_dim = num_perms
142+
res_beta_shape = (num_perms,)
143+
res_beta_sharding = (None,)
144+
self.permutation_matrices = get_permutation_matrices(self.k)
145+
else:
146+
res_out_dim = self.k * self.k
147+
res_beta_shape = (self.k, self.k)
148+
res_beta_sharding = (None, None)
149+
121150
# Weight matrices
122151
scale_init = nd_dense_init(1.0, "fan_in", "normal")
123152
in_axis = 0
@@ -126,7 +155,7 @@ def __init__(
126155
self.res_alpha = nnx.Param(
127156
scale_init(
128157
self.rngs.params(),
129-
(self.k * self.dim, self.k * self.k),
158+
(self.k * self.dim, res_out_dim),
130159
self.weight_dtype,
131160
in_axis=in_axis,
132161
out_axis=out_axis,
@@ -156,8 +185,8 @@ def __init__(
156185

157186
# Biases
158187
self.res_beta = nnx.Param(
159-
default_bias_init(self.rngs.params(), (self.k, self.k), self.weight_dtype),
160-
out_sharding=(None, None),
188+
default_bias_init(self.rngs.params(), res_beta_shape, self.weight_dtype),
189+
out_sharding=res_beta_sharding,
161190
)
162191
self.pre_beta = nnx.Param(
163192
default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype),
@@ -174,13 +203,30 @@ def res_mapping(self, x: Array):
174203
res_alpha = jnp.asarray(self.res_alpha[...], self.dtype)
175204
res_beta = jnp.asarray(self.res_beta[...], self.dtype)
176205
res_alpha_scale = jnp.asarray(self.res_alpha_scale[...], self.dtype)
177-
# Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k)
178-
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
179-
b, s, _ = h_res.shape
180-
h_res = jnp.reshape(h_res, (b, s, self.k, self.k))
181-
intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :]
182-
output = sinkhorn(intermediate, self.sinkhorn_iterations)
183-
return output
206+
207+
if self.config.enable_mhc_lite:
208+
# Apply projection: (b, s, k*d) @ (k*d, k!) -> (b, s, k!)
209+
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
210+
intermediate = res_alpha_scale * h_res + res_beta[None, None, :]
211+
# Use float32 for numerical stability during softmax
212+
weights = jax.nn.softmax(intermediate.astype(jnp.float32), axis=-1).astype(self.dtype)
213+
# Sum the permutation matrices with the weights
214+
permutation_matrices = self.permutation_matrices.astype(self.dtype)
215+
output = jnp.einsum(
216+
"bsn,nkm -> bskm",
217+
weights,
218+
permutation_matrices,
219+
precision=self.matmul_precision,
220+
)
221+
return output
222+
else:
223+
# Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k)
224+
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
225+
b, s, _ = h_res.shape
226+
h_res = jnp.reshape(h_res, (b, s, self.k, self.k))
227+
intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :]
228+
output = sinkhorn(intermediate, self.sinkhorn_iterations)
229+
return output
184230

185231
def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int):
186232
"""Helper function for both pre and post mappings."""

tests/unit/mhc_test.py

Lines changed: 90 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414

1515
"""Test for DeepSeek Manifold-Constrained Hyper Connections (mHC)."""
1616

17-
import unittest
18-
import pytest
19-
17+
from absl.testing import absltest
18+
from absl.testing import parameterized
2019
from flax import nnx
2120
from flax.linen import partitioning as nn_partitioning
2221
import jax
2322
import jax.numpy as jnp
2423
from jax.sharding import Mesh
2524
import numpy as np
25+
import pytest
2626

2727
from maxtext.configs import pyconfig
2828
from maxtext.common.common_types import HyperConnectionType
@@ -33,7 +33,7 @@
3333
from tests.utils.test_helpers import get_test_config_path
3434

3535

36-
class TestExpandReduce(unittest.TestCase):
36+
class TestExpandReduce(absltest.TestCase):
3737
"""Unit tests for MHC dimension expansion and reduction operations."""
3838

3939
def setUp(self):
@@ -65,7 +65,7 @@ def test_value_identity(self):
6565
np.testing.assert_allclose(out, expected, rtol=1e-5)
6666

6767

68-
class TestSinkhorn(unittest.TestCase):
68+
class TestSinkhorn(absltest.TestCase):
6969
"""Unit tests for MHC Sinkhorn Algorithm."""
7070

7171
def setUp(self):
@@ -86,17 +86,19 @@ def test_doubly_stochastic_property(self):
8686
np.testing.assert_allclose(col_sums, jnp.ones_like(col_sums), atol=1e-3)
8787

8888

89-
class TestMHC(unittest.TestCase):
89+
class TestMHC(parameterized.TestCase):
9090
"""Test for MHC module"""
9191

92-
def setUp(self):
92+
def _setup_mhc(self, rate, enable_mhc_lite=False):
93+
"""Sets up the common configurations and modules for MHC testing."""
9394
self.dim = 16
9495
self.config = pyconfig.initialize(
9596
[None, get_test_config_path()],
96-
run_name="test_mhc",
97+
skip_jax_distributed_system=True,
98+
run_name=f"test_mhc_k{rate}",
9799
enable_checkpointing=False,
98100
model_name="deepseek-custom",
99-
per_device_batch_size=jax.device_count(),
101+
per_device_batch_size=max(4, jax.device_count()),
100102
max_target_length=7,
101103
max_prefill_predict_length=7,
102104
attention="dot_product",
@@ -105,7 +107,8 @@ def setUp(self):
105107
# override
106108
override_model_config=True,
107109
base_emb_dim=self.dim,
108-
mhc_expansion_rate=3,
110+
mhc_expansion_rate=rate,
111+
enable_mhc_lite=enable_mhc_lite,
109112
num_experts=4,
110113
num_experts_per_tok=2,
111114
engram_layers=[],
@@ -135,7 +138,10 @@ def setUp(self):
135138

136139
# Skip GPU due to NotImplementedError: dynamic grid bounds not supported in the Triton backend
137140
@pytest.mark.tpu_only
138-
def test_moe_layer_output_shape(self):
141+
@parameterized.named_parameters(("Rate3", 3), ("Rate4", 4))
142+
def test_moe_layer_output_shape(self, rate):
143+
self._setup_mhc(rate)
144+
139145
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
140146
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)
141147
layer = moe.RoutedMoE(
@@ -154,12 +160,14 @@ def test_moe_layer_output_shape(self):
154160
b, s, k, d = self.x.shape
155161
output, metadata = module(self.pre_norm, layer, x=self.x, mhc_type=HyperConnectionType.MLP_MOE)
156162
# metadata includes load_balance_loss & moe_bias_updates
157-
self.assertEqual(len(metadata), 2)
163+
self.assertLen(metadata, 2)
158164
for key, value in metadata.items():
159165
self.assertIsNotNone(value, f"Key '{key}' has a value of None")
160166
self.assertEqual(output.shape, (b, s, k, d))
161167

162-
def test_dense_layer_output_shape(self):
168+
@parameterized.named_parameters(("Rate3", 3), ("Rate4", 4))
169+
def test_dense_layer_output_shape(self, rate):
170+
self._setup_mhc(rate)
163171
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
164172
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)
165173
layer = linears.MlpBlock(
@@ -180,8 +188,14 @@ def test_dense_layer_output_shape(self):
180188
self.assertDictEqual(metadata, {})
181189
self.assertEqual(output.shape, (b, s, k, d))
182190

183-
def test_attention_layer_output_shape(self):
184-
inputs_shape = (self.config.per_device_batch_size, self.config.max_target_length, self.config.emb_dim)
191+
@parameterized.named_parameters(("Rate3", 3), ("Rate4", 4))
192+
def test_attention_layer_output_shape(self, rate):
193+
self._setup_mhc(rate)
194+
inputs_shape = (
195+
self.config.per_device_batch_size,
196+
self.config.max_target_length,
197+
self.config.emb_dim,
198+
)
185199
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
186200
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)
187201
layer = attention_mla.MLA(
@@ -219,6 +233,66 @@ def test_attention_layer_output_shape(self):
219233
self.assertDictEqual(metadata, {})
220234
self.assertEqual(output.shape, (b, s, k, d))
221235

236+
def test_mhc_lite_doubly_stochastic(self):
237+
"""Verify that mHC-lite output is doubly stochastic (rows/cols sum to 1)."""
238+
self._setup_mhc(4, enable_mhc_lite=True)
239+
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
240+
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)
241+
242+
b, s, k, d = self.x.shape
243+
244+
# Generate random input X
245+
random_x = jax.random.normal(jax.random.PRNGKey(42), (b, s, k * d))
246+
norm_x = module.mhc_norm(random_x)
247+
248+
# Output from mHC-lite mapping
249+
res_mapping_out = module.res_mapping(norm_x)
250+
251+
row_sums = jnp.sum(res_mapping_out, axis=-1)
252+
col_sums = jnp.sum(res_mapping_out, axis=-2)
253+
254+
# Check if sums are close to 1.0
255+
np.testing.assert_allclose(row_sums, jnp.ones_like(row_sums), atol=1e-2)
256+
np.testing.assert_allclose(col_sums, jnp.ones_like(col_sums), atol=1e-2)
257+
258+
def test_feature_flag_gates_lite(self):
259+
"""Verify that setting enable_mhc_lite=False falls back to Sinkhorn."""
260+
self.dim = 16
261+
self.config = pyconfig.initialize(
262+
[None, get_test_config_path()],
263+
skip_jax_distributed_system=True,
264+
run_name="test_mhc_lite_gated",
265+
enable_checkpointing=False,
266+
model_name="deepseek-custom",
267+
per_device_batch_size=4,
268+
max_target_length=7,
269+
max_prefill_predict_length=7,
270+
attention="dot_product",
271+
routed_bias_update_rate=0.01,
272+
load_balance_loss_weight=0.02,
273+
# override
274+
override_model_config=True,
275+
base_emb_dim=self.dim,
276+
mhc_expansion_rate=4,
277+
enable_mhc_lite=False,
278+
num_experts=4,
279+
num_experts_per_tok=2,
280+
engram_layers=[],
281+
)
282+
devices_array = maxtext_utils.create_device_mesh(self.config)
283+
self.mesh = Mesh(devices_array, self.config.mesh_axes)
284+
self.rngs = nnx.Rngs(params=jax.random.key(0), dropout=jax.random.key(42))
285+
286+
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
287+
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)
288+
289+
# Shape of res_alpha should be (4*16, 4*4) = (64, 16) instead of (64, 24)
290+
self.assertEqual(module.res_alpha.shape, (64, 16))
291+
# Shape of res_beta should be (4, 4) instead of (24,)
292+
self.assertEqual(module.res_beta.shape, (4, 4))
293+
# Permutation matrices shouldn't be defined
294+
self.assertFalse(hasattr(module, "permutation_matrices"))
295+
222296

223297
if __name__ == "__main__":
224-
unittest.main()
298+
absltest.main()

0 commit comments

Comments
 (0)