Skip to content

Commit bd14fe6

Browse files
committed
[JAX] Resync onto upstream PR #3036, restore TE-EP-only MoE block
Reset 33 local commits onto phuong/ep-3-jax @ c34771d (her latest with EpConfig + EpLayerConfig API, NCCL bumped to 808d2433) and re-applied the three deltas uniquely ours: * transformer_engine/jax/moe.py: replaces upstream's multi-backend MoE block with our TE-EP-only single-custom-vjp rewrite. Adapted to her new API surface: tex.EpLayerConfig replaces tex.ep_make_handle (no more EpHandle pool/cache); 5 EP callsites rewired (cfg passed in place of handle, ep_prepare arg order swapped, top_k= dropped from ep_dispatch_bwd since it's now in cfg. * tests/jax/test_te_ep_moe.py: TE-EP MoE test (kept), with ep_bootstrap kwargs ep_size= and allow_handle_mem_reloc= dropped (no longer supported; ep_size is derived from mesh axes and the handle_mem reloc gating is gone). * tests/jax/run_te_ep_moe.sh: multi-process launcher (kept). Pre-sync state preserved at branch teddy/te_ep_integration.backup-pre-phuong-sync. EOF )
1 parent c34771d commit bd14fe6

3 files changed

Lines changed: 1973 additions & 1841 deletions

File tree

tests/jax/run_te_ep_moe.sh

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#!/usr/bin/env bash
2+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# See LICENSE for license information.
5+
#
6+
# Multiprocess (one-GPU-per-process) launcher for the TE-EP MoE custom_vjp
7+
# test suite. Forks one pytest invocation per visible GPU, passing each
8+
# its own --num-process=N --process-id=i, and waits for all of them. Each
9+
# child calls jax.distributed.initialize(..., local_device_ids=process_id)
10+
# so each Python process only sees its one GPU as a local device and the
11+
# participating processes form a global (ep, fsdp) mesh.
12+
13+
set -euo pipefail
14+
15+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
16+
TE_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
17+
TEST_FILE="$TE_ROOT/tests/jax/test_te_ep_moe.py"
18+
PYTEST_INI="$TE_ROOT/tests/jax/pytest.ini"
19+
20+
NUM_GPUS="${NUM_GPUS:-$(nvidia-smi -L | wc -l)}"
21+
if [ "$NUM_GPUS" -lt 4 ]; then
22+
echo "[run_te_ep_moe.sh] need >=4 GPUs (got $NUM_GPUS); aborting" >&2
23+
exit 1
24+
fi
25+
26+
export XLA_PYTHON_CLIENT_PREALLOCATE="${XLA_PYTHON_CLIENT_PREALLOCATE:-false}"
27+
export XLA_PYTHON_CLIENT_MEM_FRACTION="${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.5}"
28+
export TE_EP_MOE_COORDINATOR_ADDRESS="${TE_EP_MOE_COORDINATOR_ADDRESS:-127.0.0.1:13457}"
29+
30+
echo "============================================================"
31+
echo "TE-EP MoE MULTIPROCESS test (one process per GPU, ${NUM_GPUS} GPUs)"
32+
echo " test file : $TEST_FILE"
33+
echo " coordinator : $TE_EP_MOE_COORDINATOR_ADDRESS"
34+
echo " XLA_PYTHON_CLIENT_PREALLOCATE: $XLA_PYTHON_CLIENT_PREALLOCATE"
35+
echo " XLA_PYTHON_CLIENT_MEM_FRACTION: $XLA_PYTHON_CLIENT_MEM_FRACTION"
36+
echo "============================================================"
37+
38+
if [ -n "${TE_EP_MOE_MP_LOG_DIR:-}" ]; then
39+
LOG_DIR="$TE_EP_MOE_MP_LOG_DIR"
40+
mkdir -p "$LOG_DIR"
41+
else
42+
LOG_DIR=$(mktemp -d -t te_ep_moe_mp_XXXXXX)
43+
fi
44+
echo "Per-process logs: $LOG_DIR"
45+
46+
PIDS=()
47+
48+
cleanup() {
49+
for pid in "${PIDS[@]:-}"; do
50+
if kill -0 "$pid" 2>/dev/null; then
51+
kill -TERM "$pid" 2>/dev/null || true
52+
fi
53+
done
54+
sleep 1
55+
for pid in "${PIDS[@]:-}"; do
56+
if kill -0 "$pid" 2>/dev/null; then
57+
kill -KILL "$pid" 2>/dev/null || true
58+
fi
59+
done
60+
}
61+
trap cleanup EXIT INT TERM
62+
63+
for i in $(seq 0 $((NUM_GPUS - 1))); do
64+
LOG_FILE="$LOG_DIR/proc_${i}.log"
65+
PYTEST_CMD=(
66+
python3 -m pytest -c "$PYTEST_INI"
67+
"$TEST_FILE"
68+
-p no:typeguard
69+
-v -s
70+
--num-process="$NUM_GPUS"
71+
--process-id="$i"
72+
)
73+
if [ "$i" -eq 0 ]; then
74+
echo "=== Live output from process 0 ==="
75+
"${PYTEST_CMD[@]}" 2>&1 | tee "$LOG_FILE" &
76+
else
77+
"${PYTEST_CMD[@]}" > "$LOG_FILE" 2>&1 &
78+
fi
79+
PIDS+=("$!")
80+
done
81+
82+
EXITS=()
83+
for pid in "${PIDS[@]}"; do
84+
if wait "$pid"; then
85+
EXITS+=("0")
86+
else
87+
EXITS+=("$?")
88+
fi
89+
done
90+
91+
echo
92+
echo "============================================================"
93+
echo "Per-process exit codes:"
94+
for i in "${!EXITS[@]}"; do
95+
echo " proc $i -> ${EXITS[$i]}"
96+
done
97+
98+
# Treat exit 0 (pass) and exit 5 (pytest "no tests collected", which the
99+
# file emits via pytest.skip(allow_module_level=True) on pre-Blackwell
100+
# GPUs) as success.
101+
FAILED=0
102+
for e in "${EXITS[@]}"; do
103+
if [ "$e" != "0" ] && [ "$e" != "5" ]; then
104+
FAILED=1
105+
break
106+
fi
107+
done
108+
109+
echo
110+
if [ "$FAILED" -eq 0 ]; then
111+
echo "[run_te_ep_moe.sh] all processes PASSED"
112+
if [ -z "${TE_EP_MOE_MP_LOG_DIR:-}" ]; then
113+
rm -rf "$LOG_DIR"
114+
fi
115+
exit 0
116+
fi
117+
118+
echo "[run_te_ep_moe.sh] at least one process FAILED"
119+
echo " retaining logs at $LOG_DIR for diagnosis"
120+
echo " process 0 tail:"
121+
tail -20 "$LOG_DIR/proc_0.log" 2>/dev/null || true
122+
exit 1

0 commit comments

Comments
 (0)