Skip to content

Commit 98e654d

Browse files
Tao Leiwangkuiyi
authored andcommitted
Fix MoE OOM: move all-gather inside shard_map with jax.checkpoint
GitOrigin-RevId: 002a682fefb98243f7a7a265738f7d373145d006
1 parent e08fe75 commit 98e654d

1 file changed

Lines changed: 124 additions & 68 deletions

File tree

axlearn/common/mixture_of_experts.py

Lines changed: 124 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -2515,6 +2515,26 @@ def _dispatch_and_combine(self, x: Tensor) -> Tensor:
25152515
# in the implementation, which is inefficient.
25162516
assert cfg.input_dim_to_partition_spec["bsm"][-1] is None
25172517
assert cfg.output_dim_to_partition_spec["emh"][-2] is None
2518+
assert cfg.output_dim_to_partition_spec["ehm"][-1] is None
2519+
2520+
# Pass expert weights into shard_map with their fsdp sharding preserved,
2521+
# then explicitly all-gather along fsdp inside a jax.checkpoint wrapper.
2522+
# This prevents SPMD from hoisting one giant all-gather over the stacked
2523+
# [num_layers, E, M/fsdp, H] parameter tensor outside the scan loop.
2524+
# The JAX-level all_gather inside checkpoint is recomputed per-iteration in
2525+
# backward instead of being saved as a stacked residual.
2526+
2527+
# The shard_map() implementation below assumes dim H can only be
2528+
# sharded along the model axis.
2529+
assert cfg.output_dim_to_partition_spec["emh"][-1] == "model"
2530+
assert cfg.output_dim_to_partition_spec["ehm"][-2] == "model"
2531+
assert cfg.dim_to_mesh_axis_map["emh"][-1] == "model"
2532+
assert cfg.dim_to_mesh_axis_map["ehm"][-2] == "model"
2533+
2534+
emh_gather_axes = cfg.dim_to_mesh_axis_map["emh"][1]
2535+
ehm_gather_axes = cfg.dim_to_mesh_axis_map["ehm"][2]
2536+
bsm_out_spec = cfg.output_dim_to_partition_spec["bsm"]
2537+
bskm_out_spec = PartitionSpec(*bsm_out_spec[:2], None, bsm_out_spec[2])
25182538

25192539
mesh = thread_resources.env.physical_mesh
25202540

@@ -2524,13 +2544,12 @@ def _dispatch_and_combine(self, x: Tensor) -> Tensor:
25242544
in_specs=(
25252545
cfg.input_dim_to_partition_spec["bsm"],
25262546
cfg.input_dim_to_partition_spec["bsm"],
2527-
cfg.input_dim_to_partition_spec["bsm"],
2528-
cfg.output_dim_to_partition_spec["emh"],
2529-
cfg.output_dim_to_partition_spec["emh"],
2530-
cfg.output_dim_to_partition_spec["ehm"],
2547+
cfg.dim_to_mesh_axis_map["emh"],
2548+
cfg.dim_to_mesh_axis_map["emh"],
2549+
cfg.dim_to_mesh_axis_map["ehm"],
25312550
),
25322551
out_specs=(
2533-
cfg.output_dim_to_partition_spec["bsm"],
2552+
bskm_out_spec,
25342553
*self._additional_shmap_output_sharding(mesh),
25352554
),
25362555
# Disables a checking pass which jax can't apply when there's a triton | pallas
@@ -2540,10 +2559,9 @@ def _dispatch_and_combine(self, x: Tensor) -> Tensor:
25402559
def wrapper(
25412560
x: Tensor,
25422561
gate_assignment: Tensor,
2543-
expert_weights: Tensor,
2544-
wi_0: Tensor,
2545-
wi_1: Tensor,
2546-
wo: Tensor,
2562+
wi_0_sharded: Tensor,
2563+
wi_1_sharded: Tensor,
2564+
wo_sharded: Tensor,
25472565
) -> tuple[Tensor, ...]:
25482566
"""Computes unsorted outputs for one sharded block.
25492567
@@ -2554,84 +2572,122 @@ def wrapper(
25542572
Args:
25552573
x: the sharded input batch of the shape [G=B', S', M].
25562574
gate_assignment: [G=B', S', K].
2557-
expert_weights: [G=B', S', K].
2558-
wi_0: the input projection of [E, M, H'].
2559-
wi_1: the input projection of [E, M, H'].
2560-
wo: the output projection of [E, H', M'].
2575+
wi_0_sharded: the input projection of [E, M/fsdp, H/model] (local).
2576+
wi_1_sharded: the input projection of [E, M/fsdp, H/model] (local).
2577+
wo_sharded: the output projection of [E, H/model, M/fsdp] (local).
25612578
25622579
Returns:
25632580
A tuple of
2564-
- A tensor of shape [G=B', S', M'].
2581+
- A tensor of shape [G=B', S', K, M'].
25652582
- ... optional series of tensors from `_dispatch_hook()[2]`.
25662583
"""
25672584
logging.info("Setting the effective group_size=%r", x.shape[0])
25682585
B, S, M = x.shape # pylint: disable=invalid-name
2569-
# [B' x S' x K]
2570-
gate_assignment = gate_assignment.reshape((-1))
2571-
# x[sorted_indices[:, i]] for i in range(S * K) represents tokens sorted
2572-
# by which experts they are assigned to.
2573-
# [B' x S' x K]
2574-
sorted_indices = jnp.argsort(gate_assignment)
2575-
token_indices = sorted_indices // num_experts_per_token
2576-
# Dispatch the tokens.
2577-
combine_indices = jnp.argsort(sorted_indices)
2578-
# [B' x S' x K, M]
2579-
sorted_inputs = _custom_gather(
2580-
x.reshape(-1, M), token_indices, combine_indices, unique_indices=False
2581-
)
2582-
tokens_per_expert = jnp.bincount(gate_assignment, length=cfg.num_experts)
2583-
2584-
sorted_inputs, tokens_per_expert, additional_outputs, residuals = self._dispatch_hook(
2585-
sorted_inputs=sorted_inputs,
2586-
tokens_per_expert=tokens_per_expert,
2587-
)
25882586

