Skip to content

Commit 4799cef

Browse files
Merge pull request #2895 from AI-Hypercomputer:config_doc_update
PiperOrigin-RevId: 852387500
2 parents 60b0953 + fad0457 commit 4799cef

3 files changed

Lines changed: 24 additions & 19 deletions

File tree

docs/reference/core_concepts/moe_configuration.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ Dropping:
5050

5151
`first_num_dense_layers`: The number of initial dense layers before the first MoE layer is introduced.
5252

53-
`float32_weight_sum`: If enabled, performs the summation of expert weights using float32 precision for improved numerical stability.
53+
`float32_weight_sum`: If enabled, performs the summation of expert weights using float32 precision for improved numerical stability. Recommended specifically when lower precision types cause convergence or quality issues.
5454

5555
### Routing Mechanism
5656
`use_random_routing`: If enabled, ignores the gate logits and routes tokens to random experts. This is designed to simulate load balancing for debugging and performance testing purposes.
@@ -82,11 +82,11 @@ Dropping:
8282
* Value > 0: Enforces a strict capacity limit; tokens exceeding this limit are dropped.
8383
* Value = -1: Dropless with dense matrix multiplication, which is computationally expensive and typically used only as a baseline.
8484

85-
`use_custom_sort_vjp`: If enabled, use a custom Vector-Jacobian Product (VJP) sort for efficient backward pass processing in sparse matmul.
85+
`use_custom_sort_vjp`: If enabled, use a custom Vector-Jacobian Product (VJP) sort for efficient backward pass processing in sparse matmul. Recommended to replace the inefficient scatter-add generated by the `jax.numpy.take` in the backward pass.
8686

87-
`mlp_bias`: If enabled, add bias terms within the expert MLP layers.
87+
`mlp_bias`: If enabled, add learnable bias terms for MLP matmul. Originally implemented to support the GPT-OSS model architecture.
8888

89-
`use_batch_split_schedule` (experimental): If enabled, split batch into micro-batches to hide communications.
89+
`use_batch_split_schedule` (experimental): If enabled, split batch into micro-batches to hide communications that yields performance benefits.
9090

9191
## 2. Sharding
9292
`expert_shard_attention_option`: Determines how the "expert" axis is interpreted when sharding attention layers. Options include:
@@ -95,9 +95,9 @@ Dropping:
9595

9696
`use_ring_of_experts` (experimental): This feature requires expert parallelism. If enabled, it replaces the standard two All-to-All communications with All-Gather in dispatch and Reduce-Scatter in collect. By gathering inputs across all shards, it allows for local routing and Top-K calculations, followed by result aggregation via Reduce-Scatter. This approach is particularly effective for models with a large Top-K, as it gathers activations before they are replicated k times to reduce communication.
9797

98-
`moe_fsdp_use_two_stage_all_gather`: If enabled, splits the All-Gather operation for MoE weights into two separate stages when using FSDP/FSDP-transpose sharding. This is preferred when 3D All-Gather support is unavailable.
98+
`moe_fsdp_use_two_stage_all_gather`: If enabled, split the All-Gather operation for MoE weights into two separate stages when using FSDP/FSDP-transpose sharding. This is preferred when 3D All-Gather support is unavailable.
9999

100-
`fsdp_shard_on_exp`: If enabled, shard MLP weights on expert dimension instead of embedding dimension during FSDP sharding.
100+
`fsdp_shard_on_exp`: If enabled, shard the expert dimension of the MLP weights on the FSDP axis, and recommended when num_experts is a multiple of fsdp_parallelism.
101101

102102
## 3. Performance Tuning
103103
These parameters provide granular control over the tiling dimensions for sparse matmul Pallas kernel.

src/MaxText/configs/base.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ logits_dot_in_fp32: false # whether to use fp32 in logits_dense or shared_embed
157157
cast_logits_to_fp32: true # whether to cast the logits to fp32. the higher precision is generally beneficial, but it can vary slightly.
158158
float32_qk_product: false # in dot_product attention, whether to cast to fp32 the inputs to qk product
159159
float32_logits: false # in dot_product attention, whether to cast to fp32 the inputs to softmax
160-
float32_weight_sum: true # whether to use full fp32 precision for weight_sum during final unpermute in moe
160+
float32_weight_sum: true # whether to use full fp32 precision to sum expert weights for numerical stability
161161

162162
# multi-token prediction configs
163163
# the number of auxiliary prediction layers to use for mtp.
@@ -179,7 +179,7 @@ sparse_matmul: true
179179
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
180180
load_balance_loss_weight: 0.0 # weight for the load balance loss
181181
use_random_routing: false # whether to use random routing for debug/test purpose
182-
use_custom_sort_vjp: true # whether to use a custom sort vjp for sparse matmul ops
182+
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
183183
use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism
184184
# tunable tiling dimensions used for mlp gmm
185185
# megablox/jax ragged dot - supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`)
@@ -212,7 +212,8 @@ expert_shard_attention_option: "fsdp"
212212

