Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,15 @@ force_q_layout: false
mhc_expansion_rate: 1
# The number of iterations for the Sinkhorn-Knopp algorithm.
sinkhorn_iterations: 20
# mHC-lite: https://openreview.net/pdf?id=5IJX6kvOif
# - Controls whether to generate the MHC doubly stochastic matrix via
# permutation-based convex combination rather than Sinkhorn-Knopp.
# - Practical only for a small mhc_expansion_rate (e.g., k=4), as the number
# of permutation matrices scales factorially (k!).
# - Ideally, this should be used whenever expansion rate is small as it removes
# the expensive sinkhorn iterations, the downside to this approach is that
# it is factorial in k.
enable_mhc_lite: False

################################## DeepSeek Engram ##################################
# Indices of transformer layers where Engram are integrated; leave empty [] to disable.
Expand Down
8 changes: 8 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,6 +1407,14 @@ class ManifoldConstrainedHyperConnections(BaseModel):

mhc_expansion_rate: PositiveInt = Field(1, description="The number of parallel streams in Hyper Connection.")
sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.")
enable_mhc_lite: bool = Field(
False,
description=(
"Whether to generate the MHC doubly stochastic matrix via "
"permutation-based convex combination rather than Sinkhorn-Knopp. "
"Practical only for a small mhc_expansion_rate (e.g., k=4)."
),
)


class DilocoParams(BaseModel):
Expand Down
62 changes: 52 additions & 10 deletions src/maxtext/layers/mhc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

"""DeepSeek Manifold-Constrained Hyper Connections (mHC) Layer."""

import itertools
import math
from typing import Callable

from flax import nnx
import jax
import jax.numpy as jnp
Expand All @@ -25,6 +28,17 @@
from maxtext.layers.normalizations import RMSNorm


def get_permutation_matrices(k: int) -> Array:
"""Generates all permutation matrices of size k.

Reference: mHC-lite: https://openreview.net/pdf?id=5IJX6kvOif
Shape: (k!, k, k)
"""
perms = list(itertools.permutations(range(k)))
perms_array = jnp.array(perms)
return jnp.eye(k)[perms_array]