2589-
# [B' x S' x K, H']
2590-
activation_0 = self._padded_gmm(sorted_inputs, wi_0, tokens_per_expert)
2591-
activation_0 = get_activation_fn(cfg.activation[0])(activation_0)
2592-
2593-
activation_1 = self._padded_gmm(sorted_inputs, wi_1, tokens_per_expert)
2594-
activation_1 = get_activation_fn(cfg.activation[1])(activation_1)
2595-
2596-
intermediate = activation_0 * activation_1
2597-
2598-
if cfg.structure in ["prenorm", "hybridnorm", "nonorm", "v2"]:
2599-
intermediate = self.dropout1(intermediate)
2587+
@jax.jit
2588+
@partial(jax.checkpoint, prevent_cse=False)
2589+
def _gather_and_compute(
2590+
x: Tensor,
2591+
gate_assignment: Tensor,
2592+
wi_0_sharded: Tensor,
2593+
wi_1_sharded: Tensor,
2594+
wo_sharded: Tensor,
2595+
) -> tuple[Tensor, ...]:
2596+
# Explicitly all-gather expert weights along fsdp inside the
2597+
# checkpoint scope. Backward will recompute this gather rather
2598+
# than saving the gathered tensor.
2599+
if emh_gather_axes is not None:
2600+
wi_0 = jax.lax.all_gather(wi_0_sharded, emh_gather_axes, axis=1, tiled=True)
2601+
wi_1 = jax.lax.all_gather(wi_1_sharded, emh_gather_axes, axis=1, tiled=True)
2602+
else:
2603+
wi_0 = wi_0_sharded
2604+
wi_1 = wi_1_sharded
2605+
if ehm_gather_axes is not None:
2606+
wo = jax.lax.all_gather(wo_sharded, ehm_gather_axes, axis=2, tiled=True)
2607+
else:
2608+
wo = wo_sharded
2609+
2610+
# [B' x S' x K]
2611+
gate_assignment = gate_assignment.reshape((-1))
2612+
# x[sorted_indices[:, i]] for i in range(S * K) represents tokens sorted
2613+
# by which experts they are assigned to.
2614+
# [B' x S' x K]
2615+
sorted_indices = jnp.argsort(gate_assignment)
2616+
token_indices = sorted_indices // num_experts_per_token
2617+
# Dispatch the tokens.
2618+
combine_indices = jnp.argsort(sorted_indices)
2619+
# [B' x S' x K, M]
2620+
sorted_inputs = _custom_gather(
2621+
x.reshape(-1, M), token_indices, combine_indices, unique_indices=False
2622+
)
2623+
tokens_per_expert = jnp.bincount(gate_assignment, length=cfg.num_experts)
26002624

