Skip to content

Commit 80e45b4

Browse files
Introduce a batch-split config for DSv3
PiperOrigin-RevId: 883159766
1 parent 00ef5de commit 80e45b4

3 files changed

Lines changed: 196 additions & 0 deletions

File tree

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# model config for DeepSeek V3 - 671B that uses fsdp on two logical axes
16+
17+
# For DeepSeek default device-limited routing,
18+
# please set n_routing_groups=8 and topk_routing_group=4 in your command-line arguments.
19+
20+
base_emb_dim: 7168
21+
base_num_query_heads: 128
22+
base_num_kv_heads: 128
23+
base_mlp_dim: 18432
24+
base_moe_mlp_dim: 2048
25+
base_num_decoder_layers: 61
26+
first_num_dense_layers: 3
27+
mlp_activations: ["silu","linear"]
28+
vocab_size: 129280
29+
enable_dropout: False
30+
logits_via_embedding: False
31+
normalization_layer_epsilon: 1.0e-6
32+
num_experts: 256
33+
num_experts_per_tok: 8
34+
shared_experts: 1
35+
routed_scaling_factor: 2.5
36+
routed_score_func: "sigmoid"
37+
routed_bias: True
38+
decoder_block: "deepseek"
39+
# MLA
40+
attention_type: "mla"
41+
q_lora_rank: 1536
42+
kv_lora_rank: 512
43+
qk_nope_head_dim: 128
44+
qk_rope_head_dim: 64
45+
v_head_dim: 128
46+
mscale: 1.0
47+
# RoPE
48+
rope_type: "yarn"
49+
rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000
50+
max_position_embeddings: 163840
51+
original_max_position_embeddings: 4096
52+
rope_factor: 40
53+
beta_fast: 32
54+
rope_interleave: True
55+
rope_truncate: True
56+
rope_attention_scaling: False
57+
58+
override_logical_axis_rules: True
59+
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']
60+
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
61+
logical_axis_rules: [
62+
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
63+
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
64+
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
65+
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
66+
['activation_norm_length', ['context']],
67+
['activation_heads', []],
68+
['activation_stage', 'stage'],
69+
['embed', ['fsdp']],
70+
['embed_no_exp', ['fsdp']],
71+
['q_lora', ['fsdp']],
72+
['kv_lora', ['fsdp']],
73+
['layers', 'stage'],
74+
['q_lora_up_proj', ['fsdp_transpose']],
75+
['kv_lora_up_proj', ['fsdp_transpose']],
76+
['q_heads', ['fsdp_transpose']],
77+
['kv_heads', ['fsdp_transpose']],
78+
['heads', ['fsdp_transpose']],
79+
['mlp', ['fsdp_transpose']],
80+
['fsdp_transpose_and_expert', ['fsdp_transpose', 'expert']],
81+
['fsdp_transpose_only', ['fsdp_transpose']],
82+
['expert_only', ['expert']],
83+
]

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ class ProfilerType(str, Enum):
217217
"deepseek2-236b",
218218
"deepseek3-671b",
219219
"deepseek3-671b-2dfsdp",
220+
"deepseek3-671b-batchsplit",
220221
"deepseek3-test",
221222
"deepseek3-tiny",
222223
"deepseek3.2-671b",

src/maxtext/models/deepseek_batchsplit.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,117 @@ def merge(x, split_factor=2):
168168
return jnp.reshape(x, (-1,) + x.shape[2:])
169169

170170

