Skip to content

Commit 7b75993

Browse files
committed
remove no exp in moe
1 parent 7e3f19f commit 7b75993

37 files changed

Lines changed: 520 additions & 606 deletions

src/maxtext/configs/base.yml

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,7 @@ custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying y
455455
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
456456
logical_axis_rules: [
457457
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
458-
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
459-
['activation_batch_no_exp_moe', ['data', 'fsdp', 'fsdp_transpose']],
458+
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']],
460459
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
461460
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
462461
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']],
@@ -477,6 +476,7 @@ logical_axis_rules: [
477476
['activation_embed', ['tensor', 'tensor_transpose']],
478477
['activation_embed_moe', ['tensor', 'tensor_transpose']],
479478
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
479+
['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
480480
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
481481
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
482482
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose']],
@@ -490,6 +490,7 @@ logical_axis_rules: [
490490
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
491491
['decode_length', ['sequence']],
492492
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
493+
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
493494
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
494495
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
495496
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
@@ -499,18 +500,10 @@ logical_axis_rules: [
499500
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
500501
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
501502
['embed', ['fsdp', 'sequence', 'context', 'expert']],
502-
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
503-
['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
504-
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
505-
['embed_no_exp', ['fsdp', 'sequence', 'context']],
506-
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']],
507-
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']],
508-
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
509-
['embed_moe', ['fsdp', 'sequence', 'context', 'expert']],
510-
['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
511-
['embed_no_exp_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
512-
['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
513-
['embed_no_exp_moe', ['fsdp', 'sequence', 'context']],
503+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']],
504+
['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
505+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
506+
['embed_moe', ['fsdp', 'sequence', 'context']],
514507
['embed_tensor_transpose', ['tensor_transpose']],
515508
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
516509
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],

src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ mesh_axes: ['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']
3030
data_sharding: [['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']]
3131
logical_axis_rules: [
3232
['activation_batch', ['data', 'fsdp', 'expert']],
33-
['activation_batch_moe', ['data', 'fsdp', 'expert']],
34-
['activation_batch_no_exp_moe', ['data', 'fsdp']],
33+
['activation_batch_moe', ['data', 'fsdp']],
3534
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']],
3635
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'context', 'expert']],
3736
['activation_heads', ['tensor']],
@@ -45,22 +44,22 @@ logical_axis_rules: [
4544
['activation_embed', ['tensor']],
4645
['activation_embed_moe', ['tensor']],
4746
['activation_mlp', ['tensor']],
47+
['activation_mlp_moe', ['tensor']],
4848
['activation_kv', ['tensor']],
4949
['activation_kv_batch', ['data', 'fsdp']],
5050
['activation_kv_head_dim', ['tensor']],
5151
['activation_vocab', ['tensor']],
5252
['activation_stage', 'stage'],
5353
['activation_exp', ['expert']],
5454
['mlp', ['tensor']],
55+
['mlp_moe', ['tensor']],
5556
['mlp_no_fsdp', ['tensor']],
5657
['vocab', ['tensor']],
5758
['heads', ['tensor']],
5859
['q_heads', ['tensor']],
5960
['kv_heads', ['tensor']],
6061
['embed', ['fsdp', 'expert']], # remove context from embed sharding
61-
['embed_moe', ['fsdp', 'expert']],
62-
['embed_no_exp', ['fsdp']],
63-
['embed_no_exp_moe', ['fsdp']],
62+
['embed_moe', ['fsdp']],
6463
['q_lora', ['fsdp']],
6564
['kv_lora', ['fsdp']],
6665
['norm', ['tensor']],

src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,13 @@ data_sharding: [['fsdp']]
1919
logical_axis_rules: [
2020
['activation_batch', ['fsdp']],
2121
['activation_batch_moe', ['fsdp']],
22-
['activation_batch_no_exp_moe', ['fsdp']],
2322
['activation_embed_and_logits_batch', ['fsdp']],
2423
['activation_embed_and_logits_batch_sequence', ['fsdp']],
2524
['activation_prefill_kv_batch', ['fsdp']],
2625
['activation_kv_batch', ['fsdp']],
2726
['decode_batch', ['fsdp']],
2827
['embed', ['fsdp']],
29-
['embed_no_exp', ['fsdp']],
3028
['embed_moe', ['fsdp']],
31-
['embed_no_exp_moe', ['fsdp']],
3229
['q_lora', ['fsdp']],
3330
['kv_lora', ['fsdp']],
3431
['exp_with_fsdp', 'fsdp'],

src/maxtext/configs/inference/inference.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ logical_axis_rules: [
1212
['activation_norm_length', ['tensor_sequence', 'sequence']],
1313
['activation_embed', ['tensor_transpose']],
1414
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
15+
['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']],
1516
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
1617
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
1718
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
@@ -25,6 +26,7 @@ logical_axis_rules: [
2526
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
2627
['decode_length', []],
2728
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
29+
['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
2830
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']],
2931
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
3032
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
@@ -33,10 +35,10 @@ logical_axis_rules: [
3335
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
3436
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
3537
['embed', ['fsdp', 'sequence', 'expert']],
36-
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']],
37-
['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']],
38-
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']],
39-
['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive']],
38+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']],
39+
['embed_moe', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']],
40+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']],
41+
['embed_moe', ['fsdp', 'sequence', 'context_autoregressive']],
4042
['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']],
4143
['layers', 'stage'],
4244
['kv', []],

src/maxtext/configs/inference/vllm.yml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
3131
logical_axis_rules: [
3232
['activation_batch', ['data']],
3333
['activation_batch_moe', []],
34-
['activation_batch_no_exp_moe', []],
3534
['activation_embed_and_logits_batch', ['data', 'expert']],
3635
['activation_embed_and_logits_batch_sequence', ['data', 'expert']],
3736
['activation_heads', ['model', 'expert']],
@@ -45,6 +44,7 @@ logical_axis_rules: [
4544
['activation_embed', ['model', 'attn_dp']],
4645
['activation_embed_moe', ['model', 'attn_dp']],
4746
['activation_mlp', ['model', 'attn_dp']],
47+
['activation_mlp_moe', ['model', 'attn_dp']],
4848
['activation_kv', ['model']],
4949
['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']],
5050
['activation_kv_batch', ['data']],
@@ -56,8 +56,8 @@ logical_axis_rules: [
5656
['decode_batch', ['expert', 'attn_dp_expert']],
5757
['decode_length', []],
5858
['mlp', ['model', 'attn_dp']],
59+
['mlp_moe', ['model', 'attn_dp']],
5960
['mlp_no_fsdp', ['model', 'attn_dp']],
60-
['moe_mlp', ['model', 'attn_dp']],
6161
['vocab', ['model', 'attn_dp']],
6262
['heads', ['model']],
6363
['q_heads', ['model', 'expert']],
@@ -66,11 +66,9 @@ logical_axis_rules: [
6666
['kv', []],
6767
['embed', ['expert', 'attn_dp_expert']],
6868
['embed', ['attn_dp_expert']],
69-
['embed_moe', ['expert', 'attn_dp_expert']],
70-
['embed_moe', ['attn_dp_expert']],
69+
['embed_moe', []],
70+
['embed_moe', []],
7171
['embed_tensor_transpose', ['attn_dp', 'model']],
72-
['embed_no_exp', []],
73-
['embed_no_exp_moe', []],
7472
['q_lora', ['expert', 'attn_dp_expert']],
7573
['kv_lora', ['expert', 'attn_dp_expert']],
7674
['norm', []],

src/maxtext/configs/post_train/rl_mt_jt.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ logical_axis_rules: [
4949
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
5050
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
5151
['embed', ['fsdp', 'sequence', 'expert']],
52-
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']],
53-
['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']],
54-
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']],
55-
['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive']],
52+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']],
53+
['embed_moe', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']],
54+
['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']],
55+
['embed_moe', ['fsdp', 'sequence', 'context_autoregressive']],
5656
['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']],
5757
['layers', 'stage'],
5858
['kv', []],

0 commit comments

Comments
 (0)