Skip to content

Commit 53d0ecd

Browse files
committed
[JAX] MoE: re-fuse wi_0/wi_1 via jnp.concatenate (replaces our un-fuse)
Adopt jberchtold's concat-along-trailing-axis fusion of the gate /up projections (df61642) in place of our two-separate-GEMM un-fuse from ``fe3e4ff9``. Both approaches address the same upstream ``tex.grouped_gemm`` constraint (the kernel only supports the 3D ``(G, K, N)`` weight layout with ``contracting_dims=((1,), (1,))``, so the previous ``jnp.stack([wi_0, wi_1], axis=-2)`` 4D variant silently produced NaN). His version is materially cheaper: one quantize + one GEMM in fwd and one quantize + two GEMMs + one dbias in bwd, vs our two quantizes + two GEMMs in fwd and two quantizes + four GEMMs + two dbias in bwd. Concretely on the fwd path: * Pack the two weight tensors with ``jnp.concatenate([wi_0, wi_1], axis=-1)`` so the combined weight has shape ``(num_local_experts, hidden, 2*H_inter)`` -- still 3D, so the kernel's contract is preserved. * Run a single ``tex.grouped_gemm`` against the concatenated weight to get a ``(num_rows, 2*H_inter)`` output, then ``jnp.split(..., 2, axis=-1)`` to recover ``gate_proj_out`` / ``up_proj_out``. * Save only ``casted_wi_rhs_trans`` in the residual (single 3D tensor) instead of the two halves we used before. The bwd path mirrors that: concatenate the two activation cotangents, quantize once, run the dgrad against the fused ``casted_wi_rhs_trans`` residual to produce ``d_sorted_x``, run the wgrad against the same casted-d-combined RHS to produce ``d_wi_combined``, split into ``d_wi_0`` / ``d_wi_1``. ``tex.grouped_dbias`` likewise runs once on ``d_combined`` and splits. What we keep from our local work (NOT pulled from his branch): * The ``shard_map``-wrapped FFN body in ``_moe_fwd_rule`` and ``_moe_bwd_rule``. His ``2210702a`` deletes this and calls the FFN per-shard helpers directly via grouped-GEMM custom partitioning; that lands in a later integration sweep once ``jberchtold/gmm-custom-partition-rules`` merges to main. * The per-shard ``jax.lax.cond`` zero-init of ``r_tok`` inside the fwd ``_body`` -- still required to work around the NCCL EP HT dispatch zero-init gap for fully-empty-receiver ranks (``sigmoid-bias-strong`` regression). * All ``_inspect`` probes in both the fwd and bwd FFN. These stay on for this round so we can confirm the cherry-pick didn't regress anything; they will come out in a follow-up cleanup commit before PR. Plumbing impact: ``_Ctx`` collapses ``casted_wi_0_rhs_trans`` + ``casted_wi_1_rhs_trans`` into a single ``casted_wi_rhs_trans`` field; the fwd ``shard_map``'s ``residuals_spec`` and the bwd ``shard_map``'s ``bwd_in_specs`` / ``bwd_in_args`` both lose one slot accordingly.
1 parent fecb0ed commit 53d0ecd

1 file changed

Lines changed: 59 additions & 81 deletions

File tree

transformer_engine/jax/moe.py

