Skip to content

Commit 968f625

Browse files
committed
introduce cp-as-ep rule for long context training or strong scaling
1 parent f67d8b1 commit 968f625

2 files changed

Lines changed: 81 additions & 0 deletions

File tree

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This rule uses data, stage, FSDP, and expert. Expert axis acts as context parallelism in
16+
# components except core dMoE part (between EP all2all).
17+
mesh_axes: ['data', 'stage', 'fsdp', 'context', 'expert']
18+
data_sharding: [['data', 'stage', 'fsdp', 'context', 'expert']]
19+
context_sharding: 'context'
20+
logical_axis_rules: [
21+
# ==========================================
22+
# Vocabulary Embedding
23+
# ==========================================
24+
# Vocab Activations
25+
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']],
26+
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'expert']],
27+
# Vocab Weights
28+
['vocab', []],
29+
['embed_vocab', ['fsdp', 'context', 'expert']],
30+
# ==========================================
31+
# Attention
32+
# ==========================================
33+
# Attention Activations
34+
['activation_batch_attn', ['data', 'fsdp', 'expert']],
35+
['activation_heads', []],
36+
['activation_kv_heads', []],
37+
['activation_length_attn', ['context']],
38+
['activation_q_length', ['context']],
39+
['activation_kv_length', []],
40+
['activation_embed_attn', []],
41+
['activation_kv', []],
42+
['activation_kv_batch', ['data', 'fsdp', 'expert']],
43+
['activation_kv_head_dim', []],
44+
# Attention Weights
45+
['heads', []],
46+
['q_heads', []],
47+
['kv_heads', []],
48+
['qkv', []],
49+
['kv', []],
50+
['kv_head_dim', []],
51+
['q_lora', ['fsdp', 'context', 'expert']],
52+
["q_lora_up_proj", []],
53+
['kv_lora', ['fsdp', 'context', 'expert']],
54+
["kv_lora_up_proj", []],
55+
# ==========================================
56+
# Mixture of Experts (MoE)
57+
# ==========================================
58+
# MoE Activations
59+
['activation_batch_moe', ['data', 'fsdp']],
60+
['activation_exp', ['context', 'expert']],
61+
# MoE Weights
62+
['exp', ['context', 'expert']],
63+
['embed_moe', ['fsdp']],
64+
# ==========================================
65+
# Standard MLP / Dense Layers / Model Structure
66+
# ==========================================
67+
# Dense Activations
68+
['activation_mlp', []],
69+
['activation_batch', ['data', 'fsdp', 'expert']],
70+
['activation_length', ['context']],
71+
['activation_norm_length', ['context']],
72+
['activation_embed', []],
73+
['activation_stage', 'stage'],
74+
# General Weights
75+
['mlp', []],
76+
['layers', 'stage'],
77+
['embed', ['fsdp', 'context', 'expert']],
78+
]

src/maxtext/layers/moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,9 @@ def __init__(
374374

375375
if self.config.attention == "vllm_rpa" and self.config.enable_dp_attention:
376376
self._expert_parallelism_name = "attn_dp_expert"
377+
elif self.config.custom_mesh_and_rule == "cp-as-ep":
378+
# when custom mesh and rule is cp-as-ep, context axis is same with expert in MoE component
379+
self._expert_parallelism_name = ("context", "expert")
377380
else:
378381
self._expert_parallelism_name = "expert"
379382

0 commit comments

Comments
 (0)