Skip to content

Commit 1799cc1

Browse files
committed
Optimize mHC for expansion rate 4 using convex combination of permutations and add enable_mhc_k4_shortcut feature gate
1 parent 4d9f390 commit 1799cc1

5 files changed

Lines changed: 372 additions & 63 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,6 +1216,8 @@ force_q_layout: false
12161216
mhc_expansion_rate: 1
12171217
# The number of iterations for the Sinkhorn-Knopp algorithm.
12181218
sinkhorn_iterations: 20
1219+
# Whether to enable the permutation-based convex combination shortcut when mhc_expansion_rate is 4.
1220+
enable_mhc_k4_shortcut: True
12191221

12201222
################################## DeepSeek Engram ##################################
12211223
# Indices of transformer layers where Engram are integrated; leave empty [] to disable.

src/maxtext/configs/types.py

Lines changed: 92 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,8 @@ class Checkpointing(BaseModel):
345345
description="If True, enables checkpointing from remote TPU VMs instead of head node on pathways.",
346346
)
347347
enable_autocheckpoint: bool = Field(
348-
False, description="If True, enables autocheckpoint or preemption induced checkpointing."
348+
False,
349+
description="If True, enables autocheckpoint or preemption induced checkpointing.",
349350
)
350351

351352

@@ -495,9 +496,13 @@ class ModelArchitecture(BaseModel):
495496
)
496497
fused_mlp: bool = Field(False, description="If supported, fuse the MLP layers.")
497498
qk_norm_with_scale: bool = Field(
498-
True, description="Whether to apply scale on query and key normalizations (default True)."
499+
True,
500+
description="Whether to apply scale on query and key normalizations (default True).",
501+
)
502+
v_norm_with_scale: bool = Field(
503+
True,
504+
description="Whether to apply scale on value normalization (default True).",
499505
)
500-
v_norm_with_scale: bool = Field(True, description="Whether to apply scale on value normalization (default True).")
501506

502507

503508
class MTP(BaseModel):
@@ -542,9 +547,13 @@ class Attention(BaseModel):
542547
"global", description="The variant of attention to use."
543548
)
544549
share_kv_projections: bool = Field(
545-
False, description="If True, for global attention, Key and Value projections share the same weights."
550+
False,
551+
description="If True, for global attention, Key and Value projections share the same weights.",
552+
)
553+
global_num_kv_heads: int = Field(
554+
0,
555+
description="If greater than 0, sets the number of KV heads for global attention.",
546556
)
547-
global_num_kv_heads: int = Field(0, description="If greater than 0, sets the number of KV heads for global attention.")
548557
attention_sink: bool = Field(False, description="If True, enables attention sinks.")
549558
float32_qk_product: bool = Field(False, description="In dot-product attention, cast query-key product to fp32.")
550559
float32_logits: bool = Field(
@@ -672,14 +681,18 @@ class MoEGeneral(BaseModel):
672681
num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.")
673682
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
674683
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
675-
ragged_buffer_factor: float = Field(-1.0, description="Ragged buffer factor. If < 0, ragged buffer is worst case size.")
684+
ragged_buffer_factor: float = Field(
685+
-1.0,
686+
description="Ragged buffer factor. If < 0, ragged buffer is worst case size.",
687+
)
676688
moe_expert_input_dim: int = Field(
677689
-1,
678690
description="Dimension of tokens entering the MoE layer. If < 0, defaults to emb_dim.",
679691
)
680692
base_moe_mlp_dim: int = Field(-1, description="Intermediate dimension at MoE layer.")
681693
padded_base_moe_mlp_dim: Optional[int] = Field(
682-
None, description="Padded intermediate dimension at MoE layer for efficient GMM_v2 kernel execution."
694+
None,
695+
description="Padded intermediate dimension at MoE layer for efficient GMM_v2 kernel execution.",
683696
)
684697
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
685698
use_custom_sort_vjp: bool = Field(
@@ -860,7 +873,8 @@ class HardwareAndMesh(BaseModel):
860873
)
861874
custom_mesh: str = Field("", description="Available options: ['hybrid_ring_64x4', 'hybrid_ring_32x8']")
862875
custom_mesh_and_rule: CustomRule = Field(
863-
CustomRule.DEFAULT, description="Customized mesh and logical rules for granularity."
876+
CustomRule.DEFAULT,
877+
description="Customized mesh and logical rules for granularity.",
864878
)
865879
allow_split_physical_axes: bool = Field(False, description="Allow splitting physical axes for device mesh creation.")
866880
enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.")
@@ -869,7 +883,8 @@ class HardwareAndMesh(BaseModel):
869883
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")
870884
pure_nnx: bool = Field(False, description="Whether to enable pure NNX mode.")
871885
remove_size_one_mesh_axis_from_type: bool = Field(
872-
True, description="Whether to remove size one mesh axis from type through jax.config."
886+
True,
887+
description="Whether to remove size one mesh axis from type through jax.config.",
873888
)
874889

875890

@@ -890,7 +905,10 @@ class LayoutAndSharding(BaseModel):
890905
description="Allowed percentage of non-sharded parameters.",
891906
)
892907
shard_optimizer_over_data: bool = Field(False, description="Enable ZeRO-1 optimizer sharding over the data axis.")
893-
internal_compile: bool = Field(False, description="Use internal_compile to bypass open-source topology mappings.")
908+
internal_compile: bool = Field(
909+
False,
910+
description="Use internal_compile to bypass open-source topology mappings.",
911+
)
894912
internal_compile_num_devices: int = Field(-1, description="Number of devices when using internal_compile.")
895913
compile_xla_flags: str = Field("", description="Compiler options for compilation only.")
896914

