@@ -180,26 +180,18 @@ def fn(weights):
180180 (routed_wi_0 , routed_wi_1 , routed_wo ),
181181 (shared_wi_0 , shared_wi_1 , shared_wo ),
182182 ) = weights
183- # All-gather across FSDP axis. Expert axis is used for FSDP in attention.
184- wq_a = jax .lax .all_gather (wq_a , axis_name = "expert" , tiled = True , axis = 1 )
183+ # All-gather across FSDP axis.
185184 wq_a = jax .lax .all_gather (wq_a , axis_name = "fsdp" , tiled = True )
186- wq_b = jax .lax .all_gather (wq_b , axis_name = "expert" , tiled = True , axis = 1 )
187185 wq_b = jax .lax .all_gather (wq_b , axis_name = "fsdp" , tiled = True )
188- wkv_a = jax .lax .all_gather (wkv_a , axis_name = "expert" , tiled = True , axis = 1 )
189186 wkv_a = jax .lax .all_gather (wkv_a , axis_name = "fsdp" , tiled = True )
190- wkv_b = jax .lax .all_gather (wkv_b , axis_name = "expert" , tiled = True , axis = 1 )
191187 wkv_b = jax .lax .all_gather (wkv_b , axis_name = "fsdp" , tiled = True )
192- out = jax .lax .all_gather (out , axis_name = "expert" , tiled = True )
193188 out = jax .lax .all_gather (out , axis_name = "fsdp" , tiled = True , axis = 2 )
194189 gate = jax .lax .all_gather (gate , axis_name = "fsdp" , tiled = True )
195190 routed_wi_0 = jax .lax .all_gather (routed_wi_0 , axis_name = "fsdp" , tiled = True )
196191 routed_wi_1 = jax .lax .all_gather (routed_wi_1 , axis_name = "fsdp" , tiled = True )
197192 routed_wo = jax .lax .all_gather (routed_wo , axis_name = "fsdp" , tiled = True )
198- shared_wi_0 = jax .lax .all_gather (shared_wi_0 , axis_name = "expert" , tiled = True , axis = 1 )
199193 shared_wi_0 = jax .lax .all_gather (shared_wi_0 , axis_name = "fsdp" , tiled = True )
200- shared_wi_1 = jax .lax .all_gather (shared_wi_1 , axis_name = "expert" , tiled = True , axis = 1 )
201194 shared_wi_1 = jax .lax .all_gather (shared_wi_1 , axis_name = "fsdp" , tiled = True )
202- shared_wo = jax .lax .all_gather (shared_wo , axis_name = "expert" , tiled = True )
203195 shared_wo = jax .lax .all_gather (shared_wo , axis_name = "fsdp" , tiled = True , axis = 1 )
204196 return (
205197 (
@@ -224,13 +216,13 @@ def fn(weights):
224216 jax .sharding .PartitionSpec (None ),
225217 ),
226218 (
227- jax .sharding .PartitionSpec ("fsdp" , "expert" ),
228- jax .sharding .PartitionSpec ("fsdp" , "expert" , None ),
219+ jax .sharding .PartitionSpec ("fsdp" , None ),
220+ jax .sharding .PartitionSpec ("fsdp" , None , None ),
229221 jax .sharding .PartitionSpec (None ),
230- jax .sharding .PartitionSpec ("fsdp" , "expert" ),
231- jax .sharding .PartitionSpec ("fsdp" , "expert" , None ),
222+ jax .sharding .PartitionSpec ("fsdp" , None ),
223+ jax .sharding .PartitionSpec ("fsdp" , None , None ),
232224 jax .sharding .PartitionSpec (None ),
233- jax .sharding .PartitionSpec ("expert" , None , "fsdp" ),
225+ jax .sharding .PartitionSpec (None , None , "fsdp" ),
234226 ),
235227 ),
236228 (
@@ -244,9 +236,9 @@ def fn(weights):
244236 jax .sharding .PartitionSpec ("fsdp" , "expert" , None ),
245237 ),
246238 (
247- jax .sharding .PartitionSpec ("fsdp" , "expert" ),
248- jax .sharding .PartitionSpec ("fsdp" , "expert" ),
249- jax .sharding .PartitionSpec ("expert" , "fsdp" ),
239+ jax .sharding .PartitionSpec ("fsdp" , None ),
240+ jax .sharding .PartitionSpec ("fsdp" , None ),
241+ jax .sharding .PartitionSpec (None , "fsdp" ),
250242 ),
251243 ),
252244 ),
0 commit comments