Skip to content

Commit d972510

Browse files
Introduce a batch-split config for DSv3
PiperOrigin-RevId: 883159766
1 parent c6b84c1 commit d972510

3 files changed

Lines changed: 93 additions & 17 deletions

File tree

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for DeepSeek V3 - 671B that uses fsdp on two logical axes
16+
17+
# For DeepSeek default device-limited routing,
18+
# please set n_routing_groups=8 and topk_routing_group=4 in your command-line arguments.
19+
20+
base_emb_dim: 7168
21+
base_num_query_heads: 128
22+
base_num_kv_heads: 128
23+
base_mlp_dim: 18432
24+
base_moe_mlp_dim: 2048
25+
base_num_decoder_layers: 61
26+
first_num_dense_layers: 3
27+
mlp_activations: ["silu","linear"]
28+
vocab_size: 129280
29+
enable_dropout: False
30+
logits_via_embedding: False
31+
normalization_layer_epsilon: 1.0e-6
32+
num_experts: 256
33+
num_experts_per_tok: 8
34+
shared_experts: 1
35+
routed_scaling_factor: 2.5
36+
routed_score_func: "sigmoid"
37+
routed_bias: True
38+
decoder_block: "deepseek"
39+
# MLA
40+
attention_type: "mla"
41+
q_lora_rank: 1536
42+
kv_lora_rank: 512
43+
qk_nope_head_dim: 128
44+
qk_rope_head_dim: 64
45+
v_head_dim: 128
46+
mscale: 1.0
47+
# RoPE
48+
rope_type: "yarn"
49+
rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
50+
max_position_embeddings: 163840
51+
original_max_position_embeddings: 4096
52+
rope_factor: 40
53+
beta_fast: 32
54+
rope_interleave: True
55+
rope_truncate: True
56+
rope_attention_scaling: False
57+
58+
override_logical_axis_rules: True
59+
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']
60+
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
61+
logical_axis_rules: [
62+
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
63+
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
64+
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
65+
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
66+
['activation_norm_length', ['context']],
67+
['activation_heads', []],
68+
['activation_stage', 'stage'],
69+
['embed', ['fsdp']],
70+
['embed_no_exp', ['fsdp']],
71+
['q_lora', ['fsdp']],
72+
['kv_lora', ['fsdp']],
73+
['layers', 'stage'],
74+
['q_lora_up_proj', ['fsdp_transpose']],
75+
['kv_lora_up_proj', ['fsdp_transpose']],
76+
['q_heads', ['fsdp_transpose']],
77+
['kv_heads', ['fsdp_transpose']],
78+
['heads', ['fsdp_transpose']],
79+
['mlp', ['fsdp_transpose']],
80+
['fsdp_transpose_and_expert', ['fsdp_transpose', 'expert']],
81+
['fsdp_transpose_only', ['fsdp_transpose']],
82+
['expert_only', ['expert']],
83+
]

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ class ProfilerType(str, Enum):
219219
"deepseek2-236b",
220220
"deepseek3-671b",
221221
"deepseek3-671b-2dfsdp",
222+
"deepseek3-671b-batchsplit",
222223
"deepseek3-test",
223224
"deepseek3-tiny",
224225
"deepseek3.2-671b",

src/maxtext/models/deepseek_batchsplit.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -180,26 +180,18 @@ def fn(weights):
180180
(routed_wi_0, routed_wi_1, routed_wo),
181181
(shared_wi_0, shared_wi_1, shared_wo),
182182
) = weights
183-
# All-gather across FSDP axis. Expert axis is used for FSDP in attention.
184-
wq_a = jax.lax.all_gather(wq_a, axis_name="expert", tiled=True, axis=1)
183+
# All-gather across FSDP axis.
185184
wq_a = jax.lax.all_gather(wq_a, axis_name="fsdp", tiled=True)
186-
wq_b = jax.lax.all_gather(wq_b, axis_name="expert", tiled=True, axis=1)
187185
wq_b = jax.lax.all_gather(wq_b, axis_name="fsdp", tiled=True)
188-
wkv_a = jax.lax.all_gather(wkv_a, axis_name="expert", tiled=True, axis=1)
189186
wkv_a = jax.lax.all_gather(wkv_a, axis_name="fsdp", tiled=True)
190-
wkv_b = jax.lax.all_gather(wkv_b, axis_name="expert", tiled=True, axis=1)
191187
wkv_b = jax.lax.all_gather(wkv_b, axis_name="fsdp", tiled=True)
192-
out = jax.lax.all_gather(out, axis_name="expert", tiled=True)
193188
out = jax.lax.all_gather(out, axis_name="fsdp", tiled=True, axis=2)
194189
gate = jax.lax.all_gather(gate, axis_name="fsdp", tiled=True)
195190
routed_wi_0 = jax.lax.all_gather(routed_wi_0, axis_name="fsdp", tiled=True)
196191
routed_wi_1 = jax.lax.all_gather(routed_wi_1, axis_name="fsdp", tiled=True)
197192
routed_wo = jax.lax.all_gather(routed_wo, axis_name="fsdp", tiled=True)
198-
shared_wi_0 = jax.lax.all_gather(shared_wi_0, axis_name="expert", tiled=True, axis=1)
199193
shared_wi_0 = jax.lax.all_gather(shared_wi_0, axis_name="fsdp", tiled=True)
200-
shared_wi_1 = jax.lax.all_gather(shared_wi_1, axis_name="expert", tiled=True, axis=1)
201194
shared_wi_1 = jax.lax.all_gather(shared_wi_1, axis_name="fsdp", tiled=True)
202-
shared_wo = jax.lax.all_gather(shared_wo, axis_name="expert", tiled=True)
203195
shared_wo = jax.lax.all_gather(shared_wo, axis_name="fsdp", tiled=True, axis=1)
204196
return (
205197
(
@@ -224,13 +216,13 @@ def fn(weights):
224216
jax.sharding.PartitionSpec(None),
225217
),
226218
(
227-
jax.sharding.PartitionSpec("fsdp", "expert"),
228-
jax.sharding.PartitionSpec("fsdp", "expert", None),
219+
jax.sharding.PartitionSpec("fsdp", None),
220+
jax.sharding.PartitionSpec("fsdp", None, None),
229221
jax.sharding.PartitionSpec(None),
230-
jax.sharding.PartitionSpec("fsdp", "expert"),
231-
jax.sharding.PartitionSpec("fsdp", "expert", None),
222+
jax.sharding.PartitionSpec("fsdp", None),
223+
jax.sharding.PartitionSpec("fsdp", None, None),
232224
jax.sharding.PartitionSpec(None),
233-
jax.sharding.PartitionSpec("expert", None, "fsdp"),
225+
jax.sharding.PartitionSpec(None, None, "fsdp"),
234226
),
235227
),
236228
(
@@ -244,9 +236,9 @@ def fn(weights):
244236
jax.sharding.PartitionSpec("fsdp", "expert", None),
245237
),
246238
(
247-
jax.sharding.PartitionSpec("fsdp", "expert"),
248-
jax.sharding.PartitionSpec("fsdp", "expert"),
249-
jax.sharding.PartitionSpec("expert", "fsdp"),
239+
jax.sharding.PartitionSpec("fsdp", None),
240+
jax.sharding.PartitionSpec("fsdp", None),
241+
jax.sharding.PartitionSpec(None, "fsdp"),
250242
),
251243
),
252244
),

0 commit comments

Comments
 (0)