Skip to content

Commit a260c4b

Browse files
committed
[JAX] MoE: remove debug knobs (inspect probes + TE_EP_MOE_K hook)
Strip everything that was carried only for the EP MoE bring-up debugging round and is not needed in normal use: * ``moe.py``: remove the ``TE_MOE_INSPECT`` env toggle, the ``_inspect`` shim, and every ``_inspect(...)`` callsite scattered through both ``_moe_fwd_rule`` / ``_moe_bwd_rule`` and the ``_ffn_*_per_shard`` helpers. Comments that referenced ``TE_MOE_INSPECT`` as the source of a particular conclusion are reworded to keep the technical rationale (silu fp32 promotion, NaN-at-padded-slots, concat-vs-stack bisection, dispatch zero-init gap) without naming the dev-only probe machinery. * ``run_te_ep_moe.sh``: remove the ``TE_EP_MOE_K`` pytest ``-k`` forwarding hook; the suite always runs full now. Preserved deliberately: * The per-rank ``jax.lax.cond`` zero-init of ``r_tok`` inside the fwd FFN ``_body`` -- still required around the NCCL EP HT dispatch zero-init gap (see in-line comment + TODO). * The two ``jnp.where`` NaN-overwrite blocks around the combine fwd/bwd, the ``sparse_probs`` underflow sanitize, and the fp32 silu+multiply promotion in fwd/bwd -- these are correctness fixes, not debug instrumentation. After this commit the MoE block has no dev-only env vars and no debug callbacks; it is ready for upstream PR.
1 parent 53d0ecd commit a260c4b

2 files changed

Lines changed: 23 additions & 103 deletions

File tree

tests/jax/run_te_ep_moe.sh

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,6 @@ for i in $(seq 0 $((NUM_GPUS - 1))); do
7070
--num-process="$NUM_GPUS"
7171
--process-id="$i"
7272
)
73-
# Optional pytest -k selector for scoping the run (e.g. to a single
74-
# failing test). Example:
75-
# TE_EP_MOE_K="test_backward and sigmoid-bias-strong" \
76-
# bash tests/jax/run_te_ep_moe.sh
77-
if [ -n "${TE_EP_MOE_K:-}" ]; then
78-
PYTEST_CMD+=( -k "$TE_EP_MOE_K" )
79-
fi
8073
if [ "$i" -eq 0 ]; then
8174
echo "=== Live output from process 0 ==="
8275
"${PYTEST_CMD[@]}" 2>&1 | tee "$LOG_FILE" &

transformer_engine/jax/moe.py

Lines changed: 23 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737

3838
from dataclasses import dataclass
3939
from functools import partial
40-
import os
4140
from typing import Any, Dict, Optional, Tuple, Union
4241
import warnings
4342

@@ -47,30 +46,6 @@
4746
from jax.tree_util import register_pytree_node_class
4847

