@@ -821,11 +821,6 @@ def batch_split_schedule(
821821 dtype = cfg .dtype ,
822822 activation_pspec = activation_pspec ,
823823 )
824- # Prevent fusion with MoE ops, especially the RMS norm.
825- # Unfortunately, this seems to be needed to avoid slight numerical differences
826- # between the fwd pass and remat.
827- xs = jax .lax .optimization_barrier (xs )
828-
829824 xs , moe_res = moe (
830825 xs ,
831826 moe_ws ,
@@ -876,10 +871,6 @@ def batch_split_schedule_bwd(
876871 dtype = cfg .dtype ,
877872 activation_pspec = activation_pspec ,
878873 )
879- # Prevent fusion with MoE ops, especially the RMS norm.
880- # Unfortunately, this seems to be needed to avoid slight numerical differences
881- # between the fwd pass and remat.
882- mla_out = jax .lax .optimization_barrier (mla_out )
883874 residuals ["mla_out" ] = mla_out
884875 attn_out_grad , moe_ws_grad = moe_bwd (
885876 residuals ,
@@ -970,7 +961,10 @@ def fn(args):
970961 mesh = mesh ,
971962 activation_pspec = activation_pspec ,
972963 )
973- return mla_out + x , mla_res
964+ # Prevent fusion with MoE ops, especially the RMS norm.
965+ # Unfortunately, this seems to be needed to avoid slight numerical differences
966+ # between the fwd pass and remat.
967+ return jax .lax .optimization_barrier (mla_out + x ), mla_res
974968
975969 return staggered_call (fn , list (zip (inputs , yarn_freqs )))
976970
@@ -1032,7 +1026,10 @@ def remat_fn(args):
10321026 activation_pspec = activation_pspec ,
10331027 )
10341028 out = x + mla_out
1035- return out , (pre_attn_rms_norm_bwd , mla_bwds )
1029+ # Prevent fusion with MoE ops, especially the RMS norm.
1030+ # Unfortunately, this seems to be needed to avoid slight numerical differences
1031+ # between the fwd pass and remat.
1032+ return jax .lax .optimization_barrier (out ), (pre_attn_rms_norm_bwd , mla_bwds )
10361033
10371034 bwds = [None ] * len (xs )
10381035 for i , x in enumerate (zip (xs , yarn_freqs , residuals .pop ("attn_out" ), residuals .pop ("lse" ))):
0 commit comments