Skip to content

Commit 632436c

Browse files
committed
split logical names in moe module
1 parent c4b5e64 commit 632436c

4 files changed

Lines changed: 79 additions & 58 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,9 @@ custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying y
435435
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
436436
logical_axis_rules: [
437437
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
438+
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
438439
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
440+
['activation_batch_no_exp_moe', ['data', 'fsdp', 'fsdp_transpose']],
439441
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
440442
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
441443
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
@@ -448,14 +450,18 @@ logical_axis_rules: [
448450
['activation_attn_length_no_exp', ['context']],
449451
['activation_length_no_exp', ['sequence', 'context']],
450452
['activation_length_no_exp', ['context']],
453+
['activation_length_no_exp_moe', ['sequence', 'context']],
454+
['activation_length_no_exp_moe', ['context']],
451455
['activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
456+
['activation_norm_length_moe', ['tensor_sequence', 'context', 'sequence']],
452457
['activation_q_length', ['context', 'expert']],
453458
['activation_q_length_no_exp', ['context']],
454459
['prefill_activation_length', ['sequence', 'context']],
455460
['prefill_activation_norm_length', ['tensor_sequence', 'context', 'sequence']],
456461
['activation_kv_length', []],
457462
['activation_attn_embed', ['tensor', 'tensor_transpose']],
458463
['activation_embed', ['tensor', 'tensor_transpose']],
464+
['activation_embed_moe', ['tensor', 'tensor_transpose']],
459465
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
460466
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
461467
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
@@ -484,6 +490,14 @@ logical_axis_rules: [
484490
['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
485491
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
486492
['embed_no_exp', ['fsdp', 'sequence', 'context']],
493+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
494+
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
495+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
496+
['embed_moe', ['fsdp', 'sequence', 'context', 'expert']],
497+
['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
498+
['embed_no_exp_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
499+
['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
500+
['embed_no_exp_moe', ['fsdp', 'sequence', 'context']],
487501
['embed_tensor_transpose', ['tensor_transpose']],
488502
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
489503
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],

src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ logical_axis_rules: [
3838
['activation_q_length', ['expert']],
3939
['activation_attn_embed', ['tensor']],
4040
['activation_embed', ['tensor']],
41+
['activation_embed_moe', ['tensor']],
4142
['activation_mlp', ['tensor']],
4243
['activation_kv', ['tensor']],
4344
['activation_prefill_kv_batch', ['data', 'fsdp', 'expert']],
@@ -56,6 +57,7 @@ logical_axis_rules: [
5657
['kv_heads', ['tensor']],
5758
['embed', ['fsdp', 'expert']],
5859
['embed_no_exp', ['fsdp']],
60+
['embed_moe', ['fsdp']],
5961
['q_lora', ['fsdp']],
6062
['kv_lora', ['fsdp']],
6163
['norm', ['tensor']],

src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ logical_axis_rules: [
2727
['decode_batch', ['fsdp']],
2828
['embed', ['fsdp']],
2929
['embed_no_exp', ['fsdp']],
30+
['embed_moe', ['fsdp']],
3031
['q_lora', ['fsdp']],
3132
['kv_lora', ['fsdp']],
3233
['exp_with_fsdp', 'fsdp'],

src/maxtext/layers/moe.py

Lines changed: 62 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax.
278278

279279
contract_ind = tuple(range(0, len(norm_axis)))
280280
output_sharding = (
281-
create_sharding(self.mesh, ("activation_batch_no_exp", "activation_length_no_exp", None))
281+
create_sharding(self.mesh, ("activation_batch_no_exp_moe", "activation_length_no_exp_moe", None))
282282
if self.shard_mode == ShardMode.EXPLICIT
283283
else None
284284
)
@@ -351,16 +351,16 @@ def __init__(
351351

352352
if self.config.shard_exp_on_fsdp:
353353
# special sharding for dsv3
354-
self.wi_kernel_axes = ("embed_no_exp", None, "mlp")
355-
self.wo_kernel_axes = ("embed_no_exp", "mlp", None)
354+
self.wi_kernel_axes = ("embed_no_exp_moe", None, "mlp")
355+
self.wo_kernel_axes = ("embed_no_exp_moe", "mlp", None)
356356
elif self.config.use_2d_fsdp_sharding:
357-
self.wi_kernel_axes = ("embed_no_exp", "mlp", None)
358-
self.wo_kernel_axes = ("embed_no_exp", "mlp", None)
357+
self.wi_kernel_axes = ("embed_no_exp_moe", "mlp", None)
358+
self.wo_kernel_axes = ("embed_no_exp_moe", "mlp", None)
359359
elif self.config.use_batch_split_schedule:
360360
self.wi_kernel_axes, self.wo_kernel_axes = get_batchsplit_init_kernel_axes()
361361
else:
362-
self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp")
363-
self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp")
362+
self.wi_kernel_axes = ("exp", "embed_no_exp_moe", "mlp")
363+
self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp_moe")
364364

365365
if self.config.attention == "vllm_rpa":
366366
# vLLM uses 'model' as the tensor parallelism axis name
@@ -437,7 +437,7 @@ def __init__(
437437

438438
if self.config.mlp_bias:
439439
wi_bias_axes = ("exp", "activation_mlp")
440-
wo_bias_axes = ("exp", "activation_embed")
440+
wo_bias_axes = ("exp", "activation_embed_moe")
441441
wi_bias_shape = (self.num_experts, self.intermediate_dim)
442442
wo_bias_shape = (self.num_experts, self.config.emb_dim)
443443
self.wi_0_bias = nnx.Param(
@@ -1018,7 +1018,7 @@ def gmm(
10181018
self._expert_parallelism_name
10191019
in tuple(
10201020
filter(
1021-
lambda tup: tup[0] == "activation_batch",
1021+
lambda tup: tup[0] == "activation_batch_moe",
10221022
self.config.logical_axis_rules,
10231023
)
10241024
)[
@@ -1028,26 +1028,26 @@ def gmm(
10281028
except: # pylint: disable=bare-except
10291029
is_batch_sharded_by_expert = False
10301030
if is_batch_sharded_by_expert and inputs.shape[0] > 1:
1031-
batch_logical_axis = "activation_batch"
1031+
batch_logical_axis = "activation_batch_moe"
10321032
else:
1033-
batch_logical_axis = "activation_batch_no_exp"
1033+
batch_logical_axis = "activation_batch_no_exp_moe"
10341034

10351035
if self.get_tensor_transpose_parallelism_size() > 1:
10361036
input_partition_pspec = self._logical_to_mesh_axes(
1037-
(batch_logical_axis, "activation_norm_length", "activation_embed")
1037+
(batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe")
10381038
)
10391039
w0_bias_pspec = self._logical_to_mesh_axes(("exp", None))
10401040
w1_bias_pspec = self._logical_to_mesh_axes(("exp", None))
1041-
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed"))
1041+
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe"))
10421042
else:
1043-
input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
1043+
input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None))
10441044
w0_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp"))
10451045
w1_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp"))
1046-
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed"))
1046+
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe"))
10471047

1048-
gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
1048+
gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None))
10491049
if self.config.model_name.startswith("deepseek3"):
1050-
pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
1050+
pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None))
10511051
else:
10521052
# pre_bias_logits is None for non-DeepSeek v3 models
10531053
pre_bias_logits_pspec = None
@@ -1099,7 +1099,7 @@ def gmm(
10991099
P(), # Replicate the input key
11001100
),
11011101
out_specs=(
1102-
self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")),
1102+
self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe")),
11031103
P(), # Handle None or replicate the output
11041104
P(), # Handle None or replicate the output
11051105
),
@@ -1411,13 +1411,13 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
14111411
wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp_no_fsdp", "embed_tensor_transpose"))
14121412

14131413
if self.get_tensor_transpose_parallelism_size() > 1:
1414-
input_axes = (batch_logical_axis, "activation_norm_length", "activation_embed")
1414+
input_axes = (batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe")
14151415
else:
1416-
input_axes = (batch_logical_axis, "activation_norm_length", None)
1416+
input_axes = (batch_logical_axis, "activation_norm_length_moe", None)
14171417

1418-
gate_logits_axes = (batch_logical_axis, "activation_norm_length", None)
1418+
gate_logits_axes = (batch_logical_axis, "activation_norm_length_moe", None)
14191419
if self.config.model_name.startswith("deepseek3"):
1420-
pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length", None)
1420+
pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length_moe", None)
14211421
else:
14221422
pre_bias_logits_axes = None
14231423

@@ -1436,13 +1436,13 @@ def reshape_and_update_weights(self, weights, indices):
14361436
update_weights = jnp.zeros((weights.shape[0], weights.shape[1], self.num_experts), dtype=self.dtype)
14371437
index_update = (
14381438
self._maybe_shard_with_logical(
1439-
jnp.arange(weights.shape[0])[:, None, None], ("activation_batch_no_exp", None, None)
1439+
jnp.arange(weights.shape[0])[:, None, None], ("activation_batch_no_exp_moe", None, None)
14401440
),
1441-
self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length_no_exp", None)),
1441+
self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length_no_exp_moe", None)),
14421442
indices,
14431443
)
14441444
weight_sharding = (
1445-
create_sharding(self.mesh, ("activation_batch_no_exp", "activation_length_no_exp", None))
1445+
create_sharding(self.mesh, ("activation_batch_no_exp_moe", "activation_length_no_exp_moe", None))
14461446
if self.config.shard_mode == ShardMode.EXPLICIT
14471447
else None
14481448
)
@@ -1497,15 +1497,15 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs):
14971497
expert_mask,
14981498
(batch_size, cp, sub_seq * self.num_experts_per_tok, self.num_experts),
14991499
)
1500-
expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch", None, None, None))
1500+
expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch_moe", None, None, None))
15011501
expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=2)
15021502
expert_token_count = jnp.reshape(
15031503
expert_token_count_fused,
15041504
((batch_size, cp, sub_seq, self.num_experts_per_tok, self.num_experts)),
15051505
)
15061506
expert_token_count = self._maybe_shard_with_logical(
15071507
expert_token_count,
1508-
("activation_batch", "activation_norm_length", None, None, None),
1508+
("activation_batch_moe", "activation_norm_length_moe", None, None, None),
15091509
)
15101510
trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch)
15111511
combined_expert_mask = jnp.sum(trunc_expert_mask, axis=3)
@@ -1585,15 +1585,15 @@ def generate_masks(self, top_k_indices, softmax_probs):
15851585
expert_mask,
15861586
(batch_size, seq_len * self.num_experts_per_tok, self.num_experts),
15871587
)
1588-
expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch", None, None))
1588+
expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch_moe", None, None))
15891589
expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=1)
15901590
expert_token_count = jnp.reshape(
15911591
expert_token_count_fused,
15921592
((batch_size, seq_len, self.num_experts_per_tok, self.num_experts)),
15931593
)
15941594
expert_token_count = self._maybe_shard_with_logical(
15951595
expert_token_count,
1596-
("activation_batch", "activation_norm_length", None, None),
1596+
("activation_batch_moe", "activation_norm_length_moe", None, None),
15971597
)
15981598
trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch)
15991599
combined_expert_mask = jnp.sum(trunc_expert_mask, axis=2)
@@ -1691,11 +1691,13 @@ def dense_matmul(
16911691
) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]:
16921692
"""Dense matrix multiplication."""
16931693
# gate_logits: batch, length, expert
1694-
gate_logits = self._maybe_shard_with_logical(gate_logits, ("activation_batch", "activation_norm_length", None))
1694+
gate_logits = self._maybe_shard_with_logical(
1695+
gate_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None)
1696+
)
16951697
if self.config.model_name.startswith("deepseek3"):
16961698
# pre_bias_logits is None for non-DeepSeek v3 models
16971699
pre_bias_logits = self._maybe_shard_with_logical(
1698-
pre_bias_logits, ("activation_batch", "activation_norm_length", None)
1700+
pre_bias_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None)
16991701
)
17001702
top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs)
17011703
is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4
@@ -1735,16 +1737,16 @@ def dense_matmul(
17351737
dispatch_mask, combine_mask = self.generate_masks(
17361738
top_k_indices, weights # pylint: disable=undefined-variable,possibly-used-before-assignment
17371739
)
1738-
mask_axes = ("activation_batch", "activation_norm_length", None, None)
1740+
mask_axes = ("activation_batch_moe", "activation_norm_length_moe", None, None)
17391741
dispatch_axis = (
17401742
"activation_exp",
1741-
"activation_batch_no_exp",
1743+
"activation_batch_no_exp_moe",
17421744
None,
1743-
"activation_embed",
1745+
"activation_embed_moe",
17441746
)
17451747
mlp_axis = (
17461748
"activation_exp",
1747-
"activation_batch_no_exp",
1749+
"activation_batch_no_exp_moe",
17481750
None,
17491751
"activation_mlp",
17501752
)
@@ -1759,56 +1761,56 @@ def dense_matmul(
17591761
dispatch_mask, combine_mask = self.generate_masks_subgroup(top_k_indices, softmax_probs)
17601762
if self.get_context_autoregressive_parallelism_size() > 0 and cp == 1:
17611763
mask_axes = (
1762-
"activation_norm_length",
1763-
"activation_batch",
1764+
"activation_norm_length_moe",
1765+
"activation_batch_moe",
17641766
None,
17651767
None,
17661768
None,
17671769
)
17681770
input_axis = (
1769-
"activation_norm_length",
1770-
"activation_batch",
1771+
"activation_norm_length_moe",
1772+
"activation_batch_moe",
17711773
None,
1772-
"activation_embed",
1774+
"activation_embed_moe",
17731775
)
17741776
dispatch_axis = (
17751777
"activation_exp",
1776-
"activation_batch_no_exp",
1778+
"activation_batch_no_exp_moe",
17771779
None,
17781780
None,
1779-
"activation_embed",
1781+
"activation_embed_moe",
17801782
)
17811783
mlp_axis = (
17821784
"activation_exp",
1783-
"activation_batch_no_exp",
1785+
"activation_batch_no_exp_moe",
17841786
None,
17851787
None,
17861788
"activation_mlp",
17871789
)
17881790
else:
17891791
mask_axes = (
1790-
"activation_batch",
1791-
"activation_norm_length",
1792+
"activation_batch_moe",
1793+
"activation_norm_length_moe",
17921794
None,
17931795
None,
17941796
None,
17951797
)
17961798
input_axis = (
1797-
"activation_batch",
1798-
"activation_norm_length",
1799+
"activation_batch_moe",
1800+
"activation_norm_length_moe",
17991801
None,
1800-
"activation_embed",
1802+
"activation_embed_moe",
18011803
)
18021804
dispatch_axis = (
18031805
"activation_exp",
1804-
"activation_batch_no_exp",
1806+
"activation_batch_no_exp_moe",
18051807
None,
18061808
None,
1807-
"activation_embed",
1809+
"activation_embed_moe",
18081810
)
18091811
mlp_axis = (
18101812
"activation_exp",
1811-
"activation_batch_no_exp",
1813+
"activation_batch_no_exp_moe",
18121814
None,
18131815
None,
18141816
"activation_mlp",
@@ -1834,10 +1836,10 @@ def dense_matmul(
18341836
dispatch,
18351837
(
18361838
None,
1837-
"activation_batch_no_exp",
1838-
"activation_norm_length",
1839+
"activation_batch_no_exp_moe",
1840+
"activation_norm_length_moe",
18391841
None,
1840-
"activation_embed",
1842+
"activation_embed_moe",
18411843
),
18421844
)
18431845
dispatch = self._maybe_shard_with_logical(
@@ -1897,9 +1899,9 @@ def dense_matmul(
18971899
intermediate_layer,
18981900
(
18991901
"activation_exp",
1900-
"activation_batch_no_exp",
1902+
"activation_batch_no_exp_moe",
19011903
None,
1902-
"activation_embed",
1904+
"activation_embed_moe",
19031905
),
19041906
)
19051907
intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo")
@@ -1922,7 +1924,9 @@ def dense_matmul(
19221924
)
19231925
return output, lb_loss, bias_updates
19241926
else:
1925-
inputs = self._maybe_shard_with_logical(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
1927+
inputs = self._maybe_shard_with_logical(
1928+
inputs, ("activation_batch_moe", "activation_norm_length_moe", "activation_embed_moe")
1929+
)
19261930
with jax.named_scope("wi_0"):
19271931
layer_w0 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)(
19281932
"BSM,EMH -> BSEH", inputs, w0_kernel, precision=matmul_precision
@@ -2082,7 +2086,7 @@ def __init__(
20822086
num_experts_per_tok=self.config.num_experts_per_tok,
20832087
mesh=self.mesh,
20842088
kernel_init=nd_dense_init(1.0, "fan_in", "truncated_normal"),
2085-
kernel_axes=("embed", None),
2089+
kernel_axes=("embed_moe", None),
20862090
intermediate_dim=self.config.moe_mlp_dim,
20872091
dtype=self.config.dtype,
20882092
weight_dtype=self.config.weight_dtype,

0 commit comments

Comments
 (0)