You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[JAX] MoE: re-fuse wi_0/wi_1 via jnp.concatenate (replaces our un-fuse)
Adopt jberchtold's concat-along-trailing-axis fusion of the gate /up
projections (df61642) in place of our two-separate-GEMM un-fuse from
``fe3e4ff9``. Both approaches address the same upstream
``tex.grouped_gemm`` constraint (the kernel only supports the 3D
``(G, K, N)`` weight layout with ``contracting_dims=((1,), (1,))``,
so the previous ``jnp.stack([wi_0, wi_1], axis=-2)`` 4D variant
silently produced NaN). His version is materially cheaper: one
quantize + one GEMM in fwd and one quantize + two GEMMs + one dbias
in bwd, vs our two quantizes + two GEMMs in fwd and two quantizes +
four GEMMs + two dbias in bwd.
Concretely on the fwd path:
* Pack the two weight tensors with ``jnp.concatenate([wi_0, wi_1],
axis=-1)`` so the combined weight has shape
``(num_local_experts, hidden, 2*H_inter)`` -- still 3D, so the
kernel's contract is preserved.
* Run a single ``tex.grouped_gemm`` against the concatenated weight
to get a ``(num_rows, 2*H_inter)`` output, then ``jnp.split(..., 2,
axis=-1)`` to recover ``gate_proj_out`` / ``up_proj_out``.
* Save only ``casted_wi_rhs_trans`` in the residual (single 3D
tensor) instead of the two halves we used before.
The bwd path mirrors that: concatenate the two activation cotangents,
quantize once, run the dgrad against the fused ``casted_wi_rhs_trans``
residual to produce ``d_sorted_x``, run the wgrad against the same
casted-d-combined RHS to produce ``d_wi_combined``, split into
``d_wi_0`` / ``d_wi_1``. ``tex.grouped_dbias`` likewise runs once on
``d_combined`` and splits.
What we keep from our local work (NOT pulled from his branch):
* The ``shard_map``-wrapped FFN body in ``_moe_fwd_rule`` and
``_moe_bwd_rule``. His ``2210702a`` deletes this and calls the FFN
per-shard helpers directly via grouped-GEMM custom partitioning;
that lands in a later integration sweep once
``jberchtold/gmm-custom-partition-rules`` merges to main.
* The per-shard ``jax.lax.cond`` zero-init of ``r_tok`` inside the
fwd ``_body`` -- still required to work around the NCCL EP HT
dispatch zero-init gap for fully-empty-receiver ranks
(``sigmoid-bias-strong`` regression).
* All ``_inspect`` probes in both the fwd and bwd FFN. These stay
on for this round so we can confirm the cherry-pick didn't
regress anything; they will come out in a follow-up cleanup
commit before PR.
Plumbing impact: ``_Ctx`` collapses ``casted_wi_0_rhs_trans`` +
``casted_wi_1_rhs_trans`` into a single ``casted_wi_rhs_trans`` field;
the fwd ``shard_map``'s ``residuals_spec`` and the bwd ``shard_map``'s
``bwd_in_specs`` / ``bwd_in_args`` both lose one slot accordingly.
0 commit comments