def get_functions(expansion_rate: int):
"""Creates functions to broadcast a single feature stream into multiple

Expand Down Expand Up @@ -118,6 +132,17 @@ def __init__(
out_sharding=(None,),
)

if self.config.enable_mhc_lite:
num_perms = math.factorial(self.k)
res_out_dim = num_perms
res_beta_shape = (num_perms,)
res_beta_sharding = (None,)
self.permutation_matrices = get_permutation_matrices(self.k)
else:
res_out_dim = self.k * self.k
res_beta_shape = (self.k, self.k)
res_beta_sharding = (None, None)

# Weight matrices
scale_init = nd_dense_init(1.0, "fan_in", "normal")
in_axis = 0
Expand All @@ -126,7 +151,7 @@ def __init__(
self.res_alpha = nnx.Param(
scale_init(
self.rngs.params(),
(self.k * self.dim, self.k * self.k),
(self.k * self.dim, res_out_dim),
self.weight_dtype,
in_axis=in_axis,
out_axis=out_axis,
Expand Down Expand Up @@ -156,8 +181,8 @@ def __init__(

# Biases
self.res_beta = nnx.Param(
default_bias_init(self.rngs.params(), (self.k, self.k), self.weight_dtype),
out_sharding=(None, None),
default_bias_init(self.rngs.params(), res_beta_shape, self.weight_dtype),
out_sharding=res_beta_sharding,
)
self.pre_beta = nnx.Param(
default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype),
Expand All @@ -174,13 +199,30 @@ def res_mapping(self, x: Array):
res_alpha = jnp.asarray(self.res_alpha[...], self.dtype)
res_beta = jnp.asarray(self.res_beta[...], self.dtype)
res_alpha_scale = jnp.asarray(self.res_alpha_scale[...], self.dtype)
# Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k)
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
b, s, _ = h_res.shape
h_res = jnp.reshape(h_res, (b, s, self.k, self.k))
intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :]
output = sinkhorn(intermediate, self.sinkhorn_iterations)
return output

if self.config.enable_mhc_lite:
Comment thread
dandragona marked this conversation as resolved.
# Apply projection: (b, s, k*d) @ (k*d, k!) -> (b, s, k!)
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
intermediate = res_alpha_scale * h_res + res_beta[None, None, :]
# Use float32 for numerical stability during softmax
weights = jax.nn.softmax(intermediate.astype(jnp.float32), axis=-1).astype(self.dtype)
# Sum the permutation matrices with the weights
permutation_matrices = self.permutation_matrices.astype(self.dtype)
output = jnp.einsum(
"bsn,nkm -> bskm",
weights,
permutation_matrices,
precision=self.matmul_precision,
)
return output
else:
# Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k)
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
b, s, _ = h_res.shape
h_res = jnp.reshape(h_res, (b, s, self.k, self.k))
intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :]
output = sinkhorn(intermediate, self.sinkhorn_iterations)
return output

def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int):
"""Helper function for both pre and post mappings."""
Expand Down
99 changes: 85 additions & 14 deletions tests/unit/mhc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,20 @@
"""Test for DeepSeek Manifold-Constrained Hyper Connections (mHC)."""

import unittest
import pytest

from absl.testing import parameterized
from flax import nnx
from flax.linen import partitioning as nn_partitioning
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
import numpy as np

from maxtext.configs import pyconfig
from maxtext.common.common_types import HyperConnectionType
from maxtext.configs import pyconfig
from maxtext.layers import attention_mla, linears, mhc, moe
from maxtext.layers.initializers import nd_dense_init
from maxtext.layers.normalizations import RMSNorm
from maxtext.utils import maxtext_utils
import numpy as np
import pytest
from tests.utils.test_helpers import get_test_config_path


Expand Down Expand Up @@ -86,14 +85,15 @@ def test_doubly_stochastic_property(self):
np.testing.assert_allclose(col_sums, jnp.ones_like(col_sums), atol=1e-3)


class TestMHC(unittest.TestCase):
class TestMHC(parameterized.TestCase):
"""Test for MHC module"""

def setUp(self):
def _setup_mhc(self, rate, enable_mhc_lite=False):
"""Sets up the common configurations and modules for MHC testing."""
self.dim = 16
self.config = pyconfig.initialize(
[None, get_test_config_path()],
run_name="test_mhc",
run_name=f"test_mhc_k{rate}",
enable_checkpointing=False,
model_name="deepseek-custom",
per_device_batch_size=jax.device_count(),
Expand All @@ -105,7 +105,8 @@ def setUp(self):
# override
override_model_config=True,
base_emb_dim=self.dim,
mhc_expansion_rate=3,
mhc_expansion_rate=rate,
enable_mhc_lite=enable_mhc_lite,
num_experts=4,
num_experts_per_tok=2,
engram_layers=[],
Expand Down Expand Up @@ -135,7 +136,10 @@ def setUp(self):

# Skip GPU due to NotImplementedError: dynamic grid bounds not supported in the Triton backend
@pytest.mark.tpu_only
def test_moe_layer_output_shape(self):
@parameterized.named_parameters(("Rate3", 3), ("Rate4", 4))
def test_moe_layer_output_shape(self, rate):
self._setup_mhc(rate)

with nn_partitioning.axis_rules(self.config.logical_axis_rules):
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)
layer = moe.RoutedMoE(
Expand All @@ -154,12 +158,14 @@ def test_moe_layer_output_shape(self):
b, s, k, d = self.x.shape
output, metadata = module(self.pre_norm, layer, x=self.x, mhc_type=HyperConnectionType.MLP_MOE)
# metadata includes load_balance_loss & moe_bias_updates
self.assertEqual(len(metadata), 2)
self.assertLen(metadata, 2)
for key, value in metadata.items():
self.assertIsNotNone(value, f"Key '{key}' has a value of None")
self.assertEqual(output.shape, (b, s, k, d))

def test_dense_layer_output_shape(self):
@parameterized.named_parameters(("Rate3", 3), ("Rate4", 4))
def test_dense_layer_output_shape(self, rate):
self._setup_mhc(rate)
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)
layer = linears.MlpBlock(
Expand All @@ -180,8 +186,14 @@ def test_dense_layer_output_shape(self):
self.assertDictEqual(metadata, {})
self.assertEqual(output.shape, (b, s, k, d))

def test_attention_layer_output_shape(self):
inputs_shape = (self.config.per_device_batch_size, self.config.max_target_length, self.config.emb_dim)
@parameterized.named_parameters(("Rate3", 3), ("Rate4", 4))
def test_attention_layer_output_shape(self, rate):
self._setup_mhc(rate)
inputs_shape = (
self.config.per_device_batch_size,
self.config.max_target_length,
self.config.emb_dim,
)
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)
layer = attention_mla.MLA(
Expand Down Expand Up @@ -219,6 +231,65 @@ def test_attention_layer_output_shape(self):
self.assertDictEqual(metadata, {})
self.assertEqual(output.shape, (b, s, k, d))

def test_mhc_lite_doubly_stochastic(self):
"""Verify that mHC-lite output is doubly stochastic (rows/cols sum to 1)."""
self._setup_mhc(4, enable_mhc_lite=True)
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)

b, s, k, d = self.x.shape

# Generate random input X
random_x = jax.random.normal(jax.random.PRNGKey(42), (b, s, k * d))
norm_x = module.mhc_norm(random_x)

# Output from mHC-lite mapping
res_mapping_out = module.res_mapping(norm_x)

row_sums = jnp.sum(res_mapping_out, axis=-1)
col_sums = jnp.sum(res_mapping_out, axis=-2)

# Check if sums are close to 1.0
np.testing.assert_allclose(row_sums, jnp.ones_like(row_sums), atol=1e-2)
np.testing.assert_allclose(col_sums, jnp.ones_like(col_sums), atol=1e-2)

def test_feature_flag_gates_lite(self):
"""Verify that setting enable_mhc_lite=False falls back to Sinkhorn."""
self.dim = 16
self.config = pyconfig.initialize(
[None, get_test_config_path()],
run_name="test_mhc_lite_gated",
enable_checkpointing=False,
model_name="deepseek-custom",
per_device_batch_size=4,
max_target_length=7,
max_prefill_predict_length=7,
attention="dot_product",
routed_bias_update_rate=0.01,
load_balance_loss_weight=0.02,
# override
override_model_config=True,
base_emb_dim=self.dim,
mhc_expansion_rate=4,
enable_mhc_lite=False,
num_experts=4,
num_experts_per_tok=2,
engram_layers=[],
)
devices_array = maxtext_utils.create_device_mesh(self.config)
self.mesh = Mesh(devices_array, self.config.mesh_axes)
self.rngs = nnx.Rngs(params=jax.random.key(0), dropout=jax.random.key(42))

with nn_partitioning.axis_rules(self.config.logical_axis_rules):
module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs)

# Shape of res_alpha should be (4*16, 4*4) = (64, 16) instead of (64, 24)
self.assertEqual(module.res_alpha.shape, (64, 16))
# Shape of res_beta should be (4, 4) instead of (24,)
self.assertEqual(module.res_beta.shape, (4, 4))
# Permutation matrices shouldn't be defined
self.assertFalse(hasattr(module, "permutation_matrices"))


if __name__ == "__main__":
unittest.main()
Loading