2601-
# [B' x S x K, M]
2602-
sorted_output = self._padded_gmm(intermediate, wo, tokens_per_expert)
2603-
if thread_resources.env.physical_mesh.shape["model"] > 1:
2604-
# If output is partitioned across "model", we need to reduce-scatter. Otherwise,
2605-
# we do an allreduce.
2606-
spec = cfg.output_dim_to_partition_spec["bsm"][2]
2607-
if spec and "model" in spec:
2608-
sorted_output = jax.lax.psum_scatter(
2609-
sorted_output, "model", scatter_dimension=1, tiled=True
2625+
sorted_inputs, tokens_per_expert, additional_outputs, residuals = (
2626+
self._dispatch_hook(
2627+
sorted_inputs=sorted_inputs,
2628+
tokens_per_expert=tokens_per_expert,
26102629
)
2611-
else:
2612-
sorted_output = jax.lax.psum(sorted_output, "model")
2613-
# [B' x S' x K, M']
2614-
sorted_output = self._combine_hook(sorted_output=sorted_output, residuals=residuals)
2615-
# Gather the tokens to their original positions.
2616-
unsorted_output = _custom_gather(sorted_output, combine_indices, sorted_indices)
2617-
output = unsorted_output.reshape(B, S, num_experts_per_token, unsorted_output.shape[-1])
2618-
# Apply the expert weights.
2619-
output *= expert_weights[..., None]
2620-
assert expert_weights.dtype == jnp.float32
2621-
assert output.dtype == jnp.float32
2622-
# [B', S', M']
2623-
output = jnp.sum(output, axis=-2).astype(x.dtype)
2624-
return output, *additional_outputs
2630+
)
2631+
2632+
# [B' x S' x K, H']
2633+
activation_0 = self._padded_gmm(sorted_inputs, wi_0, tokens_per_expert)
2634+
activation_0 = get_activation_fn(cfg.activation[0])(activation_0)
2635+
2636+
activation_1 = self._padded_gmm(sorted_inputs, wi_1, tokens_per_expert)
2637+
activation_1 = get_activation_fn(cfg.activation[1])(activation_1)
2638+
2639+
intermediate = activation_0 * activation_1
2640+
2641+
if cfg.structure in ["prenorm", "hybridnorm", "nonorm", "v2"]:
2642+
intermediate = self.dropout1(intermediate)
2643+
2644+
# [B' x S x K, M]
2645+
sorted_output = self._padded_gmm(intermediate, wo, tokens_per_expert)
2646+
if thread_resources.env.physical_mesh.shape["model"] > 1:
2647+
# If output is partitioned across "model", we need to reduce-scatter. Otherwise,
2648+
# we do an allreduce.
2649+
spec = cfg.output_dim_to_partition_spec["bsm"][2]
2650+
if spec and "model" in spec:
2651+
sorted_output = jax.lax.psum_scatter(
2652+
sorted_output, "model", scatter_dimension=1, tiled=True
2653+
)
2654+
else:
2655+
sorted_output = jax.lax.psum(sorted_output, "model")
2656+
# [B' x S' x K, M']
2657+
sorted_output = self._combine_hook(sorted_output=sorted_output, residuals=residuals)
2658+
# Gather the tokens to their original positions.
2659+
unsorted_output = _custom_gather(sorted_output, combine_indices, sorted_indices)
2660+
# [B', S', K, M']
2661+
output = unsorted_output.reshape(
2662+
B, S, num_experts_per_token, unsorted_output.shape[-1]
2663+
)
2664+
return output, *additional_outputs
2665+
2666+
return _gather_and_compute(
2667+
x,
2668+
gate_assignment,
2669+
wi_0_sharded,
2670+
wi_1_sharded,
2671+
wo_sharded,
2672+
)
26252673

26262674
out, *additional_outputs = wrapper(
26272675
x,
26282676
gate_assignment,
2629-
expert_weights,
26302677
self.parameters["wi_0_weight"],
26312678
self.parameters["wi_1_weight"],
26322679
self.parameters["wo_weight"],
26332680
)
26342681
self._additional_shmap_output_hook(additional_outputs)
2682+
2683+
# Apply expert weights outside shard_map to avoid passing extra tensor
2684+
# through the shard_map boundary and recomputing it in backward.
2685+
# out: [B, S, K, M'], expert_weights: [B, S, K]
2686+
out *= expert_weights[..., None]
2687+
assert expert_weights.dtype == jnp.float32
2688+
assert out.dtype == jnp.float32
2689+
# [B, S, M']
2690+
out = jnp.sum(out, axis=-2).astype(x.dtype)
26352691
return out
26362692

26372693

0 commit comments

Comments
 (0)