Skip to content

Commit b3a8bb4

Browse files
committed
Optimize mHC for expansion rate 4 using convex combination of permutations and add enable_mhc_k4_shortcut feature gate
1 parent 9cc5820 commit b3a8bb4

5 files changed

Lines changed: 204 additions & 47 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,6 +1217,8 @@ force_q_layout: false
12171217
mhc_expansion_rate: 1
12181218
# The number of iterations for the Sinkhorn-Knopp algorithm.
12191219
sinkhorn_iterations: 20
1220+
# Whether to enable the permutation-based convex combination shortcut when mhc_expansion_rate is 4.
1221+
enable_mhc_k4_shortcut: True
12201222

12211223
################################## DeepSeek Engram ##################################
12221224
# Indices of transformer layers where Engram are integrated; leave empty [] to disable.

src/maxtext/configs/types.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,10 @@ class ModelArchitecture(BaseModel):
506506
True,
507507
description="Whether to apply scale on query and key normalizations (default True).",
508508
)
509-
v_norm_with_scale: bool = Field(True, description="Whether to apply scale on value normalization (default True).")
509+
v_norm_with_scale: bool = Field(
510+
True,
511+
description="Whether to apply scale on value normalization (default True).",
512+
)
510513

511514

512515
class MTP(BaseModel):
@@ -685,14 +688,18 @@ class MoEGeneral(BaseModel):
685688
num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.")
686689
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
687690
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
688-
ragged_buffer_factor: float = Field(-1.0, description="Ragged buffer factor. If < 0, ragged buffer is worst case size.")
691+
ragged_buffer_factor: float = Field(
692+
-1.0,
693+
description="Ragged buffer factor. If < 0, ragged buffer is worst case size.",
694+
)
689695
moe_expert_input_dim: int = Field(
690696
-1,
691697
description="Dimension of tokens entering the MoE layer. If < 0, defaults to emb_dim.",
692698
)
693699
base_moe_mlp_dim: int = Field(-1, description="Intermediate dimension at MoE layer.")
694700
padded_base_moe_mlp_dim: Optional[int] = Field(
695-
None, description="Padded intermediate dimension at MoE layer for efficient GMM_v2 kernel execution."
701+
None,
702+
description="Padded intermediate dimension at MoE layer for efficient GMM_v2 kernel execution.",
696703
)
697704
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
698705
use_custom_sort_vjp: bool = Field(
@@ -873,7 +880,8 @@ class HardwareAndMesh(BaseModel):
873880
)
874881
custom_mesh: str = Field("", description="Available options: ['hybrid_ring_64x4', 'hybrid_ring_32x8']")
875882
custom_mesh_and_rule: CustomRule = Field(
876-
CustomRule.DEFAULT, description="Customized mesh and logical rules for granularity."
883+
CustomRule.DEFAULT,
884+
description="Customized mesh and logical rules for granularity.",
877885
)
878886
allow_split_physical_axes: bool = Field(False, description="Allow splitting physical axes for device mesh creation.")
879887
enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.")
@@ -882,7 +890,8 @@ class HardwareAndMesh(BaseModel):
882890
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")
883891
pure_nnx: bool = Field(False, description="Whether to enable pure NNX mode.")
884892
remove_size_one_mesh_axis_from_type: bool = Field(
885-
True, description="Whether to remove size one mesh axis from type through jax.config."
893+
True,
894+
description="Whether to remove size one mesh axis from type through jax.config.",
886895
)
887896

888897

@@ -903,7 +912,10 @@ class LayoutAndSharding(BaseModel):
903912
description="Allowed percentage of non-sharded parameters.",
904913
)
905914
shard_optimizer_over_data: bool = Field(False, description="Enable ZeRO-1 optimizer sharding over the data axis.")
906-
internal_compile: bool = Field(False, description="Use internal_compile to bypass open-source topology mappings.")
915+
internal_compile: bool = Field(
916+
False,
917+
description="Use internal_compile to bypass open-source topology mappings.",
918+
)
907919
internal_compile_num_devices: int = Field(-1, description="Number of devices when using internal_compile.")
908920
compile_xla_flags: str = Field("", description="Compiler options for compilation only.")
909921

@@ -950,7 +962,8 @@ class PipelineParallelism(BaseModel):
950962
"""Configuration for pipeline parallelism."""
951963

