@@ -168,6 +168,117 @@ def merge(x, split_factor=2):
168168 return jnp .reshape (x , (- 1 ,) + x .shape [2 :])
169169
170170
171+ def gather_weights (weights , mesh ):
172+ """all-gathers FSDP sharded weights."""
173+ def fn (weights ):
174+ (
175+ (pre_attn_norm , post_attn_norm ),
176+ (wq_a , wq_b , q_norm , wkv_a , wkv_b , kv_norm , out ),
177+ ), (
178+ (gate , bias ),
179+ (routed_wi_0 , routed_wi_1 , routed_wo ),
180+ (shared_wi_0 , shared_wi_1 , shared_wo ),
181+ ) = weights
182+ # All-gather across FSDP axis.
183+ wq_a = jax .lax .all_gather (wq_a , axis_name = "fsdp" , tiled = True )
184+ wq_b = jax .lax .all_gather (wq_b , axis_name = "fsdp" , tiled = True )
185+ wkv_a = jax .lax .all_gather (wkv_a , axis_name = "fsdp" , tiled = True )
186+ wkv_b = jax .lax .all_gather (wkv_b , axis_name = "fsdp" , tiled = True )
187+ out = jax .lax .all_gather (out , axis_name = "fsdp" , tiled = True , axis = 2 )
188+ gate = jax .lax .all_gather (gate , axis_name = "fsdp" , tiled = True )
189+ routed_wi_0 = jax .lax .all_gather (routed_wi_0 , axis_name = "fsdp" , tiled = True )
190+ routed_wi_1 = jax .lax .all_gather (routed_wi_1 , axis_name = "fsdp" , tiled = True )
191+ routed_wo = jax .lax .all_gather (routed_wo , axis_name = "fsdp" , tiled = True )
192+ shared_wi_0 = jax .lax .all_gather (shared_wi_0 , axis_name = "fsdp" , tiled = True )
193+ shared_wi_1 = jax .lax .all_gather (shared_wi_1 , axis_name = "fsdp" , tiled = True )
194+ shared_wo = jax .lax .all_gather (shared_wo , axis_name = "fsdp" , tiled = True , axis = 1 )
195+ return (
196+ (
197+ (pre_attn_norm , post_attn_norm ),
198+ (wq_a , wq_b , q_norm , wkv_a , wkv_b , kv_norm , out ),
199+ ),
200+ (
201+ (gate , bias ),
202+ (routed_wi_0 , routed_wi_1 , routed_wo ),
203+ (shared_wi_0 , shared_wi_1 , shared_wo ),
204+ ),
205+ )
206+
207+ return jax .shard_map (
208+ fn ,
209+ mesh = mesh ,
210+ in_specs = (
211+ (
212+ (
213+ (
214+ jax .sharding .PartitionSpec (None ),
215+ jax .sharding .PartitionSpec (None ),
216+ ),
217+ (
218+ jax .sharding .PartitionSpec ("fsdp" , None ),
219+ jax .sharding .PartitionSpec ("fsdp" , None , None ),
220+ jax .sharding .PartitionSpec (None ),
221+ jax .sharding .PartitionSpec ("fsdp" , None ),
222+ jax .sharding .PartitionSpec ("fsdp" , None , None ),
223+ jax .sharding .PartitionSpec (None ),
224+ jax .sharding .PartitionSpec (None , None , "fsdp" ),
225+ ),
226+ ),
227+ (
228+ (
229+ jax .sharding .PartitionSpec ("fsdp" , None ),
230+ jax .sharding .PartitionSpec (None ),
231+ ),
232+ (
233+ jax .sharding .PartitionSpec ("fsdp" , None , "expert" ),
234+ jax .sharding .PartitionSpec ("fsdp" , None , "expert" ),
235+ jax .sharding .PartitionSpec ("fsdp" , "expert" , None ),
236+ ),
237+ (
238+ jax .sharding .PartitionSpec ("fsdp" , None ),
239+ jax .sharding .PartitionSpec ("fsdp" , None ),
240+ jax .sharding .PartitionSpec (None , "fsdp" ),
241+ ),
242+ ),
243+ ),
244+ ),
245+ out_specs = (
246+ (
247+ (
248+ jax .sharding .PartitionSpec (None ),
249+ jax .sharding .PartitionSpec (None ),
250+ ),
251+ (
252+ jax .sharding .PartitionSpec (None , None ),
253+ jax .sharding .PartitionSpec (None , None , None ),
254+ jax .sharding .PartitionSpec (None ),
255+ jax .sharding .PartitionSpec (None , None ),
256+ jax .sharding .PartitionSpec (None , None , None ),
257+ jax .sharding .PartitionSpec (None ),
258+ jax .sharding .PartitionSpec (None , None , None ),
259+ ),
260+ ),
261+ (
262+ (
263+ jax .sharding .PartitionSpec (None , None ),
264+ jax .sharding .PartitionSpec (None ),
265+ ),
266+ (
267+ jax .sharding .PartitionSpec (None , None , "expert" ),
268+ jax .sharding .PartitionSpec (None , None , "expert" ),
269+ jax .sharding .PartitionSpec (None , "expert" , None ),
270+ ),
271+ (
272+ jax .sharding .PartitionSpec (None , None ),
273+ jax .sharding .PartitionSpec (None , None ),
274+ jax .sharding .PartitionSpec (None , None ),
275+ ),
276+ ),
277+ ),
278+ check_vma = False ,
279+ )(weights )
280+
281+
171282def scan_batch_split_layers (
172283 inputs ,
173284 params ,
@@ -183,6 +294,7 @@ def scan_batch_split_layers(
183294 """Scans the layers with batch-split schedule."""
184295
185296 def batch_split_scan_fn (inputs , weights , dpos , dseg ):
297+ weights = gather_weights (weights , mesh )
186298 xs = batch_split_schedule (
187299 inputs ,
188300 weights ,
0 commit comments