@@ -937,7 +955,8 @@ class PipelineParallelism(BaseModel):
937955
"""Configuration for pipeline parallelism."""
938956

939957
pipeline_fsdp_ag_per_repeat: bool = Field(
940-
False, description="Enable weight prefetching for circular pipeline parallelism."
958+
False,
959+
description="Enable weight prefetching for circular pipeline parallelism.",
941960
)
942961
num_layers_per_pipeline_stage: int = Field(1, description="Number of layers to place on each pipeline stage.")
943962
num_pipeline_repeats: int = Field(
@@ -1046,7 +1065,8 @@ class Tokenizer(BaseModel):
10461065
use_chat_template: bool = Field(False, description="Whether to use the chat template for tokenization.")
10471066
chat_template_path: str = Field("", description="Path to chat template json file.")
10481067
chat_template: str = Field(
1049-
"", description="Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template."
1068+
"",
1069+
description="Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template.",
10501070
)
10511071
tokenize_train_data: bool = Field(True, description="If False, assumes the training dataset is pre-tokenized.")
10521072
tokenize_eval_data: bool = Field(True, description="If False, assumes the evaluation dataset is pre-tokenized.")
@@ -1138,7 +1158,8 @@ class GrainDataset(BaseModel):
11381158
description="Path to a JSON file specifying the mixture weights for Grain training data.",
11391159
)
11401160
grain_file_type: str = Field(
1141-
"arrayrecord", description="File type for Grain data. Supported: arrayrecord, tfrecord, parquet."
1161+
"arrayrecord",
1162+
description="File type for Grain data. Supported: arrayrecord, tfrecord, parquet.",
11421163
)
11431164
grain_use_elastic_iterator: bool = Field(
11441165
False,
@@ -1179,7 +1200,10 @@ class OlmoGrainDataset(BaseModel):
11791200
``data_shuffle_seed``); only OLMo-specific fields are listed here.
11801201
"""
11811202

1182-
olmo_index_path: PathStr = Field("", description="Path or gs:// URI to the JSON index from build_olmo_npy_index.py.")
1203+
olmo_index_path: PathStr = Field(
1204+
"",
1205+
description="Path or gs:// URI to the JSON index from build_olmo_npy_index.py.",
1206+
)
11831207
olmo_path_remap_from: PathStr = Field(
11841208
"",
11851209
description="If set, rewrite index file paths starting with this prefix to olmo_path_remap_to.",
@@ -1229,32 +1253,39 @@ class Distillation(BaseModel):
12291253

12301254
# --- Offline Distillation Field ---
12311255
offline_data_dir: Optional[str] = Field(
1232-
None, description="GCS or local path to the pre-generated ArrayRecord teacher data."
1256+
None,
1257+
description="GCS or local path to the pre-generated ArrayRecord teacher data.",
12331258
)
12341259

12351260
# --- Loss Params ---
12361261
distill_alpha: float = Field(0.5, description="Weight for the distillation loss component.")
12371262
distill_temperature: float = Field(1.0, description="Temperature for distillation softening.")
12381263
distill_beta: float = Field(0.0, description="Weight for the feature loss component. Use 0.0 to disable")
12391264
distill_feature_loss_type: Literal["cosine", "l2"] = Field(
1240-
"cosine", description="The type of loss to use for feature distillation ('cosine' or 'l2')."
1265+
"cosine",
1266+
description="The type of loss to use for feature distillation ('cosine' or 'l2').",
12411267
)
12421268
distill_layer_indices: None | list = Field(None, description="Feature indices for feature loss.")
12431269
distill_alpha_end: Optional[float] = Field(None, description="Target alpha at end of training. None keeps alpha fixed.")
12441270
distill_alpha_schedule: Literal["constant", "linear", "cosine"] = Field(
1245-
"constant", description="Schedule type for alpha annealing ('constant', 'linear', or 'cosine')."
1271+
"constant",
1272+
description="Schedule type for alpha annealing ('constant', 'linear', or 'cosine').",
12461273
)
12471274
distill_temperature_end: Optional[float] = Field(
1248-
None, description="Target temperature at end of training. None keeps temperature fixed."
1275+
None,
1276+
description="Target temperature at end of training. None keeps temperature fixed.",
12491277
)
12501278
distill_temperature_schedule: Literal["constant", "linear", "cosine"] = Field(
1251-
"constant", description="Schedule type for temperature annealing ('constant', 'linear', or 'cosine')."
1279+
"constant",
1280+
description="Schedule type for temperature annealing ('constant', 'linear', or 'cosine').",
12521281
)
12531282
distill_beta_end: Optional[float] = Field(
1254-
None, description="Target beta_feature at end of training. None keeps beta fixed."
1283+
None,
1284+
description="Target beta_feature at end of training. None keeps beta fixed.",
12551285
)
12561286
distill_beta_schedule: Literal["constant", "linear", "cosine"] = Field(
1257-
"constant", description="Schedule type for beta annealing ('constant', 'linear', or 'cosine')."
1287+
"constant",
1288+
description="Schedule type for beta annealing ('constant', 'linear', or 'cosine').",
12581289
)
12591290

12601291
# --- Learn to init related parameters --
@@ -1277,11 +1308,13 @@ class Distillation(BaseModel):
12771308
)
12781309

12791310
attn_module_name: Optional[str] = Field(
1280-
None, description="Attention nnx module attribute name to augment with LTI logic"
1311+
None,
1312+
description="Attention nnx module attribute name to augment with LTI logic",
12811313
)
12821314

12831315
lti_layer_indices: Optional[list[int]] = Field(
1284-
None, description="List of layer indices to apply LTI modifications. If None, applied to all layers."
1316+
None,
1317+
description="List of layer indices to apply LTI modifications. If None, applied to all layers.",
12851318
)
12861319
# ---------------------------------------
12871320

@@ -1328,6 +1361,10 @@ class ManifoldConstrainedHyperConnections(BaseModel):
13281361

13291362
mhc_expansion_rate: PositiveInt = Field(1, description="The number of parallel streams in Hyper Connection.")
13301363
sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.")
1364+
enable_mhc_k4_shortcut: bool = Field(
1365+
True,
1366+
description="Whether to enable the permutation-based convex combination shortcut when mhc_expansion_rate is 4.",
1367+
)
13311368

13321369

13331370
class DilocoParams(BaseModel):
@@ -1344,10 +1381,12 @@ class Optimizer(BaseModel):
13441381

13451382
opt_type: OptimizerType = Field(OptimizerType.ADAMW, description="The type of optimizer to use.")
13461383
skip_step_on_spikes: bool = Field(
1347-
False, description="If True, skip the training step when loss or gradient spike is detected."
1384+
False,
1385+
description="If True, skip the training step when loss or gradient spike is detected.",
13481386
)
13491387
skip_step_interval: PositiveInt = Field(
1350-
128, description="The rolling interval to calculate the mean and standard deviation."
1388+
128,
1389+
description="The rolling interval to calculate the mean and standard deviation.",
13511390
)
13521391
skip_step_scaling_factor: float = Field(6.0, description="The scaling factor to determine if a spike occurred.")
13531392
gradient_accumulation_steps: PositiveInt = Field(
@@ -1616,7 +1655,8 @@ class Profiling(BaseModel):
16161655
tpu_num_chips_to_profile_per_task: int = Field(1, description="Specifies the number of TPU chips to profile per task.")
16171656
tpu_num_sparse_cores_to_trace: int = Field(2, description="Specifies the number of TPU chips to profile per task.")
16181657
tpu_num_sparse_core_tiles_to_trace: int = Field(
1619-
1, description="Specifies the number of tiles within each sparse core to trace on the TPU."
1658+
1,
1659+
description="Specifies the number of tiles within each sparse core to trace on the TPU.",
16201660
)
16211661
xprof_tpu_power_trace_level: XProfTPUPowerTraceMode = Field(
16221662
XProfTPUPowerTraceMode.POWER_TRACE_NONE,
@@ -1796,7 +1836,10 @@ class VisionTower(BaseModel):
17961836
temporal_patch_size_for_vit: int = Field(2, description="Temporal patch size for video inputs.")
17971837
num_position_embeddings_for_vit: int = Field(1024, description="Number of position embeddings for ViT.")
17981838
deepstack_visual_indexes_for_vit: list[int] = Field([], description="Layer indices to extract deep visual features.")
1799-
vision_output_length: int = Field(-1, description="The output length (number of soft tokens) from the vision encoder.")
1839+
vision_output_length: int = Field(
1840+
-1,
1841+
description="The output length (number of soft tokens) from the vision encoder.",
1842+
)
18001843

18011844

18021845
class VisionProjector(BaseModel):
@@ -1900,18 +1943,28 @@ class RL(BaseModel):
19001943
grpo_epsilon: float = Field(0.2, description="Epsilon value for clipping in the GRPO loss.")
19011944
loss_algo: Literal["grpo", "gspo-token"] = Field("grpo", description="Loss algorithm, i.e., 'grpo' or 'gspo-token'.")
19021945
use_agentic_rollout: bool = Field(
1903-
False, description="If True, uses the asynchronous AgenticGRPOLearner for online vLLM rollouts."
1946+
False,
1947+
description="If True, uses the asynchronous AgenticGRPOLearner for online vLLM rollouts.",
1948+
)
1949+
max_concurrency: int = Field(
1950+
256,
1951+
description="Maximum number of concurrent rollout requests (agentic rollout only).",
19041952
)
1905-
max_concurrency: int = Field(256, description="Maximum number of concurrent rollout requests (agentic rollout only).")
19061953
off_policy_steps: int = Field(
1907-
0, description="Number of off-policy steps tolerated before requiring a policy update (agentic only)."
1954+
0,
1955+
description="Number of off-policy steps tolerated before requiring a policy update (agentic only).",
1956+
)
1957+
system_prompt: str = Field(
1958+
"",
1959+
description="System prompt injected into the agent at rollout time (agentic only).",
19081960
)
1909-
system_prompt: str = Field("", description="System prompt injected into the agent at rollout time (agentic only).")
19101961
degenerate_group_masking: bool = Field(
1911-
True, description="Mask degenerate groups (all-zero advantages) from contributing to loss (agentic only)."
1962+
True,
1963+
description="Mask degenerate groups (all-zero advantages) from contributing to loss (agentic only).",
19121964
)
19131965
epsilon_high: Optional[float] = Field(
1914-
None, description="Upper-bound clipping epsilon for GRPO loss. Defaults to epsilon when None (agentic only)."
1966+
None,
1967+
description="Upper-bound clipping epsilon for GRPO loss. Defaults to epsilon when None (agentic only).",
19151968
)
19161969
reshard_chunk_size: Optional[int] = Field(
19171970
None,
@@ -2435,7 +2488,11 @@ def validate_and_set_hlo_dump_defaults():
24352488
)
24362489
for param_name, schedule, end_value in [
24372490
("distill_alpha", self.distill_alpha_schedule, self.distill_alpha_end),
2438-
("distill_temperature", self.distill_temperature_schedule, self.distill_temperature_end),
2491+
(
2492+
"distill_temperature",
2493+
self.distill_temperature_schedule,
2494+
self.distill_temperature_end,
2495+
),
24392496
("distill_beta", self.distill_beta_schedule, self.distill_beta_end),
24402497
]:
24412498
if schedule != "constant" and end_value is None:
@@ -2948,7 +3005,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
29483005
self.use_grpo = False
29493006

29503007
if self.use_batch_split_schedule:
2951-
if self.quantization and not self.quantization == "fp8_full":
3008+
if self.quantization and self.quantization != "fp8_full":
29523009
raise ValueError("Batch split quantization only supports `quantization=fp8_full`")
29533010

29543011
if self.opt_type == "muon" and self.decoder_block not in [

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,4 @@ def load_weights(self, rng_key: jax.Array) -> None:
324324
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
325325
)
326326
self.model = nnx.data(model)
327+

0 commit comments

Comments
 (0)