Skip to content

Commit 8022d86

Browse files
Make weight all-gathers explicit for DSv3 batch-split
PiperOrigin-RevId: 883159768
1 parent 8d4f13a commit 8022d86

1 file changed

Lines changed: 120 additions & 0 deletions

File tree

src/maxtext/models/deepseek_batchsplit.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,125 @@ def merge(x, split_factor=2):
168168
return jnp.reshape(x, (-1,) + x.shape[2:])
169169

170170

171+
def gather_weights(weights, mesh):
172+
"""all-gathers FSDP sharded weights."""
173+
def fn(weights):
174+
(
175+
(pre_attn_norm, post_attn_norm),
176+
(wq_a, wq_b, q_norm, wkv_a, wkv_b, kv_norm, out),
177+
), (
178+
(gate, bias),
179+
(routed_wi_0, routed_wi_1, routed_wo),
180+
(shared_wi_0, shared_wi_1, shared_wo),
181+
) = weights
182+
# All-gather across FSDP axis. Expert axis is used for FSDP in attention.
183+
wq_a = jax.lax.all_gather(wq_a, axis_name="expert", tiled=True, axis=1)
184+
wq_a = jax.lax.all_gather(wq_a, axis_name="fsdp", tiled=True)
185+
wq_b = jax.lax.all_gather(wq_b, axis_name="expert", tiled=True, axis=1)
186+
wq_b = jax.lax.all_gather(wq_b, axis_name="fsdp", tiled=True)
187+
wkv_a = jax.lax.all_gather(wkv_a, axis_name="expert", tiled=True, axis=1)
188+
wkv_a = jax.lax.all_gather(wkv_a, axis_name="fsdp", tiled=True)
189+
wkv_b = jax.lax.all_gather(wkv_b, axis_name="expert", tiled=True, axis=1)
190+
wkv_b = jax.lax.all_gather(wkv_b, axis_name="fsdp", tiled=True)
191+
out = jax.lax.all_gather(out, axis_name="expert", tiled=True)
192+
out = jax.lax.all_gather(out, axis_name="fsdp", tiled=True, axis=2)
193+
gate = jax.lax.all_gather(gate, axis_name="fsdp", tiled=True)
194+
routed_wi_0 = jax.lax.all_gather(routed_wi_0, axis_name="fsdp", tiled=True)
195+
routed_wi_1 = jax.lax.all_gather(routed_wi_1, axis_name="fsdp", tiled=True)
196+
routed_wo = jax.lax.all_gather(routed_wo, axis_name="fsdp", tiled=True)
197+
shared_wi_0 = jax.lax.all_gather(shared_wi_0, axis_name="expert", tiled=True, axis=1)
198+
shared_wi_0 = jax.lax.all_gather(shared_wi_0, axis_name="fsdp", tiled=True)
199+
shared_wi_1 = jax.lax.all_gather(shared_wi_1, axis_name="expert", tiled=True, axis=1)
200+
shared_wi_1 = jax.lax.all_gather(shared_wi_1, axis_name="fsdp", tiled=True)
201+
shared_wo = jax.lax.all_gather(shared_wo, axis_name="expert", tiled=True)
202+
shared_wo = jax.lax.all_gather(shared_wo, axis_name="fsdp", tiled=True, axis=1)
203+
return (
204+
(
205+
(pre_attn_norm, post_attn_norm),
206+
(wq_a, wq_b, q_norm, wkv_a, wkv_b, kv_norm, out),
207+
),
208+
(
209+
(gate, bias),
210+
(routed_wi_0, routed_wi_1, routed_wo),
211+
(shared_wi_0, shared_wi_1, shared_wo),
212+
),
213+
)
214+
215+
return jax.shard_map(
216+
fn,
217+
mesh=mesh,
218+
in_specs=(
219+
(
220+
(
221+
(
222+
jax.sharding.PartitionSpec(None),
223+
jax.sharding.PartitionSpec(None),
224+
),
225+
(
226+
jax.sharding.PartitionSpec("fsdp", "expert"),
227+
jax.sharding.PartitionSpec("fsdp", "expert", None),
228+
jax.sharding.PartitionSpec(None),
229+
jax.sharding.PartitionSpec("fsdp", "expert"),
230+
jax.sharding.PartitionSpec("fsdp", "expert", None),
231+
jax.sharding.PartitionSpec(None),
232+
jax.sharding.PartitionSpec("expert", None, "fsdp"),
233+
),
234+
),
235+
(
236+
(
237+
jax.sharding.PartitionSpec("fsdp", None),
238+
jax.sharding.PartitionSpec(None),
239+
),
240+
(
241+
jax.sharding.PartitionSpec("fsdp", None, "expert"),
242+
jax.sharding.PartitionSpec("fsdp", None, "expert"),
243+
jax.sharding.PartitionSpec("fsdp", "expert", None),
244+
),
245+
(
246+
jax.sharding.PartitionSpec("fsdp", "expert"),
247+
jax.sharding.PartitionSpec("fsdp", "expert"),
248+
jax.sharding.PartitionSpec("expert", "fsdp"),
249+
),
250+
),
251+
),
252+
),
253+
out_specs=(
254+
(
255+
(
256+
jax.sharding.PartitionSpec(None),
257+
jax.sharding.PartitionSpec(None),
258+
),
259+
(
260+
jax.sharding.PartitionSpec(None, None),
261+
jax.sharding.PartitionSpec(None, None, None),
262+
jax.sharding.PartitionSpec(None),
263+
jax.sharding.PartitionSpec(None, None),
264+
jax.sharding.PartitionSpec(None, None, None),
265+
jax.sharding.PartitionSpec(None),
266+
jax.sharding.PartitionSpec(None, None, None),
267+
),
268+
),
269+
(
270+
(
271+
jax.sharding.PartitionSpec(None, None),
272+
jax.sharding.PartitionSpec(None),
273+
),
274+
(
275+
jax.sharding.PartitionSpec(None, None, "expert"),
276+
jax.sharding.PartitionSpec(None, None, "expert"),
277+
jax.sharding.PartitionSpec(None, "expert", None),
278+
),
279+
(
280+
jax.sharding.PartitionSpec(None, None),
281+
jax.sharding.PartitionSpec(None, None),
282+
jax.sharding.PartitionSpec(None, None),
283+
),
284+
),
285+
),
286+
check_vma=False,
287+
)(weights)
288+
289+
171290
def scan_batch_split_layers(
172291
inputs,
173292
params,
@@ -183,6 +302,7 @@ def scan_batch_split_layers(
183302
"""Scans the layers with batch-split schedule."""
184303

185304
def batch_split_scan_fn(inputs, weights, dpos, dseg):
305+
weights = gather_weights(weights, mesh)
186306
xs = batch_split_schedule(
187307
inputs,
188308
weights,

0 commit comments

Comments
 (0)