952964
pipeline_fsdp_ag_per_repeat: bool = Field(
953-
False, description="Enable weight prefetching for circular pipeline parallelism."
965+
False,
966+
description="Enable weight prefetching for circular pipeline parallelism.",
954967
)
955968
num_layers_per_pipeline_stage: int = Field(1, description="Number of layers to place on each pipeline stage.")
956969
num_pipeline_repeats: int = Field(
@@ -1194,7 +1207,10 @@ class OlmoGrainDataset(BaseModel):
11941207
``data_shuffle_seed``); only OLMo-specific fields are listed here.
11951208
"""
11961209

1197-
olmo_index_path: PathStr = Field("", description="Path or gs:// URI to the JSON index from build_olmo_npy_index.py.")
1210+
olmo_index_path: PathStr = Field(
1211+
"",
1212+
description="Path or gs:// URI to the JSON index from build_olmo_npy_index.py.",
1213+
)
11981214
olmo_path_remap_from: PathStr = Field(
11991215
"",
12001216
description="If set, rewrite index file paths starting with this prefix to olmo_path_remap_to.",
@@ -1279,19 +1295,24 @@ class Distillation(BaseModel):
12791295
distill_layer_indices: None | list = Field(None, description="Feature indices for feature loss.")
12801296
distill_alpha_end: Optional[float] = Field(None, description="Target alpha at end of training. None keeps alpha fixed.")
12811297
distill_alpha_schedule: Literal["constant", "linear", "cosine"] = Field(
1282-
"constant", description="Schedule type for alpha annealing ('constant', 'linear', or 'cosine')."
1298+
"constant",
1299+
description="Schedule type for alpha annealing ('constant', 'linear', or 'cosine').",
12831300
)
12841301
distill_temperature_end: Optional[float] = Field(
1285-
None, description="Target temperature at end of training. None keeps temperature fixed."
1302+
None,
1303+
description="Target temperature at end of training. None keeps temperature fixed.",
12861304
)
12871305
distill_temperature_schedule: Literal["constant", "linear", "cosine"] = Field(
1288-
"constant", description="Schedule type for temperature annealing ('constant', 'linear', or 'cosine')."
1306+
"constant",
1307+
description="Schedule type for temperature annealing ('constant', 'linear', or 'cosine').",
12891308
)
12901309
distill_beta_end: Optional[float] = Field(
1291-
None, description="Target beta_feature at end of training. None keeps beta fixed."
1310+
None,
1311+
description="Target beta_feature at end of training. None keeps beta fixed.",
12921312
)
12931313
distill_beta_schedule: Literal["constant", "linear", "cosine"] = Field(
1294-
"constant", description="Schedule type for beta annealing ('constant', 'linear', or 'cosine')."
1314+
"constant",
1315+
description="Schedule type for beta annealing ('constant', 'linear', or 'cosine').",
12951316
)
12961317

12971318
# --- Learn to init related parameters --
@@ -1314,11 +1335,13 @@ class Distillation(BaseModel):
13141335
)
13151336

13161337
attn_module_name: Optional[str] = Field(
1317-
None, description="Attention nnx module attribute name to augment with LTI logic"
1338+
None,
1339+
description="Attention nnx module attribute name to augment with LTI logic",
13181340
)
13191341

13201342
lti_layer_indices: Optional[list[int]] = Field(
1321-
None, description="List of layer indices to apply LTI modifications. If None, applied to all layers."
1343+
None,
1344+
description="List of layer indices to apply LTI modifications. If None, applied to all layers.",
13221345
)
13231346
# ---------------------------------------
13241347

@@ -1365,6 +1388,10 @@ class ManifoldConstrainedHyperConnections(BaseModel):
13651388

13661389
mhc_expansion_rate: PositiveInt = Field(1, description="The number of parallel streams in Hyper Connection.")
13671390
sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.")
1391+
enable_mhc_k4_shortcut: bool = Field(
1392+
True,
1393+
description="Whether to enable the permutation-based convex combination shortcut when mhc_expansion_rate is 4.",
1394+
)
13681395

13691396

13701397
class DilocoParams(BaseModel):
@@ -1655,7 +1682,8 @@ class Profiling(BaseModel):
16551682
tpu_num_chips_to_profile_per_task: int = Field(1, description="Specifies the number of TPU chips to profile per task.")
16561683
tpu_num_sparse_cores_to_trace: int = Field(2, description="Specifies the number of TPU chips to profile per task.")
16571684
tpu_num_sparse_core_tiles_to_trace: int = Field(
1658-
1, description="Specifies the number of tiles within each sparse core to trace on the TPU."
1685+
1,
1686+
description="Specifies the number of tiles within each sparse core to trace on the TPU.",
16591687
)
16601688
xprof_tpu_power_trace_level: XProfTPUPowerTraceMode = Field(
16611689
XProfTPUPowerTraceMode.POWER_TRACE_NONE,
@@ -2491,7 +2519,11 @@ def validate_and_set_hlo_dump_defaults():
24912519
)
24922520
for param_name, schedule, end_value in [
24932521
("distill_alpha", self.distill_alpha_schedule, self.distill_alpha_end),
2494-
("distill_temperature", self.distill_temperature_schedule, self.distill_temperature_end),
2522+
(
2523+
"distill_temperature",
2524+
self.distill_temperature_schedule,
2525+
self.distill_temperature_end,
2526+
),
24952527
("distill_beta", self.distill_beta_schedule, self.distill_beta_end),
24962528
]:
24972529
if schedule != "constant" and end_value is None:
@@ -3004,7 +3036,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
30043036
self.use_grpo = False
30053037

30063038
if self.use_batch_split_schedule:
3007-
if self.quantization and not self.quantization == "fp8_full":
3039+
if self.quantization and self.quantization != "fp8_full":
30083040
raise ValueError("Batch split quantization only supports `quantization=fp8_full`")
30093041

30103042
if self.opt_type == "muon" and self.decoder_block not in [

src/maxtext/layers/mhc.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,29 @@
1414

1515
"""DeepSeek Manifold-Constrained Hyper Connections (mHC) Layer."""
1616

17+
import itertools
1718
from typing import Callable
19+
1820
from flax import nnx
1921
import jax
2022
import jax.numpy as jnp
2123
from jax.sharding import Mesh
2224
from maxtext.common.common_types import Array, Config
2325
from maxtext.common.common_types import HyperConnectionType
24-
from maxtext.layers.initializers import default_bias_init, default_scalar_init, nd_dense_init
26+
from maxtext.layers.initializers import (
27+
default_bias_init,
28+
default_scalar_init,
29+
nd_dense_init,
30+
)
2531
from maxtext.layers.normalizations import RMSNorm
2632

2733

34+
def get_4x4_permutation_matrices():
35+
perms = list(itertools.permutations(range(4)))
36+
perms_array = jnp.array(perms)
37+
return jnp.eye(4)[perms_array]
38+
39+
2840
def get_functions(expansion_rate: int):
2941
"""Creates functions to broadcast a single feature stream into multiple
3042
@@ -118,6 +130,15 @@ def __init__(
118130
out_sharding=(None,),
119131
)
120132

