Skip to content

Commit 5f72ab6

Browse files
Kevin WangGoogle-ML-Automation
authored andcommitted
Apply optimization barrier for MLA output to a single microbatch only.
PiperOrigin-RevId: 899284823
1 parent 5ec17c0 commit 5f72ab6

1 file changed

Lines changed: 8 additions & 11 deletions

File tree

src/maxtext/models/deepseek_batchsplit.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -821,11 +821,6 @@ def batch_split_schedule(
821821
dtype=cfg.dtype,
822822
activation_pspec=activation_pspec,
823823
)
824-
# Prevent fusion with MoE ops, especially the RMS norm.
825-
# Unfortunately, this seems to be needed to avoid slight numerical differences
826-
# between the fwd pass and remat.
827-
xs = jax.lax.optimization_barrier(xs)
828-
829824
xs, moe_res = moe(
830825
xs,
831826
moe_ws,
@@ -876,10 +871,6 @@ def batch_split_schedule_bwd(
876871
dtype=cfg.dtype,
877872
activation_pspec=activation_pspec,
878873
)
879-
# Prevent fusion with MoE ops, especially the RMS norm.
880-
# Unfortunately, this seems to be needed to avoid slight numerical differences
881-
# between the fwd pass and remat.
882-
mla_out = jax.lax.optimization_barrier(mla_out)
883874
residuals["mla_out"] = mla_out
884875
attn_out_grad, moe_ws_grad = moe_bwd(
885876
residuals,
@@ -970,7 +961,10 @@ def fn(args):
970961
mesh=mesh,
971962
activation_pspec=activation_pspec,
972963
)
973-
return mla_out + x, mla_res
964+
# Prevent fusion with MoE ops, especially the RMS norm.
965+
# Unfortunately, this seems to be needed to avoid slight numerical differences
966+
# between the fwd pass and remat.
967+
return jax.lax.optimization_barrier(mla_out + x), mla_res
974968

975969
return staggered_call(fn, list(zip(inputs, yarn_freqs)))
976970

@@ -1032,7 +1026,10 @@ def remat_fn(args):
10321026
activation_pspec=activation_pspec,
10331027
)
10341028
out = x + mla_out
1035-
return out, (pre_attn_rms_norm_bwd, mla_bwds)
1029+
# Prevent fusion with MoE ops, especially the RMS norm.
1030+
# Unfortunately, this seems to be needed to avoid slight numerical differences
1031+
# between the fwd pass and remat.
1032+
return jax.lax.optimization_barrier(out), (pre_attn_rms_norm_bwd, mla_bwds)
10361033

10371034
bwds = [None] * len(xs)
10381035
for i, x in enumerate(zip(xs, yarn_freqs, residuals.pop("attn_out"), residuals.pop("lse"))):

0 commit comments

Comments
 (0)