Skip to content

Commit c6da8da

Browse files
NuojChengShuwen-Fang
authored andcommitted
add custom mesh and logical rule support
1 parent 1d1704d commit c6da8da

3 files changed

Lines changed: 8 additions & 17 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -352,13 +352,13 @@ moba_topk: 8
352352

353353
# DeepSeek Sparse Attention (DSA)
354354
# deepseek3.2 introduces indexer in MLA
355-
use_indexer: False
356-
indexer_head_dim: 128
357-
indexer_n_heads: 64
358-
indexer_topk: 2048
359-
# Determines the training strategy for the indexer:
360-
# - False (Dense Warm-up): Computes indexer loss over all tokens. Used with `trainable_parameters_mask` to freeze other model parameters.
361-
# - True (Sparse Training): Computes indexer loss over top-k tokens only and detaches the indexer input for independent optimization.
355+
use_sparse_indexer: False
356+
index_head_dim: 128
357+
index_n_heads: 64
358+
index_topk: 2048
359+
# Determines the token selection strategy for indexer loss:
360+
# - False: Uses all tokens (Dense Warm-up).
361+
# - True: Uses only top-k tokens (Sparse Training).
362362
# Note: This is only active when `indexer_loss_scaling_factor` > 0.
363363
indexer_sparse_training: False
364364
# Multiplier for the indexer KL divergence loss
@@ -426,6 +426,7 @@ internal_compile_num_devices: -1 # You must specify the number of devices when u
426426

427427
# Parallelism
428428
shard_mode: "auto" # can be either auto or explicit
429+
custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying yml name under config/mesh_and_rule/.
429430
mesh_axes:
430431
[
431432
"diloco",

src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ mesh_axes: ['data', 'stage', 'fsdp', 'tensor', 'expert']
2828
data_sharding: [['data', 'stage', 'fsdp', 'tensor', 'expert']]
2929
logical_axis_rules: [
3030
['activation_batch', ['data', 'fsdp', 'expert']],
31-
['activation_batch_moe', ['data', 'fsdp', 'expert']],
3231
['activation_batch_no_exp', ['data', 'fsdp']],
33-
['activation_batch_no_exp_moe', ['data', 'fsdp']],
3432
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']],
3533
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'expert']],
3634
['activation_heads', ['tensor']],
@@ -40,7 +38,6 @@ logical_axis_rules: [
4038
['activation_q_length', ['expert']],
4139
['activation_attn_embed', ['tensor']],
4240
['activation_embed', ['tensor']],
43-
['activation_embed_moe', ['tensor']],
4441
['activation_mlp', ['tensor']],
4542
['activation_kv', ['tensor']],
4643
['activation_prefill_kv_batch', ['data', 'fsdp', 'expert']],
@@ -58,10 +55,7 @@ logical_axis_rules: [
5855
['q_heads', ['tensor']],
5956
['kv_heads', ['tensor']],
6057
['embed', ['fsdp', 'expert']],
61-
['embed_moe', ['fsdp', 'expert']],
6258
['embed_no_exp', ['fsdp']],
63-
['embed_no_exp_moe', ['fsdp']],
64-
['embed_moe', ['fsdp']],
6559
['q_lora', ['fsdp']],
6660
['kv_lora', ['fsdp']],
6761
['norm', ['tensor']],

src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ data_sharding: [['fsdp']]
1919
logical_axis_rules: [
2020
['activation_batch', ['fsdp']],
2121
['activation_batch_no_exp', ['fsdp']],
22-
['activation_batch_moe', ['fsdp']],
23-
['activation_batch_no_exp_moe', ['fsdp']],
2422
['activation_embed_and_logits_batch', ['fsdp']],
2523
['activation_embed_and_logits_batch_sequence', ['fsdp']],
2624
['activation_prefill_kv_batch', ['fsdp']],
@@ -29,8 +27,6 @@ logical_axis_rules: [
2927
['decode_batch', ['fsdp']],
3028
['embed', ['fsdp']],
3129
['embed_no_exp', ['fsdp']],
32-
['embed_moe', ['fsdp']],
33-
['embed_no_exp_moe', ['fsdp']],
3430
['q_lora', ['fsdp']],
3531
['kv_lora', ['fsdp']],
3632
['exp_with_fsdp', 'fsdp'],

0 commit comments

Comments
 (0)