Lines changed: 59 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,7 @@ class _Ctx:
224224
token_counts: jnp.ndarray
225225
recv_topk_weights: jnp.ndarray
226226
casted_sorted_x_lhs_trans: Any
227-
casted_wi_0_rhs_trans: Any
228-
casted_wi_1_rhs_trans: Any
227+
casted_wi_rhs_trans: Any
229228
gate_proj_out: jnp.ndarray
230229
up_proj_out: jnp.ndarray
231230
casted_intermediate_lhs_trans: Any
@@ -249,8 +248,7 @@ def tree_flatten(self):
249248
self.token_counts,
250249
self.recv_topk_weights,
251250
self.casted_sorted_x_lhs_trans,
252-
self.casted_wi_0_rhs_trans,
253-
self.casted_wi_1_rhs_trans,
251+
self.casted_wi_rhs_trans,
254252
self.gate_proj_out,
255253
self.up_proj_out,
256254
self.casted_intermediate_lhs_trans,
@@ -278,8 +276,7 @@ def tree_unflatten(cls, aux_data, children):
278276
token_counts,
279277
recv_topk_weights,
280278
casted_sorted_x_lhs_trans,
281-
casted_wi_0_rhs_trans,
282-
casted_wi_1_rhs_trans,
279+
casted_wi_rhs_trans,
283280
gate_proj_out,
284281
up_proj_out,
285282
casted_intermediate_lhs_trans,
@@ -302,8 +299,7 @@ def tree_unflatten(cls, aux_data, children):
302299
token_counts=token_counts,
303300
recv_topk_weights=recv_topk_weights,
304301
casted_sorted_x_lhs_trans=casted_sorted_x_lhs_trans,
305-
casted_wi_0_rhs_trans=casted_wi_0_rhs_trans,
306-
casted_wi_1_rhs_trans=casted_wi_1_rhs_trans,
302+
casted_wi_rhs_trans=casted_wi_rhs_trans,
307303
gate_proj_out=gate_proj_out,
308304
up_proj_out=up_proj_out,
309305
casted_intermediate_lhs_trans=casted_intermediate_lhs_trans,
@@ -353,41 +349,42 @@ def _ffn_fwd_per_shard(
353349
wi_1 = wi_1.astype(sorted_x.dtype)
354350
wo = wo.astype(sorted_x.dtype)
355351

352+
# wi GEMM uses ONE fused grouped_gemm with the gate/up weights
353+
# concatenated along the trailing (output) axis: wi_combined has
354+
# shape ``(num_local_experts, hidden, 2*H_inter)`` and the resulting
355+
# combined_out has shape ``(num_rows, 2*H_inter)``, which jnp.split
356+
# cleanly slices back into gate / up halves. tex.grouped_gemm only
357+
# supports the canonical (G, K, N) 3D weight layout with
358+
# contracting_dims=((1,),(1,)) -- see the docstring on
359+
# transformer_engine.jax.dense.grouped_dense ("currently only
360+
# supports ((1,), (1,))") and the CI test
361+
# tests/jax/test_multi_process_distributed_grouped_gemm.py.
362+
# An older fused 4D variant built via jnp.stack([wi_0, wi_1], axis=-2)
363+
# put a non-contracting axis in the middle of the RHS, which the
364+
# kernel walked as if it were 3D and read off the end -> NaN.
365+
# Confirmed via TE_MOE_INSPECT bisect: the stack-axis variant
366+
# produced all-NaN output, while the concat-axis variant (this
367+
# path) produces finite outputs matching the jnp.einsum reference.
368+
wi_combined = jnp.concatenate([wi_0, wi_1], axis=-1)
369+
wi_combined_bias = (
370+
jnp.concatenate([wi_0_bias, wi_1_bias], axis=-1) if wi_0_bias is not None else None
371+
)
372+
356373
q_set = noop_quantizer_set
357-
# wi GEMM uses TWO separate 3D grouped_gemm calls (one per wi_0 / wi_1)
358-
# instead of one fused 4D call. tex.grouped_gemm only supports the
359-
# canonical (G, K, N) 3D weight layout with contracting_dims=((1,),(1,))
360-
# -- see the docstring on transformer_engine.jax.dense.grouped_dense
361-
# ("currently only supports ((1,), (1,))") and the CI test
362-
# tests/jax/test_multi_process_distributed_grouped_gemm.py. A 4D
363-
# weight built via jnp.stack([wi_0, wi_1], axis=-2) puts a
364-
# non-contracting axis in the middle of the RHS, which the kernel
365-
# walks as if it were 3D and reads off the end -> NaN. Confirmed
366-
# via TE_MOE_INSPECT bisect: clean LHS + clean fused-4D RHS still
367-
# produced all-NaN output, while the same inputs through two
368-
# 3D calls produced finite outputs matching the jnp.einsum reference.
369374
sorted_x = _inspect(sorted_x, "ffn_fwd/sorted_x_in")
370375
casted_sorted_x = tex.grouped_quantize(sorted_x, q_set.x, local_group_sizes, flatten_axis=-1)
371-
casted_wi_0 = tex.grouped_quantize(wi_0, q_set.kernel, flatten_axis=-1)
372-
casted_wi_1 = tex.grouped_quantize(wi_1, q_set.kernel, flatten_axis=-1)
373-
_casted_x_lhs = casted_sorted_x.get_tensor(usage=TensorUsage.LHS)
374-
gate_proj_out = tex.grouped_gemm(
375-
_casted_x_lhs,
376-
casted_wi_0.get_tensor(usage=TensorUsage.RHS),
376+
casted_wi = tex.grouped_quantize(wi_combined, q_set.kernel, flatten_axis=-1)
377+
combined_out = tex.grouped_gemm(
378+
casted_sorted_x.get_tensor(usage=TensorUsage.LHS),
379+
casted_wi.get_tensor(usage=TensorUsage.RHS),
377380
contracting_dims=((1,), (1,)),
378-
bias=wi_0_bias,
379-
)
380-
up_proj_out = tex.grouped_gemm(
381-
_casted_x_lhs,
382-
casted_wi_1.get_tensor(usage=TensorUsage.RHS),
383-
contracting_dims=((1,), (1,)),
384-
bias=wi_1_bias,
381+
bias=wi_combined_bias,
385382
)
383+
gate_proj_out, up_proj_out = jnp.split(combined_out, 2, axis=-1)
386384
gate_proj_out = _inspect(gate_proj_out, "ffn_fwd/gate_proj_out")
387385
up_proj_out = _inspect(up_proj_out, "ffn_fwd/up_proj_out")
388386
casted_sorted_x_lhs_trans = casted_sorted_x.get_tensor(usage=TensorUsage.LHS_TRANS)
389-
casted_wi_0_rhs_trans = casted_wi_0.get_tensor(usage=TensorUsage.RHS_TRANS)
390-
casted_wi_1_rhs_trans = casted_wi_1.get_tensor(usage=TensorUsage.RHS_TRANS)
387+
casted_wi_rhs_trans = casted_wi.get_tensor(usage=TensorUsage.RHS_TRANS)
391388

392389
# Promote the silu+multiply to fp32 to match the pure-JAX reference
393390
# (and ML common practice). bf16 silu accumulation alone drifts ~1%
@@ -429,8 +426,7 @@ def _ffn_fwd_per_shard(
429426
expert_outputs_3d = expert_outputs.reshape(1, expert_outputs.shape[0], expert_outputs.shape[1])
430427
residuals = (
431428
casted_sorted_x_lhs_trans,
432-
casted_wi_0_rhs_trans,
433-
casted_wi_1_rhs_trans,
429+
casted_wi_rhs_trans,
434430
gate_proj_out,
435431
up_proj_out,
436432
casted_intermediate_lhs_trans,
@@ -443,8 +439,7 @@ def _ffn_fwd_per_shard(
443439
def _ffn_bwd_per_shard(
444440
d_expert_outputs_local: jnp.ndarray,
445441
casted_sorted_x_lhs_trans,
446-
casted_wi_0_rhs_trans,
447-
casted_wi_1_rhs_trans,
442+
casted_wi_rhs_trans,
448443
gate_proj_out: jnp.ndarray,
449444
up_proj_out: jnp.ndarray,
450445
casted_intermediate_lhs_trans,
@@ -522,46 +517,32 @@ def _ffn_bwd_per_shard(
522517
d_up_proj_out = _inspect(d_up_proj_out, "ffn_bwd/d_up_proj_out_after_act_bwd")
523518
d_gate_proj_out = _inspect(d_gate_proj_out, "ffn_bwd/d_gate_proj_out_after_act_bwd")
524519

525-
# wi bwd (split gate/up). Two separate 3D grouped_gemm calls each
526-
# for d_sorted_x and d_w_i, mirroring the un-fused fwd. The fused
527-
# 4D path was buggy in fwd (NaN-from-clean-inputs); the same
528-
# ((1,2),(2,3)) bwd shape on a 4D RHS would silently produce NaN
529-
# too if it ever fired on clean inputs.
530-
d_gate_proj_out_b = d_gate_proj_out.astype(gate_proj_out.dtype)
531-
d_up_proj_out_b = d_up_proj_out.astype(up_proj_out.dtype)
532-
casted_d_gate = tex.grouped_quantize(
533-
d_gate_proj_out_b, q_set.dgrad, local_group_sizes, flatten_axis=-1
534-
)
535-
casted_d_up = tex.grouped_quantize(
536-
d_up_proj_out_b, q_set.dgrad, local_group_sizes, flatten_axis=-1
520+
# wi bwd (fused gate/up via concat). Mirror the fused fwd: pack the
521+
# gate/up cotangents along the trailing axis, run a single
522+
# grouped_quantize + two grouped_gemm pair (one dgrad, one wgrad)
523+
# against the fused casted_wi_rhs_trans residual, then split the
524+
# wgrad result back into d_wi_0 / d_wi_1 halves with jnp.split.
525+
d_combined = jnp.concatenate([d_gate_proj_out, d_up_proj_out], axis=-1)
526+
casted_d_combined = tex.grouped_quantize(
527+
d_combined, q_set.dgrad, local_group_sizes, flatten_axis=-1
537528
)
538-
d_sorted_x_from_gate = tex.grouped_gemm(
539-
casted_d_gate.get_tensor(usage=TensorUsage.LHS),
540-
casted_wi_0_rhs_trans,
529+
d_sorted_x = tex.grouped_gemm(
530+
casted_d_combined.get_tensor(usage=TensorUsage.LHS),
531+
casted_wi_rhs_trans,
541532
contracting_dims=((1,), (2,)),
542533
)
543-
d_sorted_x_from_up = tex.grouped_gemm(
544-
casted_d_up.get_tensor(usage=TensorUsage.LHS),
545-
casted_wi_1_rhs_trans,
546-
contracting_dims=((1,), (2,)),
547-
)
548-
d_sorted_x = d_sorted_x_from_gate + d_sorted_x_from_up
549534
d_sorted_x = _inspect(d_sorted_x, "ffn_bwd/d_sorted_x_after_wi_dgrad_sum")
550-
d_wi_0 = tex.grouped_gemm(
551-
casted_sorted_x_lhs_trans,
552-
casted_d_gate.get_tensor(usage=TensorUsage.RHS),
553-
contracting_dims=((0,), (0,)),
554-
)
555-
d_wi_1 = tex.grouped_gemm(
535+
d_wi_combined = tex.grouped_gemm(
556536
casted_sorted_x_lhs_trans,
557-
casted_d_up.get_tensor(usage=TensorUsage.RHS),
537+
casted_d_combined.get_tensor(usage=TensorUsage.RHS),
558538
contracting_dims=((0,), (0,)),
559539
)
540+
d_wi_0, d_wi_1 = jnp.split(d_wi_combined, 2, axis=-1)
560541
d_wi_0 = _inspect(d_wi_0, "ffn_bwd/d_wi_0_after_wgrad_pre_psum")
561542
d_wi_1 = _inspect(d_wi_1, "ffn_bwd/d_wi_1_after_wgrad_pre_psum")
562543
if has_bias:
563-
d_wi_0_bias = tex.grouped_dbias(d_gate_proj_out_b, local_group_sizes)
564-
d_wi_1_bias = tex.grouped_dbias(d_up_proj_out_b, local_group_sizes)
544+
d_wi_combined_bias = tex.grouped_dbias(d_combined, local_group_sizes)
545+
d_wi_0_bias, d_wi_1_bias = jnp.split(d_wi_combined_bias, 2, axis=-1)
565546
else:
566547
d_wi_0_bias = None
567548
d_wi_1_bias = None
@@ -819,12 +800,13 @@ def _moe_fwd_rule(
819800

820801
# FFN residuals live entirely on the local ep rank, so the leading
821802
# "experts" / "rows" dims map to P() (already shard-local). wi is
822-
# un-fused into wi_0 / wi_1 (see _ffn_fwd_per_shard for rationale);
823-
# each carries its own RHS_TRANS residual.
803+
# fused via jnp.concatenate along the trailing (output) axis
804+
# (see _ffn_fwd_per_shard for rationale), so the residual is a
805+
# single 3D casted_wi_rhs_trans of shape
806+
# (num_local_experts, hidden, 2*H_inter).
824807
residuals_spec = (
825808
P(), # casted_sorted_x_lhs_trans
826-
P(ep_axis, None, None), # casted_wi_0_rhs_trans
827-
P(ep_axis, None, None), # casted_wi_1_rhs_trans
809+
P(ep_axis, None, None), # casted_wi_rhs_trans
828810
P(), # gate_proj_out
829811
P(), # up_proj_out
830812
P(), # casted_intermediate_lhs_trans
@@ -921,8 +903,7 @@ def _body(*args):
921903

922904
(
923905
casted_sorted_x_lhs_trans,
924-
casted_wi_0_rhs_trans,
925-
casted_wi_1_rhs_trans,
906+
casted_wi_rhs_trans,
926907
gate_proj_out,
927908
up_proj_out,
928909
casted_intermediate_lhs_trans,
@@ -942,8 +923,7 @@ def _body(*args):
942923
token_counts=token_counts,
943924
recv_topk_weights=recv_topk_weights,
944925
casted_sorted_x_lhs_trans=casted_sorted_x_lhs_trans,
945-
casted_wi_0_rhs_trans=casted_wi_0_rhs_trans,
946-
casted_wi_1_rhs_trans=casted_wi_1_rhs_trans,
926+
casted_wi_rhs_trans=casted_wi_rhs_trans,
947927
gate_proj_out=gate_proj_out,
948928
up_proj_out=up_proj_out,
949929
casted_intermediate_lhs_trans=casted_intermediate_lhs_trans,
@@ -1075,8 +1055,7 @@ def _moe_bwd_rule(
10751055
bwd_in_specs = (
10761056
ep3_spec, # d_expert_outputs
10771057
P(), # casted_sorted_x_lhs_trans
1078-
P(ep_axis, None, None), # casted_wi_0_rhs_trans
1079-
P(ep_axis, None, None), # casted_wi_1_rhs_trans
1058+
P(ep_axis, None, None), # casted_wi_rhs_trans
10801059
P(), # gate_proj_out
10811060
P(), # up_proj_out
10821061
P(), # casted_intermediate_lhs_trans
@@ -1087,8 +1066,7 @@ def _moe_bwd_rule(
10871066
bwd_in_args = [
10881067
d_expert_outputs,
10891068
ctx.casted_sorted_x_lhs_trans,
1090-
ctx.casted_wi_0_rhs_trans,
1091-
ctx.casted_wi_1_rhs_trans,
1069+
ctx.casted_wi_rhs_trans,
10921070
ctx.gate_proj_out,
10931071
ctx.up_proj_out,
10941072
ctx.casted_intermediate_lhs_trans,

0 commit comments

Comments
 (0)