133+
if self.k == 4 and self.config.enable_mhc_k4_shortcut:
134+
res_out_dim = 24
135+
res_beta_shape = (24,)
136+
res_beta_sharding = (None,)
137+
else:
138+
res_out_dim = self.k * self.k
139+
res_beta_shape = (self.k, self.k)
140+
res_beta_sharding = (None, None)
141+
121142
# Weight matrices
122143
scale_init = nd_dense_init(1.0, "fan_in", "normal")
123144
in_axis = 0
@@ -126,7 +147,7 @@ def __init__(
126147
self.res_alpha = nnx.Param(
127148
scale_init(
128149
self.rngs.params(),
129-
(self.k * self.dim, self.k * self.k),
150+
(self.k * self.dim, res_out_dim),
130151
self.weight_dtype,
131152
in_axis=in_axis,
132153
out_axis=out_axis,
@@ -156,8 +177,8 @@ def __init__(
156177

157178
# Biases
158179
self.res_beta = nnx.Param(
159-
default_bias_init(self.rngs.params(), (self.k, self.k), self.weight_dtype),
160-
out_sharding=(None, None),
180+
default_bias_init(self.rngs.params(), res_beta_shape, self.weight_dtype),
181+
out_sharding=res_beta_sharding,
161182
)
162183
self.pre_beta = nnx.Param(
163184
default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype),
@@ -174,13 +195,30 @@ def res_mapping(self, x: Array):
174195
res_alpha = jnp.asarray(self.res_alpha[...], self.dtype)
175196
res_beta = jnp.asarray(self.res_beta[...], self.dtype)
176197
res_alpha_scale = jnp.asarray(self.res_alpha_scale[...], self.dtype)
177-
# Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k)
178-
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
179-
b, s, _ = h_res.shape
180-
h_res = jnp.reshape(h_res, (b, s, self.k, self.k))
181-
intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :]
182-
output = sinkhorn(intermediate, self.sinkhorn_iterations)
183-
return output
198+
199+
if self.k == 4 and self.config.enable_mhc_k4_shortcut:
200+
# Apply projection: (b, s, k*d) @ (k*d, 24) -> (b, s, 24)
201+
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
202+
intermediate = res_alpha_scale * h_res + res_beta[None, None, :]
203+
# Use float32 for numerical stability during softmax
204+
weights = jax.nn.softmax(intermediate.astype(jnp.float32), axis=-1).astype(self.dtype)
205+
# Sum the 24 permutation matrices with the weights
206+
permutation_matrices_4x4 = get_4x4_permutation_matrices().astype(self.dtype)
207+
output = jnp.einsum(
208+
"bsn,nkm -> bskm",
209+
weights,
210+
permutation_matrices_4x4,
211+
precision=self.matmul_precision,
212+
)
213+
return output
214+
else:
215+
# Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k)
216+
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
217+
b, s, _ = h_res.shape
218+
h_res = jnp.reshape(h_res, (b, s, self.k, self.k))
219+
intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :]
220+
output = sinkhorn(intermediate, self.sinkhorn_iterations)
221+
return output
184222

185223
def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int):
186224
"""Helper function for both pre and post mappings."""

tests/unit/grain_data_processing_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,6 @@ def setUp(self):
464464
tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"),
465465
enable_checkpointing=False,
466466
)
467-
self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices)
468467

469468

470469
@pytest.mark.external_training

0 commit comments

Comments
 (0)