171+
def gather_weights(weights, mesh):
172+
"""all-gathers FSDP sharded weights."""
173+
def fn(weights):
174+
(
175+
(pre_attn_norm, post_attn_norm),
176+
(wq_a, wq_b, q_norm, wkv_a, wkv_b, kv_norm, out),
177+
), (
178+
(gate, bias),
179+
(routed_wi_0, routed_wi_1, routed_wo),
180+
(shared_wi_0, shared_wi_1, shared_wo),
181+
) = weights
182+
# All-gather across FSDP axis.
183+
wq_a = jax.lax.all_gather(wq_a, axis_name="fsdp", tiled=True)
184+
wq_b = jax.lax.all_gather(wq_b, axis_name="fsdp", tiled=True)
185+
wkv_a = jax.lax.all_gather(wkv_a, axis_name="fsdp", tiled=True)
186+
wkv_b = jax.lax.all_gather(wkv_b, axis_name="fsdp", tiled=True)
187+
out = jax.lax.all_gather(out, axis_name="fsdp", tiled=True, axis=2)
188+
gate = jax.lax.all_gather(gate, axis_name="fsdp", tiled=True)
189+
routed_wi_0 = jax.lax.all_gather(routed_wi_0, axis_name="fsdp", tiled=True)
190+
routed_wi_1 = jax.lax.all_gather(routed_wi_1, axis_name="fsdp", tiled=True)
191+
routed_wo = jax.lax.all_gather(routed_wo, axis_name="fsdp", tiled=True)
192+
shared_wi_0 = jax.lax.all_gather(shared_wi_0, axis_name="fsdp", tiled=True)
193+
shared_wi_1 = jax.lax.all_gather(shared_wi_1, axis_name="fsdp", tiled=True)
194+
shared_wo = jax.lax.all_gather(shared_wo, axis_name="fsdp", tiled=True, axis=1)
195+
return (
196+
(
197+
(pre_attn_norm, post_attn_norm),
198+
(wq_a, wq_b, q_norm, wkv_a, wkv_b, kv_norm, out),
199+
),
200+
(
201+
(gate, bias),
202+
(routed_wi_0, routed_wi_1, routed_wo),
203+
(shared_wi_0, shared_wi_1, shared_wo),
204+
),
205+
)
206+
207+
return jax.shard_map(
208+
fn,
209+
mesh=mesh,
210+
in_specs=(
211+
(
212+
(
213+
(
214+
jax.sharding.PartitionSpec(None),
215+
jax.sharding.PartitionSpec(None),
216+
),
217+
(
218+
jax.sharding.PartitionSpec("fsdp", None),
219+
jax.sharding.PartitionSpec("fsdp", None, None),
220+
jax.sharding.PartitionSpec(None),
221+
jax.sharding.PartitionSpec("fsdp", None),
222+
jax.sharding.PartitionSpec("fsdp", None, None),
223+
jax.sharding.PartitionSpec(None),
224+
jax.sharding.PartitionSpec(None, None, "fsdp"),
225+
),
226+
),
227+
(
228+
(
229+
jax.sharding.PartitionSpec("fsdp", None),
230+
jax.sharding.PartitionSpec(None),
231+
),
232+
(
233+
jax.sharding.PartitionSpec("fsdp", None, "expert"),
234+
jax.sharding.PartitionSpec("fsdp", None, "expert"),
235+
jax.sharding.PartitionSpec("fsdp", "expert", None),
236+
),
237+
(
238+
jax.sharding.PartitionSpec("fsdp", None),
239+
jax.sharding.PartitionSpec("fsdp", None),
240+
jax.sharding.PartitionSpec(None, "fsdp"),
241+
),
242+
),
243+
),
244+
),
245+
out_specs=(
246+
(
247+
(
248+
jax.sharding.PartitionSpec(None),
249+
jax.sharding.PartitionSpec(None),
250+
),
251+
(
252+
jax.sharding.PartitionSpec(None, None),
253+
jax.sharding.PartitionSpec(None, None, None),
254+
jax.sharding.PartitionSpec(None),
255+
jax.sharding.PartitionSpec(None, None),
256+
jax.sharding.PartitionSpec(None, None, None),
257+
jax.sharding.PartitionSpec(None),
258+
jax.sharding.PartitionSpec(None, None, None),
259+
),
260+
),
261+
(
262+
(
263+
jax.sharding.PartitionSpec(None, None),
264+
jax.sharding.PartitionSpec(None),
265+
),
266+
(
267+
jax.sharding.PartitionSpec(None, None, "expert"),
268+
jax.sharding.PartitionSpec(None, None, "expert"),
269+
jax.sharding.PartitionSpec(None, "expert", None),
270+
),
271+
(
272+
jax.sharding.PartitionSpec(None, None),
273+
jax.sharding.PartitionSpec(None, None),
274+
jax.sharding.PartitionSpec(None, None),
275+
),
276+
),
277+
),
278+
check_vma=False,
279+
)(weights)
280+
281+
171282
def scan_batch_split_layers(
172283
inputs,
173284
params,
@@ -183,6 +294,7 @@ def scan_batch_split_layers(
183294
"""Scans the layers with batch-split schedule."""
184295

185296
def batch_split_scan_fn(inputs, weights, dpos, dseg):
297+
weights = gather_weights(weights, mesh)
186298
xs = batch_split_schedule(
187299
inputs,
188300
weights,

0 commit comments

Comments
 (0)