@@ -168,6 +168,125 @@ def merge(x, split_factor=2):
168168 return jnp .reshape (x , (- 1 ,) + x .shape [2 :])
169169
170170
171+ def gather_weights (weights , mesh ):
172+ """all-gathers FSDP sharded weights."""
173+ def fn (weights ):
174+ (
175+ (pre_attn_norm , post_attn_norm ),
176+ (wq_a , wq_b , q_norm , wkv_a , wkv_b , kv_norm , out ),
177+ ), (
178+ (gate , bias ),
179+ (routed_wi_0 , routed_wi_1 , routed_wo ),
180+ (shared_wi_0 , shared_wi_1 , shared_wo ),
181+ ) = weights
182+ # All-gather across FSDP axis. Expert axis is used for FSDP in attention.
183+ wq_a = jax .lax .all_gather (wq_a , axis_name = "expert" , tiled = True , axis = 1 )
184+ wq_a = jax .lax .all_gather (wq_a , axis_name = "fsdp" , tiled = True )
185+ wq_b = jax .lax .all_gather (wq_b , axis_name = "expert" , tiled = True , axis = 1 )
186+ wq_b = jax .lax .all_gather (wq_b , axis_name = "fsdp" , tiled = True )
187+ wkv_a = jax .lax .all_gather (wkv_a , axis_name = "expert" , tiled = True , axis = 1 )
188+ wkv_a = jax .lax .all_gather (wkv_a , axis_name = "fsdp" , tiled = True )
189+ wkv_b = jax .lax .all_gather (wkv_b , axis_name = "expert" , tiled = True , axis = 1 )
190+ wkv_b = jax .lax .all_gather (wkv_b , axis_name = "fsdp" , tiled = True )
191+ out = jax .lax .all_gather (out , axis_name = "expert" , tiled = True )
192+ out = jax .lax .all_gather (out , axis_name = "fsdp" , tiled = True , axis = 2 )
193+ gate = jax .lax .all_gather (gate , axis_name = "fsdp" , tiled = True )
194+ routed_wi_0 = jax .lax .all_gather (routed_wi_0 , axis_name = "fsdp" , tiled = True )
195+ routed_wi_1 = jax .lax .all_gather (routed_wi_1 , axis_name = "fsdp" , tiled = True )
196+ routed_wo = jax .lax .all_gather (routed_wo , axis_name = "fsdp" , tiled = True )
197+ shared_wi_0 = jax .lax .all_gather (shared_wi_0 , axis_name = "expert" , tiled = True , axis = 1 )
198+ shared_wi_0 = jax .lax .all_gather (shared_wi_0 , axis_name = "fsdp" , tiled = True )
199+ shared_wi_1 = jax .lax .all_gather (shared_wi_1 , axis_name = "expert" , tiled = True , axis = 1 )
200+ shared_wi_1 = jax .lax .all_gather (shared_wi_1 , axis_name = "fsdp" , tiled = True )
201+ shared_wo = jax .lax .all_gather (shared_wo , axis_name = "expert" , tiled = True )
202+ shared_wo = jax .lax .all_gather (shared_wo , axis_name = "fsdp" , tiled = True , axis = 1 )
203+ return (
204+ (
205+ (pre_attn_norm , post_attn_norm ),
206+ (wq_a , wq_b , q_norm , wkv_a , wkv_b , kv_norm , out ),
207+ ),
208+ (
209+ (gate , bias ),
210+ (routed_wi_0 , routed_wi_1 , routed_wo ),
211+ (shared_wi_0 , shared_wi_1 , shared_wo ),
212+ ),
213+ )
214+
215+ return jax .shard_map (
216+ fn ,
217+ mesh = mesh ,
218+ in_specs = (
219+ (
220+ (
221+ (
222+ jax .sharding .PartitionSpec (None ),
223+ jax .sharding .PartitionSpec (None ),
224+ ),
225+ (
226+ jax .sharding .PartitionSpec ("fsdp" , "expert" ),
227+ jax .sharding .PartitionSpec ("fsdp" , "expert" , None ),
228+ jax .sharding .PartitionSpec (None ),
229+ jax .sharding .PartitionSpec ("fsdp" , "expert" ),
230+ jax .sharding .PartitionSpec ("fsdp" , "expert" , None ),
231+ jax .sharding .PartitionSpec (None ),
232+ jax .sharding .PartitionSpec ("expert" , None , "fsdp" ),
233+ ),
234+ ),
235+ (
236+ (
237+ jax .sharding .PartitionSpec ("fsdp" , None ),
238+ jax .sharding .PartitionSpec (None ),
239+ ),
240+ (
241+ jax .sharding .PartitionSpec ("fsdp" , None , "expert" ),
242+ jax .sharding .PartitionSpec ("fsdp" , None , "expert" ),
243+ jax .sharding .PartitionSpec ("fsdp" , "expert" , None ),
244+ ),
245+ (
246+ jax .sharding .PartitionSpec ("fsdp" , "expert" ),
247+ jax .sharding .PartitionSpec ("fsdp" , "expert" ),
248+ jax .sharding .PartitionSpec ("expert" , "fsdp" ),
249+ ),
250+ ),
251+ ),
252+ ),
253+ out_specs = (
254+ (
255+ (
256+ jax .sharding .PartitionSpec (None ),
257+ jax .sharding .PartitionSpec (None ),
258+ ),
259+ (
260+ jax .sharding .PartitionSpec (None , None ),
261+ jax .sharding .PartitionSpec (None , None , None ),
262+ jax .sharding .PartitionSpec (None ),
263+ jax .sharding .PartitionSpec (None , None ),
264+ jax .sharding .PartitionSpec (None , None , None ),
265+ jax .sharding .PartitionSpec (None ),
266+ jax .sharding .PartitionSpec (None , None , None ),
267+ ),
268+ ),
269+ (
270+ (
271+ jax .sharding .PartitionSpec (None , None ),
272+ jax .sharding .PartitionSpec (None ),
273+ ),
274+ (
275+ jax .sharding .PartitionSpec (None , None , "expert" ),
276+ jax .sharding .PartitionSpec (None , None , "expert" ),
277+ jax .sharding .PartitionSpec (None , "expert" , None ),
278+ ),
279+ (
280+ jax .sharding .PartitionSpec (None , None ),
281+ jax .sharding .PartitionSpec (None , None ),
282+ jax .sharding .PartitionSpec (None , None ),
283+ ),
284+ ),
285+ ),
286+ check_vma = False ,
287+ )(weights )
288+
289+
171290def scan_batch_split_layers (
172291 inputs ,
173292 params ,
@@ -183,6 +302,7 @@ def scan_batch_split_layers(
183302 """Scans the layers with batch-split schedule."""
184303
185304 def batch_split_scan_fn (inputs , weights , dpos , dseg ):
305+ weights = gather_weights (weights , mesh )
186306 xs = batch_split_schedule (
187307 inputs ,
188308 weights ,
0 commit comments