213213
# when moe weight matrices are sharded on both fsdp and fsdp-transpose axes, use two separate all-gather calls
214214
moe_fsdp_use_two_stage_all_gather: false
215-
# shard the moe weights on num_expert_dim. this can be performanct when num_expert % fdsp_parallisum
215+
# Shard the expert dimension of the MLP weights on the FSDP axis.
216+
# This configuration is recommended only when num_experts is a multiple of fsdp_parallelism
216217
fsdp_shard_on_exp: False
217218
# use fsdp and fsdp_transpose axes for sharding the moe weights
218219
use_2d_fsdp_sharding: False
@@ -225,13 +226,12 @@ routed_scaling_factor: 1.0 # scaling factor for routing scores
225226
routed_score_func: "" # scoring function for routing
226227
routed_bias: False # a flag if a learnable bias is added for routing
227228
routed_bias_update_rate: 0.0 # a flag indicate the update rate applied to the router bias term
228-
mlp_bias: False # a flag if a learnable bias is added for MLP matmul
229+
mlp_bias: False # a flag if a learnable bias is added for MLP matmul, and originally implemented to support the GPT-OSS model architecture.
229230
n_routing_groups: -1 # number of groups for routing, disabled by default
230231
topk_routing_group: -1 # number of top groups to route inputs. For EP,
231232
# Splits the batch to allow for better scheduling when using expert parallelism by overlapping the
232233
# all-to-all communication with compute. Currently only implemented with DeepSeek sparse layers.
233-
use_batch_split_schedule: False # whether to use batch split schedule
234-
# sending activations to a maximum of topk_routing_group distinct devices can yield performance benefits.
234+
use_batch_split_schedule: False # a flag if splitting batch into micro-batches to hide communications that yields performance benefits.
235235

236236
# For complex architectures like llama4 there are repeated sets of
237237
# inhomogeneous layers. E.g. maverick uses [dense+rope, moe+rope, dense+rope, moe+nope]

src/MaxText/configs/types.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,9 @@ class MoEGeneral(BaseModel):
553553
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
554554
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
555555
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
556-
use_custom_sort_vjp: bool = Field(True, description="Whether to use a custom sort VJP for sparse matmul ops.")
556+
use_custom_sort_vjp: bool = Field(
557+
True, description="Whether to use a custom VJP sort for efficient backward pass processing in sparse matmul."
558+
)
557559
use_ring_of_experts: bool = Field(
558560
False,
559561
description="Whether to use Ring of Experts for sparse matmul expert parallelism.",
@@ -570,8 +572,8 @@ class MoEGeneral(BaseModel):
570572
)
571573
fsdp_shard_on_exp: bool = Field(
572574
False,
573-
description="Shard the MoE weights on the num_expert dimension. Can be performant when "
574-
"num_experts % fsdp_parallelism != 0.",
575+
description="Shard the expert dimension of the MLP weights on the FSDP axis, "
576+
"and recommended when num_experts is a multiple of fsdp_parallelism",
575577
)
576578
use_2d_fsdp_sharding: bool = Field(
577579
False,
@@ -583,7 +585,7 @@ class MoEGeneral(BaseModel):
583585
)
584586
float32_weight_sum: bool = Field(
585587
True,
586-
description="Whether to use full fp32 precision for weight_sum during final unpermute in MoE.",
588+
description="Whether to use full fp32 precision to sum expert weights for numerical stability.",
587589
)
588590

589591

@@ -640,13 +642,16 @@ class DeepSeekMoE(BaseModel):
640642
routed_score_func: str = Field("", description="Scoring function for routing (e.g., 'softmax', 'sigmoid').")
641643
routed_bias: bool = Field(False, description="Whether to add a bias term for routing.")
642644
routed_bias_update_rate: float = Field(0.0, description="Update rate applied to the router bias term.")
643-
mlp_bias: bool = Field(False, description="Whether to add a learnable bias for MLP matmul.")
645+
mlp_bias: bool = Field(
646+
False,
647+
description="Whether to add a learnable bias for MLP matmul, "
648+
"and originally implemented to support the GPT-OSS model architecture",
649+
)
644650
n_routing_groups: int = Field(-1, description="Number of groups for routing, disabled by default.")
645651
topk_routing_group: int = Field(-1, description="Number of top groups to route inputs to.")
646652
use_batch_split_schedule: bool = Field(
647653
False,
648-
description="Splits the batch to allow for better scheduling when using expert parallelism by overlapping all-to-all "
649-
"with compute.",
654+
description="Whether to split batch into micro-batches to hide communications that yields performance benefits.",
650655
)
651656

652657

0 commit comments

Comments
 (0)