3737
3838from dataclasses import dataclass
3939from functools import partial
40- import os
4140from typing import Any , Dict , Optional , Tuple , Union
4241import warnings
4342
4746from jax .tree_util import register_pytree_node_class
4847
4948from . import cpp_extensions as tex
50-
51-
52- # Lazy / opt-in: setting TE_MOE_INSPECT=1 in the env wires TE's
53- # inspect_array FFI through the fwd and bwd. When unset (default),
54- # _inspect is the identity, so this has zero runtime cost in normal use.
55- # Dumps land in the process CWD as
56- # my_tensor_gpu{N}_{sanitized_name}.bin + ..._meta.json (one per probe
57- # per rank, since 9cb4cfca threaded `name` through the FFI). Each call
58- # also prints a labelled line ``[gpuN <name>]: ...`` to stdout. We use
59- # the FFI rather than jax.debug.print because jax.debug.print can
60- # deadlock under multi-process (callback ordering across processes is
61- # not synchronised).
62- _INSPECT_ENABLED = os .environ .get ("TE_MOE_INSPECT" , "0" ) == "1"
63- if _INSPECT_ENABLED :
64- from .debug .experimental import inspect_array as _te_inspect_array
65-
66- def _inspect (x : jnp .ndarray , name : str ) -> jnp .ndarray :
67- return _te_inspect_array (x , name )
68-
69- else :
70-
71- def _inspect (x : jnp .ndarray , name : str ) -> jnp .ndarray :
72- del name
73- return x
7449from .quantize import (
7550 TensorUsage ,
7651 noop_quantizer_set ,
@@ -362,16 +337,15 @@ def _ffn_fwd_per_shard(
362337 # An older fused 4D variant built via jnp.stack([wi_0, wi_1], axis=-2)
363338 # put a non-contracting axis in the middle of the RHS, which the
364339 # kernel walked as if it were 3D and read off the end -> NaN.
365- # Confirmed via TE_MOE_INSPECT bisect : the stack-axis variant
340+ # Bisected against a jnp.einsum reference : the stack-axis variant
366341 # produced all-NaN output, while the concat-axis variant (this
367- # path) produces finite outputs matching the jnp.einsum reference.
342+ # path) produces finite outputs matching the reference.
368343 wi_combined = jnp .concatenate ([wi_0 , wi_1 ], axis = - 1 )
369344 wi_combined_bias = (
370345 jnp .concatenate ([wi_0_bias , wi_1_bias ], axis = - 1 ) if wi_0_bias is not None else None
371346 )
372347
373348 q_set = noop_quantizer_set
374- sorted_x = _inspect (sorted_x , "ffn_fwd/sorted_x_in" )
375349 casted_sorted_x = tex .grouped_quantize (sorted_x , q_set .x , local_group_sizes , flatten_axis = - 1 )
376350 casted_wi = tex .grouped_quantize (wi_combined , q_set .kernel , flatten_axis = - 1 )
377351 combined_out = tex .grouped_gemm (
@@ -381,8 +355,6 @@ def _ffn_fwd_per_shard(
381355 bias = wi_combined_bias ,
382356 )
383357 gate_proj_out , up_proj_out = jnp .split (combined_out , 2 , axis = - 1 )
384- gate_proj_out = _inspect (gate_proj_out , "ffn_fwd/gate_proj_out" )
385- up_proj_out = _inspect (up_proj_out , "ffn_fwd/up_proj_out" )
386358 casted_sorted_x_lhs_trans = casted_sorted_x .get_tensor (usage = TensorUsage .LHS_TRANS )
387359 casted_wi_rhs_trans = casted_wi .get_tensor (usage = TensorUsage .RHS_TRANS )
388360
@@ -397,7 +369,6 @@ def _ffn_fwd_per_shard(
397369 act_fn (gate_proj_out .astype (jnp .float32 ))
398370 * up_proj_out .astype (jnp .float32 )
399371 ).astype (sorted_x .dtype )
400- intermediate = _inspect (intermediate , "ffn_fwd/intermediate_after_silu_mul" )
401372
402373 if apply_topk_weights_early :
403374 # Fold the per-token combine weights into the FFN intermediate;
@@ -419,7 +390,6 @@ def _ffn_fwd_per_shard(
419390 contracting_dims = ((1 ,), (1 ,)),
420391 bias = wo_bias ,
421392 )
422- expert_outputs = _inspect (expert_outputs , "ffn_fwd/expert_outputs_after_wo_gemm" )
423393 casted_intermediate_lhs_trans = casted_intermediate .get_tensor (usage = TensorUsage .LHS_TRANS )
424394 casted_wo_rhs_trans = casted_wo .get_tensor (usage = TensorUsage .RHS_TRANS )
425395
@@ -461,15 +431,6 @@ def _ffn_bwd_per_shard(
461431 recv_w_flat = recv_topk_weights_local .reshape (- 1 )
462432 q_set = noop_quantizer_set
463433
464- # FFN bwd sub-step probes (TE_MOE_INSPECT=1 only). Pin down which
465- # bwd sub-step introduces NaN/Inf for the sigmoid-bias-strong
466- # config where some EP ranks have ZERO tokens routed to every
467- # local expert (empty-rank case). For those ranks d_eo_2d should
468- # be entirely zero; if any downstream tensor is non-finite the
469- # offending sub-step is the one that turned a clean zero input
470- # into NaN/Inf.
471- d_eo_2d = _inspect (d_eo_2d , "ffn_bwd/d_eo_2d_in" )
472-
473434 # wo bwd
474435 casted_d_eo = tex .grouped_quantize (d_eo_2d , q_set .dgrad , local_group_sizes , flatten_axis = - 1 )
475436 _casted_d_eo_lhs = casted_d_eo .get_tensor (usage = TensorUsage .LHS )
@@ -479,13 +440,11 @@ def _ffn_bwd_per_shard(
479440 casted_wo_rhs_trans ,
480441 contracting_dims = ((1 ,), (2 ,)),
481442 )
482- d_intermediate = _inspect (d_intermediate , "ffn_bwd/d_intermediate_after_wo_dgrad" )
483443 d_wo = tex .grouped_gemm (
484444 casted_intermediate_lhs_trans ,
485445 _casted_d_eo_rhs ,
486446 contracting_dims = ((0 ,), (0 ,)),
487447 )
488- d_wo = _inspect (d_wo , "ffn_bwd/d_wo_after_wgrad_pre_psum" )
489448 d_wo_bias = tex .grouped_dbias (d_eo_2d , local_group_sizes ) if has_bias else None
490449
491450 act_fn = _convert_to_activation_function (activation_type )
@@ -514,8 +473,6 @@ def _ffn_bwd_per_shard(
514473 d_up_proj_out = (d_int_fp32 * act_gp_fp32 ).astype (up_proj_out .dtype )
515474 (d_gate_proj_fp32 ,) = dact_pullback_fp32 (d_int_fp32 * up_fp32 )
516475 d_gate_proj_out = d_gate_proj_fp32 .astype (gate_proj_out .dtype )
517- d_up_proj_out = _inspect (d_up_proj_out , "ffn_bwd/d_up_proj_out_after_act_bwd" )
518- d_gate_proj_out = _inspect (d_gate_proj_out , "ffn_bwd/d_gate_proj_out_after_act_bwd" )
519476
520477 # wi bwd (fused gate/up via concat). Mirror the fused fwd: pack the
521478 # gate/up cotangents along the trailing axis, run a single
@@ -531,15 +488,12 @@ def _ffn_bwd_per_shard(
531488 casted_wi_rhs_trans ,
532489 contracting_dims = ((1 ,), (2 ,)),
533490 )
534- d_sorted_x = _inspect (d_sorted_x , "ffn_bwd/d_sorted_x_after_wi_dgrad_sum" )
535491 d_wi_combined = tex .grouped_gemm (
536492 casted_sorted_x_lhs_trans ,
537493 casted_d_combined .get_tensor (usage = TensorUsage .RHS ),
538494 contracting_dims = ((0 ,), (0 ,)),
539495 )
540496 d_wi_0 , d_wi_1 = jnp .split (d_wi_combined , 2 , axis = - 1 )
541- d_wi_0 = _inspect (d_wi_0 , "ffn_bwd/d_wi_0_after_wgrad_pre_psum" )
542- d_wi_1 = _inspect (d_wi_1 , "ffn_bwd/d_wi_1_after_wgrad_pre_psum" )
543497 if has_bias :
544498 d_wi_combined_bias = tex .grouped_dbias (d_combined , local_group_sizes )
545499 d_wi_0_bias , d_wi_1_bias = jnp .split (d_wi_combined_bias , 2 , axis = - 1 )
@@ -691,7 +645,6 @@ def _moe_fwd_rule(
691645 expert_bias = eb_arg ,
692646 compute_aux_scores = False ,
693647 )
694- sparse_probs = _inspect (sparse_probs , "fwd/sparse_probs_after_fused_topk" )
695648 # Sigmoid + K>1 normalises as `weights / (weights.sum + 1e-20)`; for
696649 # tokens whose top-K sigmoid scores all underflow at bf16/fp32 the
697650 # output is NaN at the selected positions. Those NaNs ride
@@ -702,7 +655,6 @@ def _moe_fwd_rule(
702655 # are already zero (routing_map is False there); only the rare
703656 # underflow path emits NaN.
704657 sparse_probs = jnp .where (jnp .isnan (sparse_probs ), 0 , sparse_probs ).astype (dtype )
705- sparse_probs = _inspect (sparse_probs , "fwd/sparse_probs_after_sanitize" )
706658
707659 # ---------------- Aux loss (global view, replicated) ----------------
708660 # ``fused_moe_aux_loss_fwd`` sums probs and tokens_per_expert across
@@ -771,7 +723,6 @@ def _moe_fwd_rule(
771723 topk_w_3d = jax .lax .with_sharding_constraint (
772724 topk_w_3d , NamedSharding (mesh , ep3_spec )
773725 )
774- topk_w_3d = _inspect (topk_w_3d , "fwd/topk_w_3d_before_dispatch" )
775726
776727 # ---------------- TE EP dispatch (global view) ----------------
777728 handle = _get_or_make_ep_handle (
@@ -785,8 +736,6 @@ def _moe_fwd_rule(
785736 recv_topk_weights = jax .lax .with_sharding_constraint (
786737 recv_topk_weights , NamedSharding (mesh , ep2_spec )
787738 )
788- recv_tokens = _inspect (recv_tokens , "fwd/recv_tokens_after_dispatch" )
789- recv_topk_weights = _inspect (recv_topk_weights , "fwd/recv_topk_weights_after_dispatch" )
790739
791740 # ---------------- FFN (per-shard via shard_map) ----------------
792741 has_bias = wi_0_bias is not None
@@ -822,26 +771,24 @@ def _body(*args):
822771 (r_tok , r_w , w0 , w1 , w_o ) = args
823772 w0b = w1b = wob = None
824773 # Per-rank conditional zero-init of r_tok. Works around a
825- # narrowly-scoped tex.ep_dispatch_fwd contract gap: the dispatch
826- # kernel zero-initialises the recv buffer correctly on ranks
827- # that receive at least one token, but leaves uninitialised
828- # memory on fully-empty-receiver ranks. ``r_w`` (the dispatch's
829- # own written-or-not indicator: 0 at padded slots, non-zero at
830- # real-routed slots) gives us a per-shard predicate for free.
831- # ``jax.lax.cond`` only executes the selected branch, so loaded
832- # ranks pay nothing at runtime; only empty ranks do the
833- # zero-fill. See INTEGRATION_DESIGN.md "FOLLOW-UP" for the bug
834- # surface details. TODO: remove once tex.ep_dispatch_fwd is
835- # fixed upstream (or once we adopt tex.tokens_per_expert as
836- # local_group_sizes to bypass padded slots entirely).
774+ # narrowly-scoped tex.ep_dispatch_fwd contract gap: the NCCL EP
775+ # HT dispatch kernel zero-initialises the recv buffer correctly
776+ # on ranks that receive at least one token, but leaves
777+ # uninitialised memory on fully-empty-receiver ranks. ``r_w``
778+ # (the dispatch's own written-or-not indicator: 0 at padded
779+ # slots, non-zero at real-routed slots) gives us a per-shard
780+ # predicate for free. ``jax.lax.cond`` only executes the
781+ # selected branch, so loaded ranks pay nothing at runtime;
782+ # only empty ranks do the zero-fill.
783+ # TODO: remove once tex.ep_dispatch_fwd zero-inits empty-rank
784+ # recv buffers upstream.
837785 rank_has_tokens = jnp .any (r_w != 0 )
838786 r_tok = jax .lax .cond (
839787 rank_has_tokens ,
840788 lambda x : x ,
841789 lambda x : jnp .zeros_like (x ),
842790 r_tok ,
843791 )
844- r_tok = _inspect (r_tok , "fwd/recv_tokens_after_dispatch_sanitize" )
845792 return _ffn_fwd_per_shard (
846793 r_tok ,
847794 r_w ,
@@ -867,7 +814,6 @@ def _body(*args):
867814 expert_outputs = jax .lax .with_sharding_constraint (
868815 expert_outputs , NamedSharding (mesh , ep3_spec )
869816 )
870- expert_outputs = _inspect (expert_outputs , "fwd/expert_outputs_before_combine" )
871817
872818 # ---------------- TE EP combine (global view) ----------------
873819 out_partition_spec = (batch_pspec_axis , None , None )
@@ -884,11 +830,10 @@ def _body(*args):
884830 # IEEE 754: NaN * 0 = NaN, so a multiplicative mask cannot kill
885831 # the NaNs ep_dispatch_fwd leaves at padded slots of recv_tokens
886832 # (they ride through the FFN into expert_outputs at the same
887- # padded positions). Use jnp.where to overwrite padded positions
888- # with a literal 0 before combine — confirmed via TE_MOE_INSPECT
889- # that mean=NaN on expert_outputs[padded] can propagate into the
890- # combine output when the kernel's read pattern overlaps the
891- # padded region.
833+ # padded positions): mean=NaN on expert_outputs[padded] then
834+ # propagates into the combine output when the kernel's read
835+ # pattern overlaps the padded region. Use jnp.where to overwrite
836+ # padded positions with a literal 0 before combine.
892837 w = recv_topk_weights [..., None ].astype (expert_outputs .dtype )
893838 mask_bool = (recv_topk_weights != 0 )[..., None ]
894839 weighted = jnp .where (mask_bool , expert_outputs * w , jnp .zeros_like (expert_outputs ))
@@ -899,7 +844,6 @@ def _body(*args):
899844 num_local_tokens = (B , S ),
900845 out_partition_spec = out_partition_spec ,
901846 )
902- output = _inspect (output , "fwd/output_after_combine" )
903847
904848 (
905849 casted_sorted_x_lhs_trans ,
@@ -995,12 +939,10 @@ def _moe_bwd_rule(
995939
996940 # ---------------- Combine bwd (global view) ----------------
997941 d_output = jax .lax .with_sharding_constraint (d_output , NamedSharding (mesh , ep3_spec ))
998- d_output = _inspect (d_output , "bwd/d_output_into_combine_bwd" )
999942 grad_pre_combine = tex .ep_combine_bwd (ctx .handle , ctx .handle_mem , d_output , recv_pr )
1000943 grad_pre_combine = jax .lax .with_sharding_constraint (
1001944 grad_pre_combine , NamedSharding (mesh , ep3_spec )
1002945 )
1003- grad_pre_combine = _inspect (grad_pre_combine , "bwd/grad_pre_combine_after_combine_bwd" )
1004946
1005947 if apply_topk_weights_early :
1006948 # combine_fwd consumed already-weighted expert_outputs; the recv_w
@@ -1020,23 +962,19 @@ def _moe_bwd_rule(
1020962 # propagates through grad_pre_combine * w * mask into d_expert_outputs
1021963 # and then into every downstream gradient (gate_kernel ends up
1022964 # all-NaN). Sanitize once here.
1023- recv_w_raw = _inspect (
1024- ctx .recv_topk_weights , "bwd/ctx.recv_topk_weights_before_sanitize"
1025- )
1026- recv_w_clean = jnp .where (jnp .isnan (recv_w_raw ), 0 , recv_w_raw )
965+ recv_w_clean = jnp .where (jnp .isnan (ctx .recv_topk_weights ), 0 , ctx .recv_topk_weights )
1027966 # IEEE 754: NaN * 0 = NaN, so multiplying grad_pre_combine by a
1028967 # 0/1 mask cannot kill the NaNs tex.ep_combine_bwd leaves at
1029- # padded slots of grad_pre_combine. Confirmed via TE_MOE_INSPECT:
1030- # ctx.recv_topk_weights is clean ( after the recv_w_clean
1031- # sanitize above), but grad_pre_combine[padded] is NaN, so
1032- # grad_pre_combine * w * mask = NaN. Use jnp.where to overwrite
1033- # padded positions with literal 0 instead.
968+ # padded slots of grad_pre_combine: ctx.recv_topk_weights is
969+ # clean after the sanitize above, but grad_pre_combine[padded]
970+ # is still NaN, so grad_pre_combine * w * mask = NaN. Use
971+ # jnp.where to overwrite padded positions with literal 0
972+ # instead.
1034973 w = recv_w_clean [..., None ].astype (grad_pre_combine .dtype )
1035974 mask_bool = (recv_w_clean != 0 )[..., None ]
1036975 d_expert_outputs = jnp .where (
1037976 mask_bool , grad_pre_combine * w , jnp .zeros_like (grad_pre_combine )
1038977 )
1039- d_expert_outputs = _inspect (d_expert_outputs , "bwd/d_expert_outputs_after_w_mask_split" )
1040978 # Same masking strategy for the cotangent on recv_topk_weights:
1041979 # grad_pre_combine has NaN at padded slots and ctx.expert_outputs
1042980 # may too, so the per-element product must be jnp.where'd before
@@ -1113,14 +1051,6 @@ def _bwd_body(*args):
11131051 d_wi_0_bias = jax .lax .psum (d_wi_0_bias , axis_name = dp )
11141052 d_wi_1_bias = jax .lax .psum (d_wi_1_bias , axis_name = dp )
11151053 d_wo_bias = jax .lax .psum (d_wo_bias , axis_name = dp )
1116- # Post-psum probes (TE_MOE_INSPECT=1 only). The
1117- # sigmoid-bias-strong test asserts on the final d_wo / d_wi_*
1118- # after the DP psum. If pre-psum probes (above, in
1119- # _ffn_bwd_per_shard) are clean but post-psum is NaN, the DP
1120- # psum across an empty-rank shard is the offender.
1121- d_wo = _inspect (d_wo , "ffn_bwd/d_wo_post_psum" )
1122- d_wi_0 = _inspect (d_wi_0 , "ffn_bwd/d_wi_0_post_psum" )
1123- d_wi_1 = _inspect (d_wi_1 , "ffn_bwd/d_wi_1_post_psum" )
11241054 return (
11251055 d_sorted_x_3d ,
11261056 d_recv_w_3d ,
@@ -1176,7 +1106,6 @@ def _bwd_body(*args):
11761106 jnp .arange (ctx .routing_map .shape [0 ])[:, None ], selected_experts
11771107 ].set (d_topk_w_flat )
11781108
1179- d_sparse_probs = _inspect (d_sparse_probs , "bwd/d_sparse_probs_before_topk_bwd" )
11801109 d_logits_2d = tex .fused_topk_with_score_function_bwd (
11811110 ctx .routing_map ,
11821111 ctx .saved_scores ,
@@ -1187,7 +1116,6 @@ def _bwd_body(*args):
11871116 score_function = score_function ,
11881117 compute_aux_scores = False ,
11891118 )
1190- d_logits_2d = _inspect (d_logits_2d , "bwd/d_logits_2d_after_topk_bwd" )
11911119
11921120 # ---------------- Aux loss bwd (global view, replicated) ----------------
11931121 # Reverse the fwd's all-gather/aux pipeline: aux_loss_bwd produces
@@ -1225,7 +1153,6 @@ def _bwd_body(*args):
12251153 gate_kernel_cast = ctx .gate_kernel .astype (ctx .x .dtype )
12261154 d_x_from_gate = jnp .einsum ("bse,he->bsh" , d_gate_logits , gate_kernel_cast )
12271155 d_gate_kernel = jnp .einsum ("bsh,bse->he" , ctx .x , d_gate_logits ).astype (ctx .gate_kernel .dtype )
1228- d_gate_kernel = _inspect (d_gate_kernel , "bwd/d_gate_kernel_final" )
12291156 d_x = d_x_from_gate + d_x_from_dispatch
12301157
12311158 # Pin output grads to the declared logical axes so downstream
0 commit comments