4948
from . import cpp_extensions as tex
50-
51-
52-
# Lazy / opt-in: setting TE_MOE_INSPECT=1 in the env wires TE's
53-
# inspect_array FFI through the fwd and bwd. When unset (default),
54-
# _inspect is the identity, so this has zero runtime cost in normal use.
55-
# Dumps land in the process CWD as
56-
# my_tensor_gpu{N}_{sanitized_name}.bin + ..._meta.json (one per probe
57-
# per rank, since 9cb4cfca threaded `name` through the FFI). Each call
58-
# also prints a labelled line ``[gpuN <name>]: ...`` to stdout. We use
59-
# the FFI rather than jax.debug.print because jax.debug.print can
60-
# deadlock under multi-process (callback ordering across processes is
61-
# not synchronised).
62-
_INSPECT_ENABLED = os.environ.get("TE_MOE_INSPECT", "0") == "1"
63-
if _INSPECT_ENABLED:
64-
from .debug.experimental import inspect_array as _te_inspect_array
65-
66-
def _inspect(x: jnp.ndarray, name: str) -> jnp.ndarray:
67-
return _te_inspect_array(x, name)
68-
69-
else:
70-
71-
def _inspect(x: jnp.ndarray, name: str) -> jnp.ndarray:
72-
del name
73-
return x
7449
from .quantize import (
7550
TensorUsage,
7651
noop_quantizer_set,
@@ -362,16 +337,15 @@ def _ffn_fwd_per_shard(
362337
# An older fused 4D variant built via jnp.stack([wi_0, wi_1], axis=-2)
363338
# put a non-contracting axis in the middle of the RHS, which the
364339
# kernel walked as if it were 3D and read off the end -> NaN.
365-
# Confirmed via TE_MOE_INSPECT bisect: the stack-axis variant
340+
# Bisected against a jnp.einsum reference: the stack-axis variant
366341
# produced all-NaN output, while the concat-axis variant (this
367-
# path) produces finite outputs matching the jnp.einsum reference.
342+
# path) produces finite outputs matching the reference.
368343
wi_combined = jnp.concatenate([wi_0, wi_1], axis=-1)
369344
wi_combined_bias = (
370345
jnp.concatenate([wi_0_bias, wi_1_bias], axis=-1) if wi_0_bias is not None else None
371346
)
372347

373348
q_set = noop_quantizer_set
374-
sorted_x = _inspect(sorted_x, "ffn_fwd/sorted_x_in")
375349
casted_sorted_x = tex.grouped_quantize(sorted_x, q_set.x, local_group_sizes, flatten_axis=-1)
376350
casted_wi = tex.grouped_quantize(wi_combined, q_set.kernel, flatten_axis=-1)
377351
combined_out = tex.grouped_gemm(
@@ -381,8 +355,6 @@ def _ffn_fwd_per_shard(
381355
bias=wi_combined_bias,
382356
)
383357
gate_proj_out, up_proj_out = jnp.split(combined_out, 2, axis=-1)
384-
gate_proj_out = _inspect(gate_proj_out, "ffn_fwd/gate_proj_out")
385-
up_proj_out = _inspect(up_proj_out, "ffn_fwd/up_proj_out")
386358
casted_sorted_x_lhs_trans = casted_sorted_x.get_tensor(usage=TensorUsage.LHS_TRANS)
387359
casted_wi_rhs_trans = casted_wi.get_tensor(usage=TensorUsage.RHS_TRANS)
388360

@@ -397,7 +369,6 @@ def _ffn_fwd_per_shard(
397369
act_fn(gate_proj_out.astype(jnp.float32))
398370
* up_proj_out.astype(jnp.float32)
399371
).astype(sorted_x.dtype)
400-
intermediate = _inspect(intermediate, "ffn_fwd/intermediate_after_silu_mul")
401372

402373
if apply_topk_weights_early:
403374
# Fold the per-token combine weights into the FFN intermediate;
@@ -419,7 +390,6 @@ def _ffn_fwd_per_shard(
419390
contracting_dims=((1,), (1,)),
420391
bias=wo_bias,
421392
)
422-
expert_outputs = _inspect(expert_outputs, "ffn_fwd/expert_outputs_after_wo_gemm")
423393
casted_intermediate_lhs_trans = casted_intermediate.get_tensor(usage=TensorUsage.LHS_TRANS)
424394
casted_wo_rhs_trans = casted_wo.get_tensor(usage=TensorUsage.RHS_TRANS)
425395

@@ -461,15 +431,6 @@ def _ffn_bwd_per_shard(
461431
recv_w_flat = recv_topk_weights_local.reshape(-1)
462432
q_set = noop_quantizer_set
463433

464-
# FFN bwd sub-step probes (TE_MOE_INSPECT=1 only). Pin down which
465-
# bwd sub-step introduces NaN/Inf for the sigmoid-bias-strong
466-
# config where some EP ranks have ZERO tokens routed to every
467-
# local expert (empty-rank case). For those ranks d_eo_2d should
468-
# be entirely zero; if any downstream tensor is non-finite the
469-
# offending sub-step is the one that turned a clean zero input
470-
# into NaN/Inf.
471-
d_eo_2d = _inspect(d_eo_2d, "ffn_bwd/d_eo_2d_in")
472-
473434
# wo bwd
474435
casted_d_eo = tex.grouped_quantize(d_eo_2d, q_set.dgrad, local_group_sizes, flatten_axis=-1)
475436
_casted_d_eo_lhs = casted_d_eo.get_tensor(usage=TensorUsage.LHS)
@@ -479,13 +440,11 @@ def _ffn_bwd_per_shard(
479440
casted_wo_rhs_trans,
480441
contracting_dims=((1,), (2,)),
481442
)
482-
d_intermediate = _inspect(d_intermediate, "ffn_bwd/d_intermediate_after_wo_dgrad")
483443
d_wo = tex.grouped_gemm(
484444
casted_intermediate_lhs_trans,
485445
_casted_d_eo_rhs,
486446
contracting_dims=((0,), (0,)),
487447
)
488-
d_wo = _inspect(d_wo, "ffn_bwd/d_wo_after_wgrad_pre_psum")
489448
d_wo_bias = tex.grouped_dbias(d_eo_2d, local_group_sizes) if has_bias else None
490449

491450
act_fn = _convert_to_activation_function(activation_type)
@@ -514,8 +473,6 @@ def _ffn_bwd_per_shard(
514473
d_up_proj_out = (d_int_fp32 * act_gp_fp32).astype(up_proj_out.dtype)
515474
(d_gate_proj_fp32,) = dact_pullback_fp32(d_int_fp32 * up_fp32)
516475
d_gate_proj_out = d_gate_proj_fp32.astype(gate_proj_out.dtype)
517-
d_up_proj_out = _inspect(d_up_proj_out, "ffn_bwd/d_up_proj_out_after_act_bwd")
518-
d_gate_proj_out = _inspect(d_gate_proj_out, "ffn_bwd/d_gate_proj_out_after_act_bwd")
519476

520477
# wi bwd (fused gate/up via concat). Mirror the fused fwd: pack the
521478
# gate/up cotangents along the trailing axis, run a single
@@ -531,15 +488,12 @@ def _ffn_bwd_per_shard(
531488
casted_wi_rhs_trans,
532489
contracting_dims=((1,), (2,)),
533490
)
534-
d_sorted_x = _inspect(d_sorted_x, "ffn_bwd/d_sorted_x_after_wi_dgrad_sum")
535491
d_wi_combined = tex.grouped_gemm(
536492
casted_sorted_x_lhs_trans,
537493
casted_d_combined.get_tensor(usage=TensorUsage.RHS),
538494
contracting_dims=((0,), (0,)),
539495
)
540496
d_wi_0, d_wi_1 = jnp.split(d_wi_combined, 2, axis=-1)
541-
d_wi_0 = _inspect(d_wi_0, "ffn_bwd/d_wi_0_after_wgrad_pre_psum")
542-
d_wi_1 = _inspect(d_wi_1, "ffn_bwd/d_wi_1_after_wgrad_pre_psum")
543497
if has_bias:
544498
d_wi_combined_bias = tex.grouped_dbias(d_combined, local_group_sizes)
545499
d_wi_0_bias, d_wi_1_bias = jnp.split(d_wi_combined_bias, 2, axis=-1)
@@ -691,7 +645,6 @@ def _moe_fwd_rule(
691645
expert_bias=eb_arg,
692646
compute_aux_scores=False,
693647
)
694-
sparse_probs = _inspect(sparse_probs, "fwd/sparse_probs_after_fused_topk")
695648
# Sigmoid + K>1 normalises as `weights / (weights.sum + 1e-20)`; for
696649
# tokens whose top-K sigmoid scores all underflow at bf16/fp32 the
697650
# output is NaN at the selected positions. Those NaNs ride
@@ -702,7 +655,6 @@ def _moe_fwd_rule(
702655
# are already zero (routing_map is False there); only the rare
703656
# underflow path emits NaN.
704657
sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs).astype(dtype)
705-
sparse_probs = _inspect(sparse_probs, "fwd/sparse_probs_after_sanitize")
706658

707659
# ---------------- Aux loss (global view, replicated) ----------------
708660
# ``fused_moe_aux_loss_fwd`` sums probs and tokens_per_expert across
@@ -771,7 +723,6 @@ def _moe_fwd_rule(
771723
topk_w_3d = jax.lax.with_sharding_constraint(
772724
topk_w_3d, NamedSharding(mesh, ep3_spec)
773725
)
774-
topk_w_3d = _inspect(topk_w_3d, "fwd/topk_w_3d_before_dispatch")
775726

776727
# ---------------- TE EP dispatch (global view) ----------------
777728
handle = _get_or_make_ep_handle(
@@ -785,8 +736,6 @@ def _moe_fwd_rule(
785736
recv_topk_weights = jax.lax.with_sharding_constraint(
786737
recv_topk_weights, NamedSharding(mesh, ep2_spec)
787738
)
788-
recv_tokens = _inspect(recv_tokens, "fwd/recv_tokens_after_dispatch")
789-
recv_topk_weights = _inspect(recv_topk_weights, "fwd/recv_topk_weights_after_dispatch")
790739

791740
# ---------------- FFN (per-shard via shard_map) ----------------
792741
has_bias = wi_0_bias is not None
@@ -822,26 +771,24 @@ def _body(*args):
822771
(r_tok, r_w, w0, w1, w_o) = args
823772
w0b = w1b = wob = None
824773
# Per-rank conditional zero-init of r_tok. Works around a
825-
# narrowly-scoped tex.ep_dispatch_fwd contract gap: the dispatch
826-
# kernel zero-initialises the recv buffer correctly on ranks
827-
# that receive at least one token, but leaves uninitialised
828-
# memory on fully-empty-receiver ranks. ``r_w`` (the dispatch's
829-
# own written-or-not indicator: 0 at padded slots, non-zero at
830-
# real-routed slots) gives us a per-shard predicate for free.
831-
# ``jax.lax.cond`` only executes the selected branch, so loaded
832-
# ranks pay nothing at runtime; only empty ranks do the
833-
# zero-fill. See INTEGRATION_DESIGN.md "FOLLOW-UP" for the bug
834-
# surface details. TODO: remove once tex.ep_dispatch_fwd is
835-
# fixed upstream (or once we adopt tex.tokens_per_expert as
836-
# local_group_sizes to bypass padded slots entirely).
774+
# narrowly-scoped tex.ep_dispatch_fwd contract gap: the NCCL EP
775+
# HT dispatch kernel zero-initialises the recv buffer correctly
776+
# on ranks that receive at least one token, but leaves
777+
# uninitialised memory on fully-empty-receiver ranks. ``r_w``
778+
# (the dispatch's own written-or-not indicator: 0 at padded
779+
# slots, non-zero at real-routed slots) gives us a per-shard
780+
# predicate for free. ``jax.lax.cond`` only executes the
781+
# selected branch, so loaded ranks pay nothing at runtime;
782+
# only empty ranks do the zero-fill.
783+
# TODO: remove once tex.ep_dispatch_fwd zero-inits empty-rank
784+
# recv buffers upstream.
837785
rank_has_tokens = jnp.any(r_w != 0)
838786
r_tok = jax.lax.cond(
839787
rank_has_tokens,
840788
lambda x: x,
841789
lambda x: jnp.zeros_like(x),
842790
r_tok,
843791
)
844-
r_tok = _inspect(r_tok, "fwd/recv_tokens_after_dispatch_sanitize")
845792
return _ffn_fwd_per_shard(
846793
r_tok,
847794
r_w,
@@ -867,7 +814,6 @@ def _body(*args):
867814
expert_outputs = jax.lax.with_sharding_constraint(
868815
expert_outputs, NamedSharding(mesh, ep3_spec)
869816
)
870-
expert_outputs = _inspect(expert_outputs, "fwd/expert_outputs_before_combine")
871817

872818
# ---------------- TE EP combine (global view) ----------------
873819
out_partition_spec = (batch_pspec_axis, None, None)
@@ -884,11 +830,10 @@ def _body(*args):
884830
# IEEE 754: NaN * 0 = NaN, so a multiplicative mask cannot kill
885831
# the NaNs ep_dispatch_fwd leaves at padded slots of recv_tokens
886832
# (they ride through the FFN into expert_outputs at the same
887-
# padded positions). Use jnp.where to overwrite padded positions
888-
# with a literal 0 before combine — confirmed via TE_MOE_INSPECT
889-
# that mean=NaN on expert_outputs[padded] can propagate into the
890-
# combine output when the kernel's read pattern overlaps the
891-
# padded region.
833+
# padded positions): mean=NaN on expert_outputs[padded] then
834+
# propagates into the combine output when the kernel's read
835+
# pattern overlaps the padded region. Use jnp.where to overwrite
836+
# padded positions with a literal 0 before combine.
892837
w = recv_topk_weights[..., None].astype(expert_outputs.dtype)
893838
mask_bool = (recv_topk_weights != 0)[..., None]
894839
weighted = jnp.where(mask_bool, expert_outputs * w, jnp.zeros_like(expert_outputs))
@@ -899,7 +844,6 @@ def _body(*args):
899844
num_local_tokens=(B, S),
900845
out_partition_spec=out_partition_spec,
901846
)
902-
output = _inspect(output, "fwd/output_after_combine")
903847

904848
(
905849
casted_sorted_x_lhs_trans,
@@ -995,12 +939,10 @@ def _moe_bwd_rule(
995939

996940
# ---------------- Combine bwd (global view) ----------------
997941
d_output = jax.lax.with_sharding_constraint(d_output, NamedSharding(mesh, ep3_spec))
998-
d_output = _inspect(d_output, "bwd/d_output_into_combine_bwd")
999942
grad_pre_combine = tex.ep_combine_bwd(ctx.handle, ctx.handle_mem, d_output, recv_pr)
1000943
grad_pre_combine = jax.lax.with_sharding_constraint(
1001944
grad_pre_combine, NamedSharding(mesh, ep3_spec)
1002945
)
1003-
grad_pre_combine = _inspect(grad_pre_combine, "bwd/grad_pre_combine_after_combine_bwd")
1004946

1005947
if apply_topk_weights_early:
1006948
# combine_fwd consumed already-weighted expert_outputs; the recv_w
@@ -1020,23 +962,19 @@ def _moe_bwd_rule(
1020962
# propagates through grad_pre_combine * w * mask into d_expert_outputs
1021963
# and then into every downstream gradient (gate_kernel ends up
1022964
# all-NaN). Sanitize once here.
1023-
recv_w_raw = _inspect(
1024-
ctx.recv_topk_weights, "bwd/ctx.recv_topk_weights_before_sanitize"
1025-
)
1026-
recv_w_clean = jnp.where(jnp.isnan(recv_w_raw), 0, recv_w_raw)
965+
recv_w_clean = jnp.where(jnp.isnan(ctx.recv_topk_weights), 0, ctx.recv_topk_weights)
1027966
# IEEE 754: NaN * 0 = NaN, so multiplying grad_pre_combine by a
1028967
# 0/1 mask cannot kill the NaNs tex.ep_combine_bwd leaves at
1029-
# padded slots of grad_pre_combine. Confirmed via TE_MOE_INSPECT:
1030-
# ctx.recv_topk_weights is clean (after the recv_w_clean
1031-
# sanitize above), but grad_pre_combine[padded] is NaN, so
1032-
# grad_pre_combine * w * mask = NaN. Use jnp.where to overwrite
1033-
# padded positions with literal 0 instead.
968+
# padded slots of grad_pre_combine: ctx.recv_topk_weights is
969+
# clean after the sanitize above, but grad_pre_combine[padded]
970+
# is still NaN, so grad_pre_combine * w * mask = NaN. Use
971+
# jnp.where to overwrite padded positions with literal 0
972+
# instead.
1034973
w = recv_w_clean[..., None].astype(grad_pre_combine.dtype)
1035974
mask_bool = (recv_w_clean != 0)[..., None]
1036975
d_expert_outputs = jnp.where(
1037976
mask_bool, grad_pre_combine * w, jnp.zeros_like(grad_pre_combine)
1038977
)
1039-
d_expert_outputs = _inspect(d_expert_outputs, "bwd/d_expert_outputs_after_w_mask_split")
1040978
# Same masking strategy for the cotangent on recv_topk_weights:
1041979
# grad_pre_combine has NaN at padded slots and ctx.expert_outputs
1042980
# may too, so the per-element product must be jnp.where'd before
@@ -1113,14 +1051,6 @@ def _bwd_body(*args):
11131051
d_wi_0_bias = jax.lax.psum(d_wi_0_bias, axis_name=dp)
11141052
d_wi_1_bias = jax.lax.psum(d_wi_1_bias, axis_name=dp)
11151053
d_wo_bias = jax.lax.psum(d_wo_bias, axis_name=dp)
1116-
# Post-psum probes (TE_MOE_INSPECT=1 only). The
1117-
# sigmoid-bias-strong test asserts on the final d_wo / d_wi_*
1118-
# after the DP psum. If pre-psum probes (above, in
1119-
# _ffn_bwd_per_shard) are clean but post-psum is NaN, the DP
1120-
# psum across an empty-rank shard is the offender.
1121-
d_wo = _inspect(d_wo, "ffn_bwd/d_wo_post_psum")
1122-
d_wi_0 = _inspect(d_wi_0, "ffn_bwd/d_wi_0_post_psum")
1123-
d_wi_1 = _inspect(d_wi_1, "ffn_bwd/d_wi_1_post_psum")
11241054
return (
11251055
d_sorted_x_3d,
11261056
d_recv_w_3d,
@@ -1176,7 +1106,6 @@ def _bwd_body(*args):
11761106
jnp.arange(ctx.routing_map.shape[0])[:, None], selected_experts
11771107
].set(d_topk_w_flat)
11781108

1179-
d_sparse_probs = _inspect(d_sparse_probs, "bwd/d_sparse_probs_before_topk_bwd")
11801109
d_logits_2d = tex.fused_topk_with_score_function_bwd(
11811110
ctx.routing_map,
11821111
ctx.saved_scores,
@@ -1187,7 +1116,6 @@ def _bwd_body(*args):
11871116
score_function=score_function,
11881117
compute_aux_scores=False,
11891118
)
1190-
d_logits_2d = _inspect(d_logits_2d, "bwd/d_logits_2d_after_topk_bwd")
11911119

11921120
# ---------------- Aux loss bwd (global view, replicated) ----------------
11931121
# Reverse the fwd's all-gather/aux pipeline: aux_loss_bwd produces
@@ -1225,7 +1153,6 @@ def _bwd_body(*args):
12251153
gate_kernel_cast = ctx.gate_kernel.astype(ctx.x.dtype)
12261154
d_x_from_gate = jnp.einsum("bse,he->bsh", d_gate_logits, gate_kernel_cast)
12271155
d_gate_kernel = jnp.einsum("bsh,bse->he", ctx.x, d_gate_logits).astype(ctx.gate_kernel.dtype)
1228-
d_gate_kernel = _inspect(d_gate_kernel, "bwd/d_gate_kernel_final")
12291156
d_x = d_x_from_gate + d_x_from_dispatch
12301157

12311158
# Pin output grads to the declared logical axes so downstream

0 commit comments

Comments
 (0)