@@ -2515,6 +2515,26 @@ def _dispatch_and_combine(self, x: Tensor) -> Tensor:
25152515 # in the implementation, which is inefficient.
25162516 assert cfg .input_dim_to_partition_spec ["bsm" ][- 1 ] is None
25172517 assert cfg .output_dim_to_partition_spec ["emh" ][- 2 ] is None
2518+ assert cfg .output_dim_to_partition_spec ["ehm" ][- 1 ] is None
2519+
2520+ # Pass expert weights into shard_map with their fsdp sharding preserved,
2521+ # then explicitly all-gather along fsdp inside a jax.checkpoint wrapper.
2522+ # This prevents SPMD from hoisting one giant all-gather over the stacked
2523+ # [num_layers, E, M/fsdp, H] parameter tensor outside the scan loop.
2524+ # The JAX-level all_gather inside checkpoint is recomputed per-iteration in
2525+ # backward instead of being saved as a stacked residual.
2526+
2527+ # The shard_map() implementation below assumes dim H can only be
2528+ # sharded along the model axis.
2529+ assert cfg .output_dim_to_partition_spec ["emh" ][- 1 ] == "model"
2530+ assert cfg .output_dim_to_partition_spec ["ehm" ][- 2 ] == "model"
2531+ assert cfg .dim_to_mesh_axis_map ["emh" ][- 1 ] == "model"
2532+ assert cfg .dim_to_mesh_axis_map ["ehm" ][- 2 ] == "model"
2533+
2534+ emh_gather_axes = cfg .dim_to_mesh_axis_map ["emh" ][1 ]
2535+ ehm_gather_axes = cfg .dim_to_mesh_axis_map ["ehm" ][2 ]
2536+ bsm_out_spec = cfg .output_dim_to_partition_spec ["bsm" ]
2537+ bskm_out_spec = PartitionSpec (* bsm_out_spec [:2 ], None , bsm_out_spec [2 ])
25182538
25192539 mesh = thread_resources .env .physical_mesh
25202540
@@ -2524,13 +2544,12 @@ def _dispatch_and_combine(self, x: Tensor) -> Tensor:
25242544 in_specs = (
25252545 cfg .input_dim_to_partition_spec ["bsm" ],
25262546 cfg .input_dim_to_partition_spec ["bsm" ],
2527- cfg .input_dim_to_partition_spec ["bsm" ],
2528- cfg .output_dim_to_partition_spec ["emh" ],
2529- cfg .output_dim_to_partition_spec ["emh" ],
2530- cfg .output_dim_to_partition_spec ["ehm" ],
2547+ cfg .dim_to_mesh_axis_map ["emh" ],
2548+ cfg .dim_to_mesh_axis_map ["emh" ],
2549+ cfg .dim_to_mesh_axis_map ["ehm" ],
25312550 ),
25322551 out_specs = (
2533- cfg . output_dim_to_partition_spec [ "bsm" ] ,
2552+ bskm_out_spec ,
25342553 * self ._additional_shmap_output_sharding (mesh ),
25352554 ),
25362555 # Disables a checking pass which jax can't apply when there's a triton | pallas
@@ -2540,10 +2559,9 @@ def _dispatch_and_combine(self, x: Tensor) -> Tensor:
25402559 def wrapper (
25412560 x : Tensor ,
25422561 gate_assignment : Tensor ,
2543- expert_weights : Tensor ,
2544- wi_0 : Tensor ,
2545- wi_1 : Tensor ,
2546- wo : Tensor ,
2562+ wi_0_sharded : Tensor ,
2563+ wi_1_sharded : Tensor ,
2564+ wo_sharded : Tensor ,
25472565 ) -> tuple [Tensor , ...]:
25482566 """Computes unsorted outputs for one sharded block.
25492567
@@ -2554,84 +2572,122 @@ def wrapper(
25542572 Args:
25552573 x: the sharded input batch of the shape [G=B', S', M].
25562574 gate_assignment: [G=B', S', K].
2557- expert_weights: [G=B', S', K].
2558- wi_0: the input projection of [E, M, H'].
2559- wi_1: the input projection of [E, M, H'].
2560- wo: the output projection of [E, H', M'].
2575+ wi_0_sharded: the input projection of [E, M/fsdp, H/model] (local).
2576+ wi_1_sharded: the input projection of [E, M/fsdp, H/model] (local).
2577+ wo_sharded: the output projection of [E, H/model, M/fsdp] (local).
25612578
25622579 Returns:
25632580 A tuple of
2564- - A tensor of shape [G=B', S', M'].
2581+ - A tensor of shape [G=B', S', K, M'].
25652582 - ... optional series of tensors from `_dispatch_hook()[2]`.
25662583 """
25672584 logging .info ("Setting the effective group_size=%r" , x .shape [0 ])
25682585 B , S , M = x .shape # pylint: disable=invalid-name
2569- # [B' x S' x K]
2570- gate_assignment = gate_assignment .reshape ((- 1 ))
2571- # x[sorted_indices[:, i]] for i in range(S * K) represents tokens sorted
2572- # by which experts they are assigned to.
2573- # [B' x S' x K]
2574- sorted_indices = jnp .argsort (gate_assignment )
2575- token_indices = sorted_indices // num_experts_per_token
2576- # Dispatch the tokens.
2577- combine_indices = jnp .argsort (sorted_indices )
2578- # [B' x S' x K, M]
2579- sorted_inputs = _custom_gather (
2580- x .reshape (- 1 , M ), token_indices , combine_indices , unique_indices = False
2581- )
2582- tokens_per_expert = jnp .bincount (gate_assignment , length = cfg .num_experts )
2583-
2584- sorted_inputs , tokens_per_expert , additional_outputs , residuals = self ._dispatch_hook (
2585- sorted_inputs = sorted_inputs ,
2586- tokens_per_expert = tokens_per_expert ,
2587- )
25882586
2589- # [B' x S' x K, H']
2590- activation_0 = self ._padded_gmm (sorted_inputs , wi_0 , tokens_per_expert )
2591- activation_0 = get_activation_fn (cfg .activation [0 ])(activation_0 )
2592-
2593- activation_1 = self ._padded_gmm (sorted_inputs , wi_1 , tokens_per_expert )
2594- activation_1 = get_activation_fn (cfg .activation [1 ])(activation_1 )
2595-
2596- intermediate = activation_0 * activation_1
2597-
2598- if cfg .structure in ["prenorm" , "hybridnorm" , "nonorm" , "v2" ]:
2599- intermediate = self .dropout1 (intermediate )
2587+ @jax .jit
2588+ @partial (jax .checkpoint , prevent_cse = False )
2589+ def _gather_and_compute (
2590+ x : Tensor ,
2591+ gate_assignment : Tensor ,
2592+ wi_0_sharded : Tensor ,
2593+ wi_1_sharded : Tensor ,
2594+ wo_sharded : Tensor ,
2595+ ) -> tuple [Tensor , ...]:
2596+ # Explicitly all-gather expert weights along fsdp inside the
2597+ # checkpoint scope. Backward will recompute this gather rather
2598+ # than saving the gathered tensor.
2599+ if emh_gather_axes is not None :
2600+ wi_0 = jax .lax .all_gather (wi_0_sharded , emh_gather_axes , axis = 1 , tiled = True )
2601+ wi_1 = jax .lax .all_gather (wi_1_sharded , emh_gather_axes , axis = 1 , tiled = True )
2602+ else :
2603+ wi_0 = wi_0_sharded
2604+ wi_1 = wi_1_sharded
2605+ if ehm_gather_axes is not None :
2606+ wo = jax .lax .all_gather (wo_sharded , ehm_gather_axes , axis = 2 , tiled = True )
2607+ else :
2608+ wo = wo_sharded
2609+
2610+ # [B' x S' x K]
2611+ gate_assignment = gate_assignment .reshape ((- 1 ))
2612+ # x[sorted_indices[:, i]] for i in range(S * K) represents tokens sorted
2613+ # by which experts they are assigned to.
2614+ # [B' x S' x K]
2615+ sorted_indices = jnp .argsort (gate_assignment )
2616+ token_indices = sorted_indices // num_experts_per_token
2617+ # Dispatch the tokens.
2618+ combine_indices = jnp .argsort (sorted_indices )
2619+ # [B' x S' x K, M]
2620+ sorted_inputs = _custom_gather (
2621+ x .reshape (- 1 , M ), token_indices , combine_indices , unique_indices = False
2622+ )
2623+ tokens_per_expert = jnp .bincount (gate_assignment , length = cfg .num_experts )
26002624
2601- # [B' x S x K, M]
2602- sorted_output = self ._padded_gmm (intermediate , wo , tokens_per_expert )
2603- if thread_resources .env .physical_mesh .shape ["model" ] > 1 :
2604- # If output is partitioned across "model", we need to reduce-scatter. Otherwise,
2605- # we do an allreduce.
2606- spec = cfg .output_dim_to_partition_spec ["bsm" ][2 ]
2607- if spec and "model" in spec :
2608- sorted_output = jax .lax .psum_scatter (
2609- sorted_output , "model" , scatter_dimension = 1 , tiled = True
2625+ sorted_inputs , tokens_per_expert , additional_outputs , residuals = (
2626+ self ._dispatch_hook (
2627+ sorted_inputs = sorted_inputs ,
2628+ tokens_per_expert = tokens_per_expert ,
26102629 )
2611- else :
2612- sorted_output = jax .lax .psum (sorted_output , "model" )
2613- # [B' x S' x K, M']
2614- sorted_output = self ._combine_hook (sorted_output = sorted_output , residuals = residuals )
2615- # Gather the tokens to their original positions.
2616- unsorted_output = _custom_gather (sorted_output , combine_indices , sorted_indices )
2617- output = unsorted_output .reshape (B , S , num_experts_per_token , unsorted_output .shape [- 1 ])
2618- # Apply the expert weights.
2619- output *= expert_weights [..., None ]
2620- assert expert_weights .dtype == jnp .float32
2621- assert output .dtype == jnp .float32
2622- # [B', S', M']
2623- output = jnp .sum (output , axis = - 2 ).astype (x .dtype )
2624- return output , * additional_outputs
2630+ )
2631+
2632+ # [B' x S' x K, H']
2633+ activation_0 = self ._padded_gmm (sorted_inputs , wi_0 , tokens_per_expert )
2634+ activation_0 = get_activation_fn (cfg .activation [0 ])(activation_0 )
2635+
2636+ activation_1 = self ._padded_gmm (sorted_inputs , wi_1 , tokens_per_expert )
2637+ activation_1 = get_activation_fn (cfg .activation [1 ])(activation_1 )
2638+
2639+ intermediate = activation_0 * activation_1
2640+
2641+ if cfg .structure in ["prenorm" , "hybridnorm" , "nonorm" , "v2" ]:
2642+ intermediate = self .dropout1 (intermediate )
2643+
2644+ # [B' x S x K, M]
2645+ sorted_output = self ._padded_gmm (intermediate , wo , tokens_per_expert )
2646+ if thread_resources .env .physical_mesh .shape ["model" ] > 1 :
2647+ # If output is partitioned across "model", we need to reduce-scatter. Otherwise,
2648+ # we do an allreduce.
2649+ spec = cfg .output_dim_to_partition_spec ["bsm" ][2 ]
2650+ if spec and "model" in spec :
2651+ sorted_output = jax .lax .psum_scatter (
2652+ sorted_output , "model" , scatter_dimension = 1 , tiled = True
2653+ )
2654+ else :
2655+ sorted_output = jax .lax .psum (sorted_output , "model" )
2656+ # [B' x S' x K, M']
2657+ sorted_output = self ._combine_hook (sorted_output = sorted_output , residuals = residuals )
2658+ # Gather the tokens to their original positions.
2659+ unsorted_output = _custom_gather (sorted_output , combine_indices , sorted_indices )
2660+ # [B', S', K, M']
2661+ output = unsorted_output .reshape (
2662+ B , S , num_experts_per_token , unsorted_output .shape [- 1 ]
2663+ )
2664+ return output , * additional_outputs
2665+
2666+ return _gather_and_compute (
2667+ x ,
2668+ gate_assignment ,
2669+ wi_0_sharded ,
2670+ wi_1_sharded ,
2671+ wo_sharded ,
2672+ )
26252673
26262674 out , * additional_outputs = wrapper (
26272675 x ,
26282676 gate_assignment ,
2629- expert_weights ,
26302677 self .parameters ["wi_0_weight" ],
26312678 self .parameters ["wi_1_weight" ],
26322679 self .parameters ["wo_weight" ],
26332680 )
26342681 self ._additional_shmap_output_hook (additional_outputs )
2682+
2683+ # Apply expert weights outside shard_map to avoid passing extra tensor
2684+ # through the shard_map boundary and recomputing it in backward.
2685+ # out: [B, S, K, M'], expert_weights: [B, S, K]
2686+ out *= expert_weights [..., None ]
2687+ assert expert_weights .dtype == jnp .float32
2688+ assert out .dtype == jnp .float32
2689+ # [B, S, M']
2690+ out = jnp .sum (out , axis = - 2 ).astype (x .dtype )
26352691 return out
26362692
26372693
0 commit comments