diff --git a/.gitignore b/.gitignore index 8a627a7e76..7b03d79cc7 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,6 @@ tensor_dumps/ artifacts/ .DS_Store .claude/ + +# NCCL EP shared library staged by setup.py for wheel packaging. +/transformer_engine/libnccl_ep.so* diff --git a/.gitmodules b/.gitmodules index 4b188d6bb1..495d8e3fe7 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git +[submodule "3rdparty/nccl"] + path = 3rdparty/nccl + url = https://github.com/NVIDIA/nccl.git diff --git a/3rdparty/nccl b/3rdparty/nccl new file mode 160000 index 0000000000..808d2433dd --- /dev/null +++ b/3rdparty/nccl @@ -0,0 +1 @@ +Subproject commit 808d2433dda3cccc80f8172a94a6b117359e7102 diff --git a/build_tools/jax.py b/build_tools/jax.py index 35ed62f832..4ad0442055 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -120,6 +120,9 @@ def setup_jax_extension( if bool(int(os.getenv("NVTE_WITH_CUBLASMP", 0))): cxx_flags.append("-DNVTE_WITH_CUBLASMP") + if bool(int(os.getenv("NVTE_BUILD_WITH_NCCL_EP", "1"))): + cxx_flags.append("-DNVTE_WITH_NCCL_EP") + # Define TE/JAX as a Pybind11Extension from pybind11.setup_helpers import Pybind11Extension diff --git a/examples/jax/ep/bench/ep_bench.py b/examples/jax/ep/bench/ep_bench.py new file mode 100644 index 0000000000..27ad8ca146 --- /dev/null +++ b/examples/jax/ep/bench/ep_bench.py @@ -0,0 +1,308 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX EP perf bench — dispatch/combine (raw fwd + custom_vjp wrapper) on a 1DP x EP mesh. + +One process per GPU; launch via run_ep_bench.sh. Each stage is jitted and +timed separately with NVTX ranges (prepare runs once outside the loop). +Rank-0 prints mean wall in us; nsys / --xplane attribute kernels per stage. +""" + +import argparse +import os +import sys +import time + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from transformer_engine.jax.cpp_extensions import ep as tex_ep +from transformer_engine.jax.ep import EpLayerConfig, ep_bootstrap, ep_dispatch, ep_combine +from transformer_engine.jax.sharding import MeshResource, global_shard_guard + + +def _parse_args(): + p = argparse.ArgumentParser(description="TE-JAX EP perf bench (dispatch_fwd + combine_fwd)") + p.add_argument("--coordinator-address", required=True) + p.add_argument("--process-id", type=int, required=True) + p.add_argument("--num-processes", type=int, required=True) + p.add_argument("--tokens-per-rank", type=int, default=8192) + p.add_argument("--hidden", type=int, default=7168) + p.add_argument("--top-k", type=int, default=8) + p.add_argument("--num-experts", type=int, default=256) + p.add_argument("--dp-size", type=int, default=1) + p.add_argument("--warmup", type=int, default=2) + p.add_argument("--iters", type=int, default=10) + p.add_argument( + "--max-num-sms", + type=int, + default=0, + help="Max SMs for dispatch / combine / preprocess kernels (0 = auto).", + ) + p.add_argument( + "--mode-label", + default=None, + help="Optional label suffix for NVTX range names so nsys can partition kernels.", + ) + p.add_argument( + "--second-step", + action="store_true", + help=( + "Time only the 2nd step (1 warmup iter, 1 timed iter). Use to isolate " + "JIT-cache-warm-but-no-steady-state-batching overhead from steady-state perf." + ), + ) + p.add_argument( + "--xplane", + default=None, + help="If set, jax.profiler dumps an XPlane trace into this dir (rank 0 only).", + ) + return p.parse_args() + + +def _distributed_init(args): + jax.distributed.initialize( + coordinator_address=args.coordinator_address, + num_processes=args.num_processes, + process_id=args.process_id, + local_device_ids=[args.process_id], + ) + + +def _build_mesh(args): + n = args.num_processes + assert n % args.dp_size == 0 + ep = n // args.dp_size + devs = np.asarray(jax.devices()).reshape(args.dp_size, ep) + return Mesh(devs, ("dp", "ep")), ep + + +def _make_inputs(args, ep_size): + """Round-robin routing, uniform top-k weights; each rank sees ``args.tokens_per_rank`` tokens.""" + n = args.num_processes + T = args.tokens_per_rank + H = args.hidden + K = args.top_k + E = args.num_experts + del ep_size + + topk_idx = np.empty((n * T, K), dtype=np.int32) + for t in range(n * T): + for k in range(K): + topk_idx[t, k] = (t * K + k) % E + topk_idx = jnp.asarray(topk_idx) + topk_w = jnp.full((n * T, K), 1.0 / K, dtype=jnp.float32) + tokens = jnp.asarray( + np.random.default_rng(0).standard_normal((n * T, H), dtype=np.float32) * 0.5, + dtype=jnp.bfloat16, + ) + return tokens, topk_idx, topk_w + + +def main(): + args = _parse_args() + _distributed_init(args) + mesh, ep_size = _build_mesh(args) + mr = MeshResource(dp_resource="dp", ep_resource="ep") + rank = args.process_id + + local_experts = args.num_experts // ep_size + recv_capacity_per_rank = args.num_processes * args.tokens_per_rank * args.top_k // 2 + + if rank == 0: + print( + f"[ep_bench] world={args.num_processes} dp={args.dp_size} ep={ep_size}" + f" T={args.tokens_per_rank} H={args.hidden} K={args.top_k}" + f" E={args.num_experts} (local={local_experts}) recv_pr={recv_capacity_per_rank}" + + (f" mode={args.mode_label}" if args.mode_label else ""), + flush=True, + ) + + nvtx_suffix = f"[{args.mode_label}]" if args.mode_label else "" + + in_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + out_spec = (("dp", "ep"), None) + T_global = args.num_processes * args.tokens_per_rank + + with mesh, global_shard_guard(mr): + ep_bootstrap( + world_size=args.num_processes, + rank=rank, + num_experts=args.num_experts, + max_tokens_per_rank=args.tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=args.hidden, + max_num_sms=args.max_num_sms, + ) + + tokens, topk_idx, topk_w = _make_inputs(args, ep_size) + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(mesh, in_spec)) + tok_s = jax.lax.with_sharding_constraint(tokens, NamedSharding(mesh, in_spec)) + w_s = jax.lax.with_sharding_constraint(topk_w, NamedSharding(mesh, in_spec)) + + cfg = EpLayerConfig(top_k=args.top_k, dispatch_output_per_expert_alignment=16) + + @jax.jit + def run_prepare(idx): + tc, hm = tex_ep.ep_prepare(cfg, idx) + return tc, hm + + @jax.jit + def run_dispatch(hm, idx, toks, w): + recv_t, recv_w = tex_ep.ep_dispatch_fwd(cfg, hm, idx, toks, w, recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(mesh, ep_spec_3d)) + recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(mesh, ep_spec_2d)) + return recv_t, recv_w + + @jax.jit + def run_dispatch_vjp(idx, toks, w): + recv_t, recv_w, _hm, _tc = ep_dispatch(cfg, idx, toks, w, recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(mesh, ep_spec_3d)) + recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(mesh, ep_spec_2d)) + return recv_t, recv_w + + @jax.jit + def run_combine(hm, recv_t): + out = tex_ep.ep_combine_fwd( + cfg, + hm, + recv_t, + T_global, + out_partition_spec=out_spec, + ) + return out + + @jax.jit + def run_combine_vjp(hm, tc, recv_t): + # ep_combine is unweighted; bench feeds expert_out directly (caller + # would otherwise pre-multiply by recv_topk_weights + mask). + out = ep_combine(cfg, hm, tc, recv_t, T_global, out_sharding=out_spec) + return out + + tc, handle_mem = run_prepare(idx_s) + tc.block_until_ready() + handle_mem.block_until_ready() + + recv_t0, recv_w0 = run_dispatch(handle_mem, idx_s, tok_s, w_s) + recv_t0.block_until_ready() + recv_w0.block_until_ready() + + warmup_n = 1 if args.second_step else args.warmup + iters_n = 1 if args.second_step else args.iters + + for _ in range(warmup_n): + r, _rw = run_dispatch(handle_mem, idx_s, tok_s, w_s) + r.block_until_ready() + o = run_combine(handle_mem, r) + o.block_until_ready() + run_dispatch_vjp(idx_s, tok_s, w_s)[0].block_until_ready() + run_combine_vjp(handle_mem, tc, recv_t0).block_until_ready() + + if args.xplane and rank == 0: + os.makedirs(args.xplane, exist_ok=True) + jax.profiler.start_trace(args.xplane) + + try: + import nvtx as _nvtx + + def _push(name): + _nvtx.push_range(message=name) + + def _pop(): + _nvtx.pop_range() + + except ImportError: + + def _push(name): + pass + + def _pop(): + pass + + def _time_stage_wall_us(name, fn): + # First timed iter still carries an autotune outlier even after JIT + # warmup; run iters_n + 1, drop iter 0 from the average, and push + # the NVTX range AFTER iter 0 so nsys' nvtx_kern_sum excludes the + # outlier too. + total_ns = 0 + counted = 0 + for i in range(iters_n + 1): + if i == 1: + _push(f"{name}{nvtx_suffix}") + t0 = time.perf_counter_ns() + fn() + dt = time.perf_counter_ns() - t0 + if i == 0: + continue + total_ns += dt + counted += 1 + _pop() + return total_ns / 1e3 / counted + + def _do_dispatch(): + r, _ = run_dispatch(handle_mem, idx_s, tok_s, w_s) + r.block_until_ready() + + def _do_dispatch_vjp(): + r, _ = run_dispatch_vjp(idx_s, tok_s, w_s) + r.block_until_ready() + + def _do_combine(): + o = run_combine(handle_mem, recv_t0) + o.block_until_ready() + + def _do_combine_vjp(): + o = run_combine_vjp(handle_mem, tc, recv_t0) + o.block_until_ready() + + d_wall_us = _time_stage_wall_us("dispatch_fwd", _do_dispatch) + dv_wall_us = _time_stage_wall_us("ep_dispatch_vjp", _do_dispatch_vjp) + c_wall_us = _time_stage_wall_us("combine_fwd", _do_combine) + cv_wall_us = _time_stage_wall_us("ep_combine_vjp", _do_combine_vjp) + + if args.xplane and rank == 0: + jax.profiler.stop_trace() + + if rank == 0: + label = f" [{args.mode_label}]" if args.mode_label else "" + print("", flush=True) + print(f"| stage | mean wall (us){label} |", flush=True) + print("|-------------------|---------------:|", flush=True) + print(f"| dispatch_fwd | {d_wall_us:14.1f} |", flush=True) + print(f"| ep_dispatch_vjp | {dv_wall_us:14.1f} |", flush=True) + print(f"| combine_fwd | {c_wall_us:14.1f} |", flush=True) + print(f"| ep_combine_vjp | {cv_wall_us:14.1f} |", flush=True) + print(f"| (dispatch vjp-fwd)| {dv_wall_us - d_wall_us:14.1f} |", flush=True) + print(f"| (combine vjp-fwd)| {cv_wall_us - c_wall_us:14.1f} |", flush=True) + print("", flush=True) + print( + "[ep_bench] kernel breakout: see nsys nvtx_kern_sum output below " + "(produced by run_ep_bench.sh --nsys).", + flush=True, + ) + + # Under nsys: force cudaDeviceReset() to drain CUPTI's in-process kernel + # records into the .nsys-rep, then os._exit to skip JAX's coord-service + # watchdog. The reset crashes during NCCL EP context teardown, so we only + # take this path when the launcher opts in via EP_BENCH_FLUSH_CUPTI=1. + if os.environ.get("EP_BENCH_FLUSH_CUPTI", "0") == "1": + try: + import ctypes + + cudart = ctypes.CDLL("libcudart.so") + cudart.cudaDeviceSynchronize() + cudart.cudaDeviceReset() + except Exception: + pass + time.sleep(0.5) + sys.stdout.flush() + sys.stderr.flush() + os._exit(0) + + +if __name__ == "__main__": + main() diff --git a/examples/jax/ep/bench/run_ep_bench.sh b/examples/jax/ep/bench/run_ep_bench.sh new file mode 100755 index 0000000000..1531dfd5cf --- /dev/null +++ b/examples/jax/ep/bench/run_ep_bench.sh @@ -0,0 +1,352 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# 4-rank launcher for ep_bench.py. +# Examples: +# bash run_ep_bench.sh # plain run, stdout only +# bash run_ep_bench.sh --cuda-graph # enable XLA command-buffer (cudaGraph), min_size=1 +# bash run_ep_bench.sh --nsys # nsys on rank 0 -> results/jax_nsys.nsys-rep +# bash run_ep_bench.sh --xplane # jax.profiler on rank 0 -> results/xplane/ +# +# Notes: +# * nsys + xplane cannot be combined (both attach CUPTI -> MULTIPLE_SUBSCRIBERS). +# * nsys + --cuda-graph is rejected: cudaGraph fires kernels via cuGraphLaunch +# and detaches the host NVTX context, breaking per-stage attribution. +# * stdout per rank lands in results/stdout__rank_.txt. + +set -uo pipefail + +NSYS=0; XPLANE=0; CGRAPH=0; SECOND_STEP=0 +for a in "$@"; do + case "$a" in + --nsys) NSYS=1 ;; + --xplane) XPLANE=1 ;; + --cuda-graph) CGRAPH=1 ;; + --second-step) SECOND_STEP=1 ;; + *) echo "unknown arg: $a" >&2; exit 2 ;; + esac +done +if [ "${NSYS}" -eq 1 ] && [ "${XPLANE}" -eq 1 ]; then + echo "--nsys and --xplane both attach CUPTI; pick one." >&2; exit 2 +fi +if [ "${NSYS}" -eq 1 ] && [ "${CGRAPH}" -eq 1 ]; then + echo "--nsys and --cuda-graph cannot be combined: cudaGraph launches detach the" \ + "host NVTX context, so nvtx_kern_sum cannot attribute kernels to our ranges." >&2 + exit 2 +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../../../.." && pwd)" +RESULTS="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS}" +export PYTHONPATH="${TE_REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +NUM_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "EP bench requires >=4 GPUs (found ${NUM_GPUS}); SKIPPING."; exit 0 +fi +NUM=4 +COORD="${COORD:-127.0.0.1:23457}" +TIMEOUT_S="${TIMEOUT_S:-1800}" + +XLA_BASE="${XLA_BASE:---xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_graph_min_graph_size=1}" + +if [ "${CGRAPH}" -eq 1 ]; then + TAG="cudagraph" + export XLA_FLAGS="${XLA_BASE} --xla_gpu_enable_command_buffer=FUSION,CUSTOM_CALL --xla_gpu_graph_min_graph_size=1" +else + TAG="vanilla" + export XLA_FLAGS="${XLA_BASE} --xla_gpu_enable_command_buffer=" +fi +[ "${SECOND_STEP}" -eq 1 ] && TAG="${TAG}_step2" + +: "${NCCL_EP_JIT_CACHE_DIR:=${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)}" +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "${NCCL_EP_JIT_CACHE_DIR}" + +# JAX/XLA persistent compilation cache: first run pays full compile cost +# (cudaGraph capture + EP custom_calls is minutes); subsequent runs reuse it. +: "${JAX_COMPILATION_CACHE_DIR:=${TMPDIR:-/tmp}/jax_cache_$(id -u)}" +export JAX_COMPILATION_CACHE_DIR +mkdir -p "${JAX_COMPILATION_CACHE_DIR}" + +export XLA_PYTHON_CLIENT_MEM_FRACTION="${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.2}" +export NVTE_EP_SILENCE_NONSYMM_WARN="${NVTE_EP_SILENCE_NONSYMM_WARN:-1}" + +ALL_RANKS_ARGS=() +R0_ONLY_ARGS=() +NSYS_PREFIX=() +SUFFIX="" +if [ "${SECOND_STEP}" -eq 1 ]; then + ALL_RANKS_ARGS+=(--second-step) +fi +if [ "${XPLANE}" -eq 1 ]; then + R0_ONLY_ARGS+=(--xplane "${RESULTS}/xplane_${TAG}") + SUFFIX="_xplane" +fi +if [ "${NSYS}" -eq 1 ]; then + SUFFIX="_nsys" + export EP_BENCH_FLUSH_CUPTI=1 + NSYS_PREFIX=(nsys profile + --output "${RESULTS}/jax_${TAG}_nsys" + --force-overwrite=true + --trace=cuda,nvtx + --gpu-metrics-devices=none + --cuda-um-cpu-page-faults=false + --cuda-um-gpu-page-faults=false) +fi + +OUT_PREFIX="stdout_${TAG}${SUFFIX}_rank" + +for f in "${RESULTS}/${OUT_PREFIX}_"*.txt \ + "${RESULTS}/jax_${TAG}_nsys.nsys-rep" \ + "${RESULTS}/jax_${TAG}_nsys.sqlite" \ + "${RESULTS}/jax_${TAG}_nsys_nvtx_kern_sum.csv" \ + "${RESULTS}/jax_${TAG}_nsys_kern_sum.csv" \ + "${RESULTS}/summary_${TAG}${SUFFIX}.md"; do + [ -f "$f" ] && mv -f "$f" "$f.prev" +done + +PIDS=() +cleanup() { for pid in "${PIDS[@]}"; do kill -KILL "$pid" 2>/dev/null || true; done; } +trap cleanup EXIT INT TERM + +for ((i=1; i "${RESULTS}/${OUT_PREFIX}_${i}.txt" 2>&1 & + PIDS+=($!) +done + +R0_CMD=(python -u "${SCRIPT_DIR}/ep_bench.py" + --coordinator-address "${COORD}" --process-id 0 --num-processes "${NUM}" + "${ALL_RANKS_ARGS[@]}" "${R0_ONLY_ARGS[@]}") +if [ "${NSYS}" -eq 1 ]; then + R0_CMD=("${NSYS_PREFIX[@]}" "${R0_CMD[@]}") +fi + +WATCHDOG_PID="" +if [ "${NSYS}" -eq 1 ]; then + ( while ! grep -q "kernel breakout" "${RESULTS}/${OUT_PREFIX}_0.txt" 2>/dev/null; do + sleep 2 + done + sleep 20 + pkill -INT -f "nsys profile --output ${RESULTS}/jax_${TAG}_nsys" 2>/dev/null || true + ) & + WATCHDOG_PID=$! +fi + +timeout --foreground --signal=KILL "${TIMEOUT_S}" "${R0_CMD[@]}" 2>&1 | tee "${RESULTS}/${OUT_PREFIX}_0.txt" +if [ -n "${WATCHDOG_PID}" ]; then + kill "${WATCHDOG_PID}" 2>/dev/null || true +fi +wait + +SUMMARY="${RESULTS}/summary_${TAG}${SUFFIX}.md" +RANK0_LOG="${RESULTS}/${OUT_PREFIX}_0.txt" + +{ + echo "# JAX EP bench summary — tag=${TAG}${SUFFIX}" + echo "" + echo "Generated: $(date -Iseconds)" + echo "Rank-0 log: \`${RANK0_LOG}\`" + echo "" + echo "## Per-stage runtime (rank 0)" + echo "" + echo '```' + awk '/^\| stage / {flag=1} flag {print; if (/combine[ ]+vjp-fwd/) {flag=0}}' "${RANK0_LOG}" || true + echo '```' +} > "${SUMMARY}" + +if [ "${NSYS}" -eq 1 ]; then + NSYS_REP="${RESULTS}/jax_${TAG}_nsys.nsys-rep" + NVTX_CSV="${RESULTS}/jax_${TAG}_nsys_nvtx_kern_sum.csv" + KERN_CSV="${RESULTS}/jax_${TAG}_nsys_kern_sum.csv" + if [ -f "${NSYS_REP}" ] && command -v nsys >/dev/null 2>&1; then + PROJ_CSV="${RESULTS}/jax_${TAG}_nsys_nvtx_gpu_proj_sum.csv" + echo "Extracting NVTX-range + kernel summaries from ${NSYS_REP} ..." + nsys stats --report nvtx_kern_sum --format csv \ + --output - "${NSYS_REP}" > "${NVTX_CSV}" 2>&1 || true + nsys stats --report cuda_gpu_kern_sum --format csv \ + --output - "${NSYS_REP}" > "${KERN_CSV}" 2>&1 || true + nsys stats --report nvtx_gpu_proj_sum --format csv \ + --output - "${NSYS_REP}" > "${PROJ_CSV}" 2>&1 || true + + BREAKOUT=$(python3 - "${NVTX_CSV}" "${PROJ_CSV}" <<'PYEOF' +import csv, sys, collections, re +path = sys.argv[1] + +STAGE_PATTERNS = { + "dispatch_fwd": re.compile(r"(^|:)dispatch_fwd(\[[^\]]*\])?$"), + "ep_dispatch_vjp": re.compile(r"(^|:)ep_dispatch_vjp(\[[^\]]*\])?$"), + "combine_fwd": re.compile(r"(^|:)combine_fwd(\[[^\]]*\])?$"), + "ep_combine_vjp": re.compile(r"(^|:)ep_combine_vjp(\[[^\]]*\])?$"), +} +STAGE_ORDER = ("dispatch_fwd", "ep_dispatch_vjp", "combine_fwd", "ep_combine_vjp") + +stages = collections.defaultdict(list) +try: + with open(path) as f: + lines = [ln for ln in f] + header_idx = next((i for i, ln in enumerate(lines) + if ln.lstrip().startswith("NVTX Range,")), -1) + if header_idx < 0: + print("(NVTX header not found)"); sys.exit(0) + reader = csv.reader(lines[header_idx:]) + header = next(reader, None) + def col(name): + for i, h in enumerate(header): + if h.strip().lower() == name.lower(): + return i + return -1 + i_range = col("NVTX Range") + i_total = col("Total Time (ns)") + i_inst = col("Kern Inst") + i_name = col("Kernel Name") + if min(i_range, i_total, i_inst, i_name) < 0: + print(f"(missing expected columns; got {header})"); sys.exit(0) + for row in reader: + if len(row) <= i_name: continue + rname = row[i_range].strip() + try: + total_ns = int(row[i_total].replace(',', '')) + inst = int(row[i_inst].replace(',', '')) + except ValueError: + continue + kname = row[i_name].strip() + for stage, pat in STAGE_PATTERNS.items(): + if pat.search(rname): + stages[stage].append((total_ns, inst, kname)) + break +except FileNotFoundError: + print("(nvtx_kern_sum CSV not found)"); sys.exit(0) + +if not stages: + print("(no kernels matched expected NVTX ranges)") + sys.exit(0) + +proj_csv = sys.argv[2] if len(sys.argv) > 2 else None +proj = {} +if proj_csv: + try: + with open(proj_csv) as f: + plines = list(f) + hidx = next((i for i, ln in enumerate(plines) + if ln.lstrip().startswith("Range,")), -1) + if hidx >= 0: + pr = csv.reader(plines[hidx:]) + ph = next(pr, None) + def pcol(n): + for i, h in enumerate(ph): + if h.strip().lower() == n.lower(): return i + return -1 + pi_range = pcol("Range") + pi_total = pcol("Total Proj Time (ns)") + pi_inst = pcol("Range Instances") + pi_gpuops = pcol("Total GPU Ops") + for row in pr: + if len(row) <= max(pi_range, pi_total, pi_inst): continue + rname = row[pi_range].strip() + for stage, pat in STAGE_PATTERNS.items(): + if pat.search(rname): + try: + t = int(row[pi_total].replace(',', '')) + n = int(row[pi_inst].replace(',', '')) + ops = int(row[pi_gpuops].replace(',', '')) if pi_gpuops >= 0 else 0 + except ValueError: + continue + proj[stage] = (t / 1e3, n) + break + except FileNotFoundError: + pass + +print("### Per-stage GPU activity (kernels + memops, from nvtx_gpu_proj_sum)") +print() +print("| stage | iters | GPU activity total (us) | per-iter (us) | kernel sum (us) | per-iter (us) | gap = memops+idle (us) |") +print("|------|-----:|----------------------:|------------:|--------------:|------------:|---------------------:|") +for stage in STAGE_ORDER: + rows = stages.get(stage, []) + kern_total_us = sum(r[0] for r in rows) / 1e3 + iters = max(rows, key=lambda r: r[0])[1] if rows else 0 + gpu_total_us, _ = proj.get(stage, (0.0, 0)) + per_iter_gpu = gpu_total_us / iters if iters else 0 + per_iter_kern = kern_total_us / iters if iters else 0 + gap = per_iter_gpu - per_iter_kern + print(f"| `{stage}` | {iters} | {gpu_total_us:18.1f} | {per_iter_gpu:11.1f} | {kern_total_us:13.1f} | {per_iter_kern:11.1f} | {gap:20.1f} |") +print() + +def _kern_per_iter(rows, needle): + tot_ns = 0; inst = 0 + for tns, n, kname in rows: + if needle in kname: + tot_ns += tns; inst += n + return (tot_ns / inst / 1e3) if inst else None + +KEY_KERNELS = { + "dispatch_fwd": [("dispatch", "nccl_ep_jit_ht_dispatch_kernel"), + ("permute", "nccl_ep_jit_ht_permute_kernel")], + "ep_dispatch_vjp": [("dispatch", "nccl_ep_jit_ht_dispatch_kernel"), + ("permute", "nccl_ep_jit_ht_permute_kernel")], + "combine_fwd": [("combine", "nccl_ep_jit_ht_combine_kernel"), + ("local_reduce", "nccl_ep_jit_ht_local_reduce_kernel")], + "ep_combine_vjp": [("combine", "nccl_ep_jit_ht_combine_kernel"), + ("local_reduce", "nccl_ep_jit_ht_local_reduce_kernel")], +} + +print("### Key NCCL EP kernel time per iter (us)") +print() +print("| stage | primary kernel (us/iter) | secondary kernel (us/iter) | kernel sum/iter (us) |") +print("|------|--------------------:|-----------------------:|------------------:|") +for stage in STAGE_ORDER: + rows = stages.get(stage, []) + iters = max(rows, key=lambda r: r[0])[1] if rows else 0 + per_iter_kern = (sum(r[0] for r in rows) / 1e3 / iters) if iters else 0.0 + keys = KEY_KERNELS.get(stage, []) + cells = [] + for label, needle in keys: + v = _kern_per_iter(rows, needle) + cells.append(f"{label}: {v:.1f}" if v is not None else f"{label}: -") + while len(cells) < 2: + cells.append("-") + print(f"| `{stage}` | {cells[0]:>20} | {cells[1]:>22} | {per_iter_kern:17.1f} |") +print() + +for stage in STAGE_ORDER: + rows = stages.get(stage, []) + if not rows: + print(f"### Stage `{stage}` top kernels — none"); print(); continue + agg = collections.defaultdict(lambda: [0, 0]) + for tns, inst, kname in rows: + agg[kname][0] += tns + agg[kname][1] += inst + items = sorted(([k, v[0], v[1]] for k, v in agg.items()), key=lambda x: -x[1]) + total_us = sum(v[1] for v in items) / 1e3 + print(f"### Stage `{stage}` — top 20 kernels ({len(items)} distinct; kernel-sum {total_us:.1f} us)") + print() + print("| # | total (us) | inst | avg (us) | kernel |") + print("|--:|-----------:|-----:|---------:|--------|") + for i, (kname, tns, inst) in enumerate(items[:20], 1): + avg_us = (tns / inst) / 1e3 if inst else 0 + short = kname if len(kname) <= 80 else kname[:77] + "..." + print(f"| {i} | {tns/1e3:10.1f} | {inst:4d} | {avg_us:8.2f} | `{short}` |") + print() +PYEOF +) + { + echo "" + echo "## Kernel breakout per NVTX range (rank 0)" + echo "" + echo "${BREAKOUT}" + echo "Full CSVs:" + echo "- per-range: \`${NVTX_CSV}\`" + echo "- overall: \`${KERN_CSV}\`" + } | tee -a "${RANK0_LOG}" >> "${SUMMARY}" + fi +fi + +echo "Done. Logs in ${RESULTS}/${OUT_PREFIX}_*.txt" +echo "Summary: ${SUMMARY}" diff --git a/examples/jax/ep/ep_moe.py b/examples/jax/ep/ep_moe.py new file mode 100644 index 0000000000..a23a0b33c9 --- /dev/null +++ b/examples/jax/ep/ep_moe.py @@ -0,0 +1,396 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""End-to-end MoE example: dispatch -> batched expert linear -> combine, fwd + bwd. + +One process per GPU. Run via run_test_ep.sh. +""" + +import argparse +import sys + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from transformer_engine.jax.ep import EpLayerConfig, ep_bootstrap, ep_dispatch, ep_combine +from transformer_engine.jax.sharding import MeshResource, global_shard_guard + + +# ── Setup ─────────────────────────────────────────────────────────────────── + + +def _parse_args(): + p = argparse.ArgumentParser(description="TE-JAX EP MoE example (fwd + bwd)") + p.add_argument("--coordinator-address", required=True) + p.add_argument("--process-id", type=int, required=True) + p.add_argument("--num-processes", type=int, required=True) + p.add_argument("--num-tokens", type=int, default=8, help="Per-rank token count.") + p.add_argument("--top-k", type=int, default=2) + p.add_argument("--hidden", type=int, default=32) + p.add_argument("--hidden-out", type=int, default=32) + p.add_argument( + "--num-experts", + type=int, + default=None, + help="Total experts across the EP group. Default: num_processes.", + ) + p.add_argument("--dp-size", type=int, default=None, help="Default: num_procs // ep_size.") + p.add_argument( + "--check", + action="store_true", + default=True, + help="Verify fwd+bwd against a single-rank numpy reference.", + ) + p.add_argument( + "--iters", + type=int, + default=3, + help="Number of fwd+bwd iterations to run (same compiled jit, same handle_mem).", + ) + return p.parse_args() + + +def _distributed_init(args): + jax.distributed.initialize( + coordinator_address=args.coordinator_address, + num_processes=args.num_processes, + process_id=args.process_id, + local_device_ids=[args.process_id], + ) + assert ( + jax.local_device_count() == 1 + ), f"EP example requires 1 GPU per process; got {jax.local_device_count()}" + + +def _build_mesh_and_resource(args): + """Pick a (2, 2) mesh by default. Override via --dp-size.""" + n = args.num_processes + if n < 4: + raise ValueError(f"num_processes ({n}) must be >= 4 for NCCL EP") + if args.dp_size is None: + if n != 4: + raise ValueError( + f"default mesh expects exactly 4 ranks (got {n}); pass --dp-size to override" + ) + args.dp_size = 2 + assert n % args.dp_size == 0, f"num_processes={n} not divisible by dp_size={args.dp_size}" + args.ep_size = n // args.dp_size + if args.num_experts is None: + args.num_experts = args.num_processes + assert args.num_experts % args.ep_size == 0 + args.num_local_experts = args.num_experts // args.ep_size + args.recv_capacity_per_rank = args.ep_size * args.num_tokens * args.top_k + + devs = np.asarray(jax.devices()).reshape(args.dp_size, args.ep_size) + mesh = Mesh(devs, ("dp", "ep")) + mr = MeshResource(dp_resource="dp", ep_resource="ep") + return mesh, mr + + +def _make_routing(dp_color, num_tokens, top_k, num_experts, num_local_experts): + """Deterministic routing: topk_idx[t, k] = (dp_color*NLE + t*K + k) % E.""" + topk_idx = np.empty((num_tokens, top_k), dtype=np.int32) + for t in range(num_tokens): + for k in range(top_k): + topk_idx[t, k] = (dp_color * num_local_experts + t * top_k + k) % num_experts + return topk_idx + + +def _make_inputs(args): + """Build 3D ``[B, S, H]`` arrays sharded ``(("dp","ep"), None, None)``. + + B = num_processes (sharded across the compound (dp,ep) axis so each rank + holds one slot); S = args.num_tokens. Global numpy views (rank-0 + reference) are kept 2D for the legacy reference implementation. + """ + T, K, H, H_out = args.num_tokens, args.top_k, args.hidden, args.hidden_out + E = args.num_experts + dp_size = args.dp_size + ep_size = args.ep_size + num_procs = args.num_processes + dp_color = args.process_id // ep_size + + rng_dp = np.random.default_rng(seed=42 + dp_color) + tokens_np = (rng_dp.standard_normal((T, H), dtype=np.float32) * 0.5).astype(np.float32) + topk_idx_np = _make_routing(dp_color, T, K, E, args.num_local_experts) + w_np = np.full((T, K), 1.0 / K, dtype=np.float32) + + tokens_global_np = np.concatenate( + [ + ( + np.random.default_rng(seed=42 + c).standard_normal((T, H), dtype=np.float32) * 0.5 + ).astype(np.float32) + for c in range(dp_size) + ], + axis=0, + ) + topk_idx_global_np = np.concatenate( + [_make_routing(c, T, K, E, args.num_local_experts) for c in range(dp_size)], axis=0 + ) + w_global_np = np.full((dp_size * T, K), 1.0 / K, dtype=np.float32) + + # Same seed on every rank → identical kernel array everywhere. + rng = np.random.default_rng(seed=42) + kernels_np = (rng.standard_normal((E, H, H_out), dtype=np.float32) * (1.0 / np.sqrt(H))).astype( + np.float32 + ) + + # Each rank contributes one [1, T, ...] slab; the global shape is + # [num_procs, T, ...] sharded on the first dim across (dp, ep). + mesh = args.mesh + dpep_spec = NamedSharding(mesh, PartitionSpec(("dp", "ep"), None, None)) + tokens = jax.make_array_from_process_local_data( + dpep_spec, tokens_np[None, :, :].astype(np.float32), (num_procs, T, H) + ).astype(jnp.bfloat16) + topk_idx = jax.make_array_from_process_local_data( + dpep_spec, topk_idx_np[None, :, :], (num_procs, T, K) + ) + topk_w = jax.make_array_from_process_local_data(dpep_spec, w_np[None, :, :], (num_procs, T, K)) + kernels = jnp.asarray(kernels_np, dtype=jnp.bfloat16) + return ( + tokens_global_np, + topk_idx_global_np, + w_global_np, + kernels_np, + tokens, + topk_idx, + topk_w, + kernels, + ) + + +# ── MoE step ──────────────────────────────────────────────────────────────── + + +def _batched_expert_linear(recv_tokens, kernels, num_local_experts, dp_size, ep_size): + """Per-expert linear. ``recv_tokens`` is 3D ``[num_procs, recv_pr, H]`` + (compound (dp,ep) leading); ``kernels`` is 4D ``[ep_size, NLE, H, H_out]``, + broadcast over the dp axis. Output matches ``recv_tokens``' 3D layout + with ``H_out`` in place of ``H``.""" + num_procs, recv_pr, H = recv_tokens.shape + H_out = kernels.shape[-1] + slots_per_expert = recv_pr // num_local_experts + # [num_procs, recv_pr, H] -> [dp, ep, NLE, slots, H] + grouped = recv_tokens.reshape(dp_size, ep_size, num_local_experts, slots_per_expert, H) + # Contract H; batch over (ep, NLE) which are present on both sides. + out = jax.lax.dot_general( + grouped, + kernels.astype(grouped.dtype), + dimension_numbers=(((4,), (2,)), ((1, 2), (0, 1))), + ) + # Output dim order from dot_general: batch dims first, then remaining lhs, rhs. + # batch=(ep,NLE), lhs_remaining=(dp,slots), rhs_remaining=(H_out,) + # → shape [ep, NLE, dp, slots, H_out]. Permute to [dp, ep, NLE, slots, H_out]. + out = jnp.transpose(out, (2, 0, 1, 3, 4)) + return out.reshape(num_procs, recv_pr, H_out) + + +def _moe_step(args, topk_idx, tokens, topk_w, kernels): + """Jit'd MoE step: dispatch -> batched per-expert linear -> combine. + + Inputs are 3D ``[B, S, H]`` with the first dim compound-sharded across + ``("dp","ep")``. Combine returns the same 3D shape. + """ + B = args.num_processes + S = args.num_tokens + NLE = args.num_local_experts + dp_size, ep_size = args.dp_size, args.ep_size + mesh = args.mesh + in_spec = PartitionSpec(("dp", "ep"), None, None) # [B, S, ...] + ep3 = PartitionSpec(("dp", "ep"), None, None) # [num_procs, recv_pr, H] + ep2 = PartitionSpec(("dp", "ep"), None) # [num_procs, recv_pr] + # Kernels are EP-replicated across dp colors; shard only the ep-rank axis. + kernel_spec = PartitionSpec("ep", None, None, None) + + kernels = kernels.reshape(ep_size, NLE, *kernels.shape[1:]) + layer_cfg = EpLayerConfig(top_k=args.top_k, dispatch_output_per_expert_alignment=16) + + @jax.jit + def step(topk_idx, tokens, topk_w, local_kernels): + topk_idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(mesh, in_spec)) + tokens = jax.lax.with_sharding_constraint(tokens, NamedSharding(mesh, in_spec)) + topk_w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(mesh, in_spec)) + local_kernels = jax.lax.with_sharding_constraint( + local_kernels, NamedSharding(mesh, kernel_spec) + ) + recv_tokens, recv_topk_w, handle_mem, _tc = ep_dispatch( + layer_cfg, topk_idx, tokens, topk_w, args.recv_capacity_per_rank + ) + recv_tokens = jax.lax.with_sharding_constraint(recv_tokens, NamedSharding(mesh, ep3)) + recv_topk_w = jax.lax.with_sharding_constraint(recv_topk_w, NamedSharding(mesh, ep2)) + expert_out = _batched_expert_linear(recv_tokens, local_kernels, NLE, dp_size, ep_size) + expert_out = jax.lax.with_sharding_constraint(expert_out, NamedSharding(mesh, ep3)) + # ep_combine is unweighted: pre-multiply by recv_topk_w and zero + # padded slots (recv_topk_w == 0) before the scatter-sum. + mask = (recv_topk_w != 0).astype(jnp.float32)[..., None] + weighted = (expert_out.astype(jnp.float32) * recv_topk_w[..., None] * mask).astype( + expert_out.dtype + ) + weighted = jax.lax.with_sharding_constraint(weighted, NamedSharding(mesh, ep3)) + return ep_combine( + layer_cfg, + handle_mem, + _tc, + weighted, + num_local_tokens=(B, S), + out_sharding=(("dp", "ep"), None, None), + ) + + return step(topk_idx, tokens, topk_w, kernels) + + +# ── Reference (numerical check) ───────────────────────────────────────────── + + +def _reference_moe(tokens, topk_idx, topk_w, kernels): + """Single-rank dense MoE reference. tokens [T, H], output [T, H_out].""" + T, K = topk_idx.shape + H_out = kernels.shape[-1] + out = np.zeros((T, H_out), dtype=np.float32) + for t in range(T): + tok = tokens[t].astype(np.float32) + for k in range(K): + e = int(topk_idx[t, k]) + out[t] += float(topk_w[t, k]) * (tok @ kernels[e].astype(np.float32)) + return out + + +def _reference_grad(tokens, topk_idx, topk_w, kernels): + """d/dtokens of 0.5 * sum(ref_out**2) — used by --check to validate bwd.""" + T, K = topk_idx.shape + H = tokens.shape[-1] + ref_out = _reference_moe(tokens, topk_idx, topk_w, kernels) + grad = np.zeros((T, H), dtype=np.float32) + for t in range(T): + mixed = np.zeros_like(kernels[0]) + for k in range(K): + mixed = mixed + float(topk_w[t, k]) * kernels[int(topk_idx[t, k])] + grad[t] = ref_out[t] @ mixed.T + return ref_out, grad + + +# ── Main ──────────────────────────────────────────────────────────────────── + + +def main(): + args = _parse_args() + _distributed_init(args) + + dev = jax.local_devices()[0] + cap = getattr(dev, "compute_capability", None) + if cap is not None: + major, minor = (int(x) for x in str(cap).split(".")) + if major * 10 + minor < 90: + print(f"[ep_moe] SKIPPED: NCCL EP requires SM>=90 (got SM{major}{minor})") + return + + args.mesh, args.mr = _build_mesh_and_resource(args) + + with args.mesh, global_shard_guard(args.mr): + ep_bootstrap( + world_size=args.num_processes, + rank=args.process_id, + num_experts=args.num_experts, + max_tokens_per_rank=args.num_tokens, + recv_capacity_per_rank=args.recv_capacity_per_rank, + hidden_dim=args.hidden, + ) + + ( + tokens_global_np, + topk_idx_global_np, + w_global_np, + kernels_np, + tokens, + topk_idx, + topk_w, + kernels, + ) = _make_inputs(args) + + def loss_fn(toks, idx, w, kern): + out = _moe_step(args, idx, toks, w, kern) + return 0.5 * (out.astype(jnp.float32) ** 2).sum(), out + + step_jit = jax.jit(jax.value_and_grad(loss_fn, has_aux=True)) + + # Same jit + same inputs each iter: handle_mem cache must give identical loss/grad. + for it in range(args.iters): + (loss, out_fwd), grad_tokens = step_jit(tokens, topk_idx, topk_w, kernels) + grad_tokens.block_until_ready() + out_fwd.block_until_ready() + if args.process_id == 0: + print( + f"[ep_moe] iter={it} loss={float(loss):.4f}" + f" grad_tokens.shape={grad_tokens.shape}" + f" dp={args.dp_size} ep={args.ep_size}" + f" num_experts={args.num_experts} recv_pr={args.recv_capacity_per_rank}" + ) + + if args.check: + + def _norm(spec, ndim): + return tuple(spec) + (None,) * (ndim - len(spec)) + + # JAX may collapse a size-1 mesh axis: when dp_size==1 the spec can + # appear as ``(("dp","ep"),...)`` or ``("ep",...)``. Accept both. + if args.dp_size > 1: + acceptable_specs = ((("dp", "ep"), None, None),) + else: + acceptable_specs = ((("dp", "ep"), None, None), ("ep", None, None)) + assert ( + _norm(out_fwd.sharding.spec, out_fwd.ndim) in acceptable_specs + ), f"out_fwd.sharding.spec={out_fwd.sharding.spec} (expected one of {acceptable_specs})" + assert _norm(grad_tokens.sharding.spec, grad_tokens.ndim) in acceptable_specs, ( + f"grad_tokens.sharding.spec={grad_tokens.sharding.spec}" + f" (expected one of {acceptable_specs})" + ) + + replicated = NamedSharding(args.mesh, jax.sharding.PartitionSpec()) + out_global = jax.jit(lambda x: jax.lax.with_sharding_constraint(x, replicated))(out_fwd) + grad_global = jax.jit(lambda x: jax.lax.with_sharding_constraint(x, replicated))( + grad_tokens + ) + out_global.block_until_ready() + grad_global.block_until_ready() + + ref_out, ref_grad = _reference_grad( + tokens_global_np, topk_idx_global_np, w_global_np, kernels_np + ) + # 3D global ``[num_procs, S, H]`` with num_procs = dp * ep. Each EP + # column in a DP color sees identical inputs (and produces identical + # outputs), so collapse the ep dim to one replica before flattening + # to 2D against the dp-only reference. + dp_size, ep_size = args.dp_size, args.ep_size + global_out = ( + np.asarray(out_global.addressable_shards[0].data.astype(jnp.float32)) + .reshape(dp_size, ep_size, -1, ref_out.shape[-1])[:, 0] + .reshape(-1, ref_out.shape[-1]) + ) + global_grad = ( + np.asarray(grad_global.addressable_shards[0].data.astype(jnp.float32)) + .reshape(dp_size, ep_size, -1, ref_grad.shape[-1])[:, 0] + .reshape(-1, ref_grad.shape[-1]) + ) + np.testing.assert_allclose( + global_out, + ref_out, + rtol=5e-2, + atol=5e-2, + err_msg=f"rank {args.process_id}: fwd mismatch", + ) + np.testing.assert_allclose( + global_grad, + ref_grad, + rtol=5e-2, + atol=5e-2, + err_msg=f"rank {args.process_id}: bwd mismatch", + ) + if args.process_id == 0: + print(f"[ep_moe] --check PASSED (ref_out.sum()={float(ref_out.sum()):.4f})") + + +if __name__ == "__main__": + main() + sys.exit(0) diff --git a/examples/jax/ep/run_test_ep.sh b/examples/jax/ep/run_test_ep.sh new file mode 100755 index 0000000000..55b958f146 --- /dev/null +++ b/examples/jax/ep/run_test_ep.sh @@ -0,0 +1,85 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +#!/bin/bash + +NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} + +if [ "${NUM_GPUS}" -lt 4 ]; then + echo "NCCL EP requires at least 4 GPUs (found ${NUM_GPUS}); SKIPPING." + exit 0 +fi +# Default mesh is (2, 2); use exactly 4 ranks even on larger boxes. +NUM_GPUS="${NVTE_EP_NUM_RANKS:-4}" + +: ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" + +# NCCL EP requires NVLink P2P among ranks on the node. +echo "*** Checking NVLINK support ***" +NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1) +NVLINK_EXIT_CODE=$? +if [ $NVLINK_EXIT_CODE -ne 0 ] || [[ "$NVLINK_OUTPUT" == *"not supported"* ]] \ + || [[ "$NVLINK_OUTPUT" == *"No devices"* ]] || [ -z "$NVLINK_OUTPUT" ]; then + echo "NVLINK is not supported on this platform — EP example requires NVLINK; SKIPPING" + exit 0 +fi +echo "NVLINK support detected" + +SCRIPT="$TE_PATH/examples/jax/ep/ep_moe.py" +export PYTHONPATH="${TE_PATH}${PYTHONPATH:+:${PYTHONPATH}}" +COORD="${COORD:-127.0.0.1:12345}" +TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-300}" + +XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_graph_min_graph_size=1" +export XLA_FLAGS="${XLA_BASE_FLAGS}" + +# Stage NCCL EP JIT cubins on tmpfs to keep build/iteration fast. +: ${NCCL_EP_JIT_CACHE_DIR:="${TMPDIR:-/tmp}/nccl_ep_jit_cache_$(id -u)"} +export NCCL_EP_JIT_CACHE_DIR +mkdir -p "$NCCL_EP_JIT_CACHE_DIR" + +echo +echo "*** Executing ep_moe.py across $NUM_GPUS GPUs ***" + +PIDS=() +cleanup() { + for pid in "${PIDS[@]}"; do + kill -0 "$pid" 2>/dev/null && kill -KILL "$pid" 2>/dev/null || true + done +} +trap cleanup EXIT INT TERM + +EXTRA_ARGS=${EXTRA_ARGS:-"--check"} + +for ((i=1; i "stdout_rank_${i}.txt" 2>&1 & + PIDS+=($!) +done +timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ + python -u "$SCRIPT" \ + --coordinator-address "$COORD" --process-id "0" --num-processes "$NUM_GPUS" \ + $EXTRA_ARGS 2>&1 | tee stdout_rank_0.txt +wait + +HAS_FAILURE=0 +if grep -qE "FAILED|Traceback|ERROR" stdout_rank_0.txt; then + echo "... ep_moe FAILED" + HAS_FAILURE=1 +elif ! grep -qE "\[ep_moe\]" stdout_rank_0.txt; then + echo "... ep_moe INVALID (rank 0 produced no summary line)" + for ((i=1; i/dev/null + done + HAS_FAILURE=1 +else + echo "... ep_moe PASSED" +fi +rm -f stdout_rank_*.txt +exit $HAS_FAILURE diff --git a/qa/L1_cpp_distributed/test.sh b/qa/L1_cpp_distributed/test.sh index 8d767a4efb..7e5ce2cf0d 100755 --- a/qa/L1_cpp_distributed/test.sh +++ b/qa/L1_cpp_distributed/test.sh @@ -14,4 +14,7 @@ if [[ $(nvidia-smi --list-gpus | wc -l) -ge 4 ]]; then cmake -GNinja -S. -Bbuild cmake --build build mpirun --allow-run-as-root --np 4 --oversubscribe ./build/test_comm_gemm + + # EP suites; runner self-skips on pre-Hopper GPUs. + bash ./run_test_ep.sh 4 ./build fi diff --git a/setup.py b/setup.py index ec277b6349..551faf8e83 100644 --- a/setup.py +++ b/setup.py @@ -83,6 +83,24 @@ def setup_common_extension() -> CMakeExtension: cusolvermp_dir = os.getenv("CUSOLVERMP_HOME", "/usr") cmake_flags.append(f"-DCUSOLVERMP_DIR={cusolvermp_dir}") + # NCCL EP (Hopper+): on by default; auto-skipped when no arch >= 90 is + # targeted. Set NVTE_BUILD_WITH_NCCL_EP=0 to force off. + build_with_nccl_ep = bool(int(os.getenv("NVTE_BUILD_WITH_NCCL_EP", "1"))) + if build_with_nccl_ep: + arch_tokens = [a.strip() for a in str(archs or "").split(";") if a.strip()] + has_hopper_or_newer = any( + t.lower() == "native" or (t.rstrip("af").isdigit() and int(t.rstrip("af")) >= 90) + for t in arch_tokens + ) + if not has_hopper_or_newer: + print(f"[NCCL EP] No arch >= 90 in NVTE_CUDA_ARCHS ('{archs}'); skipping build.") + build_with_nccl_ep = False + if build_with_nccl_ep: + nccl_home = build_nccl_ep_submodule() + cmake_flags.append(f"-DNCCL_INCLUDE_DIR={nccl_home}/include") + else: + cmake_flags.append("-DNVTE_WITH_NCCL_EP=OFF") + # Add custom CMake arguments from environment variable nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") if nvte_cmake_extra_args: @@ -128,6 +146,104 @@ def setup_requirements() -> Tuple[List[str], List[str]]: return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]] +def _discover_nccl_home() -> str: + """Resolve NCCL_HOME: honor env var, else probe well-known prefixes, else ldconfig.""" + env_home = os.environ.get("NCCL_HOME") + if env_home: + if (Path(env_home) / "include" / "nccl.h").exists(): + return env_home + print( + f"[NCCL EP] WARNING: NCCL_HOME='{env_home}' is set but " + f"'{env_home}/include/nccl.h' was not found; falling back to system probes." + ) + + lib_names = ("libnccl.so", "libnccl.so.2") + # Include Debian/Ubuntu multiarch subdirs (e.g. lib/aarch64-linux-gnu). + lib_subdirs = ("lib", "lib64", "lib/aarch64-linux-gnu", "lib/x86_64-linux-gnu") + for cand in ("/opt/nvidia/nccl", "/usr/local/nccl", "/usr"): + p = Path(cand) + if (p / "include" / "nccl.h").exists() and any( + (p / sub / name).exists() for sub in lib_subdirs for name in lib_names + ): + return str(p) + + try: + out = subprocess.check_output(["ldconfig", "-p"], stderr=subprocess.DEVNULL).decode() + for line in out.splitlines(): + if "libnccl.so" in line and "=>" in line: + lib_path = Path(line.split("=>")[-1].strip()) + # Walk upward so multiarch layouts (.../lib//libnccl.so) + # resolve to the prefix that contains include/nccl.h. + for root in (lib_path.parent.parent, lib_path.parent.parent.parent): + if (root / "include" / "nccl.h").exists(): + return str(root) + except (subprocess.CalledProcessError, FileNotFoundError): + pass + + raise RuntimeError( + "Could not locate NCCL core (nccl.h + libnccl.so). Set NCCL_HOME to the install prefix." + ) + + +def build_nccl_ep_submodule() -> str: + """Build libnccl_ep.so from the 3rdparty/nccl submodule. + + Returns the discovered NCCL core install prefix (the path that contains + include/nccl.h and lib/libnccl.so), which the caller passes to CMake as + NCCL_INCLUDE_DIR for TE's own NCCL link. + """ + nccl_root = current_file_path / "3rdparty" / "nccl" + if not (nccl_root / "Makefile").exists(): + raise RuntimeError( + f"NCCL submodule not found at {nccl_root}. " + "Run `git submodule update --init --recursive`." + ) + + build_dir = nccl_root / "build" + nccl_ep_lib = build_dir / "lib" / "libnccl_ep.so" + + # Caller gates on arch >= 90 or "native"; let nvcc resolve "native". + arch_tokens = [a.strip() for a in str(cuda_archs() or "").split(";") if a.strip()] + if any(t.lower() == "native" for t in arch_tokens): + gencode = "-arch=native" + else: + arch_list = [ + t.rstrip("af") + for t in arch_tokens + if t.rstrip("af").isdigit() and int(t.rstrip("af")) >= 90 + ] + gencode = " ".join(f"-gencode=arch=compute_{a},code=sm_{a}" for a in arch_list) + + nproc = os.cpu_count() or 8 + env = os.environ.copy() + env["NVCC_GENCODE"] = gencode + # NCCL EP needs the core NCCL headers + libnccl.so; write NCCL EP build + # outputs to the submodule's local build/ tree. + nccl_home = _discover_nccl_home() + env["NCCL_HOME"] = nccl_home + env["NCCL_EP_BUILDDIR"] = str(build_dir) + + if not nccl_ep_lib.exists(): + print(f"[NCCL EP] Building libnccl_ep.so (gencode='{gencode}')") + subprocess.check_call( + ["make", "-j", str(nproc), "-C", "contrib/nccl_ep", "lib"], + cwd=str(nccl_root), + env=env, + ) + + # Stage libnccl_ep.so.0 alongside libtransformer_engine.so so $ORIGIN-rpath + # finds it in the installed wheel. + soname = "libnccl_ep.so.0" + src = (build_dir / "lib" / soname).resolve() + dst = current_file_path / "transformer_engine" / soname + if dst.is_symlink() or dst.exists(): + dst.unlink() + shutil.copy2(src, dst) + print(f"[NCCL EP] Bundled {dst} ({src.stat().st_size // (1 << 20)} MB)") + + return nccl_home + + def git_check_submodules() -> None: """ Attempt to checkout git submodules automatically during setup. @@ -208,7 +324,8 @@ def git_check_submodules() -> None: else: install_requires, test_requires = setup_requirements() ext_modules = [setup_common_extension()] - package_data = {"": ["VERSION.txt"]} + # libnccl_ep.so.0 is staged by build_nccl_ep_submodule(); ship it. + package_data = {"": ["VERSION.txt"], "transformer_engine": ["libnccl_ep.so*"]} include_package_data = True extras_require = {"test": test_requires} diff --git a/tests/cpp_distributed/CMakeLists.txt b/tests/cpp_distributed/CMakeLists.txt index 44ad7c7384..e65c298e15 100644 --- a/tests/cpp_distributed/CMakeLists.txt +++ b/tests/cpp_distributed/CMakeLists.txt @@ -55,9 +55,18 @@ target_include_directories(test_comm_gemm PRIVATE ${test_comm_gemm_INCLUDES}) find_package(CUDAToolkit REQUIRED) find_package(OpenMP REQUIRED) find_package(MPI REQUIRED) + +# -- NCCL core ---------------------------------------------------------------- +# Pass -DNCCL_INCLUDE_DIR=/include; falls back to well-known prefixes. +find_path(NCCL_INCLUDE_DIR nccl.h + HINTS /opt/nvidia/nccl/include /usr/local/nccl/include) +if(NOT NCCL_INCLUDE_DIR) + message(FATAL_ERROR + "nccl.h not found. Pass -DNCCL_INCLUDE_DIR=/include.") +endif() find_library(NCCL_LIB NAMES nccl libnccl - PATH_SUFFIXES lib + PATH_SUFFIXES lib lib64 REQUIRED) list(APPEND test_comm_gemm_LINKER_LIBS CUDA::cuda_driver @@ -74,3 +83,37 @@ target_compile_options(test_comm_gemm PRIVATE -O2 -fopenmp) include(GoogleTest) gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) + +# -- EP distributed tests ------------------------------------------------------ +# Launched via mpirun; ncclUniqueId exchange uses MPI_Bcast (see test_ep_common.h). +# The test binary only uses NCCL core symbols (ncclMemAlloc, ncclCommWindow*); +# all ncclEp* calls live behind TE's public , which +# resolves libnccl_ep.so via dlopen in libtransformer_engine.so itself. +message(STATUS "EP test: NCCL headers: ${NCCL_INCLUDE_DIR}") +set(EP_TEST_COMMON_INCLUDES + ${NCCL_INCLUDE_DIR} + ${MPI_CXX_INCLUDE_PATH} + ../../transformer_engine/common/include + ../../transformer_engine/common + ${CMAKE_CURRENT_SOURCE_DIR}) + +# nvrtc must follow TE_LIB so symbols referenced from libtransformer_engine.so +# (loaded via dlopen in Python; not in its DT_NEEDED) resolve through nvrtc. +set(EP_TEST_COMMON_LIBS + CUDA::cuda_driver + CUDA::cudart + GTest::gtest + ${TE_LIB} + CUDA::nvrtc + ${NCCL_LIB} + MPI::MPI_CXX + OpenMP::OpenMP_CXX) + +# -- EP distributed tests (per-op + full pipeline + zero-copy symm) ----------- +add_executable(test_ep test_ep.cu ../cpp/test_common.cu) +target_include_directories(test_ep PRIVATE ${EP_TEST_COMMON_INCLUDES}) +target_link_libraries(test_ep PUBLIC ${EP_TEST_COMMON_LIBS}) + +# Do NOT use gtest_discover_tests - these binaries require multi-process +# launch via run_test_ep.sh, not direct single-process execution. +message(STATUS "EP distributed tests enabled (TE backend dlopens libnccl_ep.so)") diff --git a/tests/cpp_distributed/run_test_ep.sh b/tests/cpp_distributed/run_test_ep.sh new file mode 100755 index 0000000000..1c4432531c --- /dev/null +++ b/tests/cpp_distributed/run_test_ep.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Run TE EP distributed unit tests via mpirun. Each MPI rank pins to one GPU +# (rank % device_count) and exchanges ncclUniqueId through MPI_Bcast. +# +# Usage: +# bash run_test_ep.sh [num_gpus] [build_dir] +# +# Defaults: +# num_gpus = number of GPUs visible to nvidia-smi +# build_dir = /build +# +# Environment variables: +# GTEST_FILTER - forwarded to all processes (e.g., "EPPipelineTest.*") +# MPIRUN - override the mpirun binary (default: mpirun) +# MPIRUN_EXTRA - extra flags forwarded to mpirun + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BUILD_DIR="${2:-${SCRIPT_DIR}/build}" +NUM_GPUS="${1:-$(nvidia-smi -L 2>/dev/null | wc -l)}" +MPIRUN="${MPIRUN:-mpirun}" + +# Skip cleanly on pre-Hopper: NCCL EP requires SM>=90. +MIN_SM=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \ + | awk -F. 'NR==1 || ($1*10+$2) 0 && MIN_SM < 90 )); then + echo "NCCL EP requires SM>=90 (lowest visible GPU is SM${MIN_SM}); SKIPPING." + exit 0 +fi + +TEST_BIN="${BUILD_DIR}/test_ep" +if [[ ! -x "${TEST_BIN}" ]]; then + echo "ERROR: binary not found: ${TEST_BIN}" + echo "Build: cd ${SCRIPT_DIR} && mkdir -p build && cd build && cmake .. && make" + exit 1 +fi + +if (( NUM_GPUS < 2 )); then + echo "EP Tests: requires at least 2 GPUs, found ${NUM_GPUS}. Skipping." + exit 0 +fi + +GTEST_ARGS="${GTEST_FILTER:+--gtest_filter=${GTEST_FILTER}}" + +echo "=== EP Tests ===" +echo " GPUs: ${NUM_GPUS} Binary: ${TEST_BIN}" +echo + +"${MPIRUN}" -n "${NUM_GPUS}" ${MPIRUN_EXTRA:-} "${TEST_BIN}" ${GTEST_ARGS} diff --git a/tests/cpp_distributed/test_ep.cu b/tests/cpp_distributed/test_ep.cu new file mode 100644 index 0000000000..1a67644d06 --- /dev/null +++ b/tests/cpp_distributed/test_ep.cu @@ -0,0 +1,808 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * EP pipeline tests: smallest-scope first. + * + * EPDispatchTest/PrepareAndDispatch : exact recv values + per-expert counts + * EPCombineTest/Combine : round-trip: out == top_k * tokens + * EPCombineBwdTest/CombineBwdCheck : exact grad_expert values + * EPDispatchBwdTest/DispatchBwdCheck : exact grad_tokens + * EPDispatchBwdGradWeightsTest/RoundTrip : exact per-(t, k) grad_topk_weights + * EPPipelineTest/FullForwardBackward : fwd + bwd NaN/Inf check + * + * Routing: token t on rank r -> expert (r * num_local_experts + t * top_k + k) % num_experts + * Token values: rank r, token t -> all hidden dims = (r+1)*0.01 + t*0.001 + * + * Closed-form expected values: + * dispatch recv: multiset of source-token values routed to this rank's experts + * combine: result[t] == top_k * tokens[t] + * combine_bwd: grad_expert[slot] == d_result[t] (no weighting) + * dispatch_bwd: grad_tokens[t] == top_k * d_result[t] + */ + +#include "test_ep_common.h" + +#include +#include +#include +#include + +// -- Deterministic routing helpers --------------------------------------------- + +// Token value for (rank, t): (rank * num_tokens + t + 1) / 256. Step 1/256 is +// bf16-exact and unique across (rank, t) when rank * num_tokens + t < 256. +static inline float token_value(int rank, int t, int num_tokens) { + return static_cast(rank * num_tokens + t + 1) * (1.0f / 256.0f); +} + +// Per-element host-side conversion helpers used by templated test code. +inline float tok_to_float(nv_bfloat16 v) { return __bfloat162float(v); } +inline float tok_to_float(__half v) { return __half2float(v); } +inline float tok_to_float(float v) { return v; } + +template T tok_from_float(float v); +template <> inline nv_bfloat16 tok_from_float(float v) { return __float2bfloat16(v); } +template <> inline __half tok_from_float<__half> (float v) { return __float2half(v); } +template <> inline float tok_from_float (float v) { return v; } + +template +static std::vector generate_tokens(int rank, int num_tokens, int hidden_dim) { + std::vector v(num_tokens * hidden_dim); + for (int t = 0; t < num_tokens; ++t) { + T val = tok_from_float(token_value(rank, t, num_tokens)); + for (int h = 0; h < hidden_dim; ++h) + v[t * hidden_dim + h] = val; + } + return v; +} + +static std::vector expected_token_counts( + int recv_rank, int num_processes, int num_tokens, int top_k, + int num_experts, int num_local_experts) { + int base = recv_rank * num_local_experts; + std::vector cnt(num_local_experts, 0); + for (int src = 0; src < num_processes; ++src) { + auto idx = routing_balanced(src, num_tokens, top_k, num_experts, num_local_experts); + for (int t = 0; t < num_tokens; ++t) + for (int k = 0; k < top_k; ++k) { + int64_t e = idx[t * top_k + k]; + if (e >= base && e < base + num_local_experts) ++cnt[e - base]; + } + } + return cnt; +} + +static std::vector expected_recv_values_sorted( + int recv_rank, int num_processes, int num_tokens, int top_k, + int num_experts, int num_local_experts) { + int base = recv_rank * num_local_experts; + std::vector vals; + for (int src = 0; src < num_processes; ++src) { + auto idx = routing_balanced(src, num_tokens, top_k, num_experts, num_local_experts); + for (int t = 0; t < num_tokens; ++t) + for (int k = 0; k < top_k; ++k) { + int64_t e = idx[t * top_k + k]; + if (e >= base && e < base + num_local_experts) { + float raw = token_value(src, t, num_tokens); + vals.push_back(__bfloat162float(__float2bfloat16(raw))); + } + } + } + std::sort(vals.begin(), vals.end()); + return vals; +} + +// 2^-5 relative tolerance for BF16 (matches mantissa precision with margin), +// plus a small atol floor for near-zero expected values. +static constexpr float kBf16Rtol = 1.0f / 32.0f; +static constexpr float kBf16Atol = 1e-3f; +static float bf16_tol(float magnitude) { + return kBf16Atol + kBf16Rtol * std::fabs(magnitude); +} + +template +static bool check_no_nan_inf(const T* dev, int count, const char* name) { + std::vector h(count); + cudaMemcpy(h.data(), dev, count * sizeof(T), cudaMemcpyDeviceToHost); + for (int i = 0; i < count; ++i) { + float v = tok_to_float(h[i]); + if (std::isnan(v) || std::isinf(v)) { + fprintf(stderr, "Rank %d: %s in %s[%d]\n", + g_process_id, std::isnan(v) ? "NaN" : "Inf", name, i); + return false; + } + } + return true; +} + +// -- Forward buffer set with RAII ---------------------------------------------- + +template +struct EPBuffers { + // Forward + DevBuf topk_idx; + DevBuf topk_weights; + DevBuf tokens; + DevBuf token_counts; + DevBuf handle_mem; + DevBuf recv_tokens; + DevBuf recv_topk_weights; + DevBuf result; + // Backward + DevBuf grad_result; + DevBuf grad_expert; + DevBuf grad_tokens; + DevBuf g_recv_topk_weights; + DevBuf grad_topk_weights; + + size_t handle_mem_size = 0; + size_t recv_capacity = 0; + int top_k_ = 0; + size_t alignment_ = 0; + + void alloc(int num_tokens, int top_k, int hidden_dim, int num_local_experts, + int ep_size, int max_tokens_per_rank, size_t alignment = 0) { + top_k_ = top_k; + alignment_ = alignment; + recv_capacity = static_cast(ep_size) * max_tokens_per_rank * 2; + + topk_idx.alloc(num_tokens * top_k); + topk_weights.alloc(num_tokens * top_k); + tokens.alloc(num_tokens * hidden_dim); + token_counts.alloc(num_local_experts); + recv_tokens.alloc(recv_capacity * hidden_dim); + recv_topk_weights.alloc(recv_capacity); + result.alloc(num_tokens * hidden_dim); + + handle_mem_size = nvte_ep_handle_mem_size(NVTEEpLayerConfig{top_k, alignment}); + handle_mem.alloc(handle_mem_size); + + grad_result.alloc(num_tokens * hidden_dim); + grad_expert.alloc(recv_capacity * hidden_dim); + grad_tokens.alloc(num_tokens * hidden_dim); + g_recv_topk_weights.alloc(recv_capacity); + grad_topk_weights.alloc(num_tokens * top_k); + } +}; + +// Bundled NVTETensor views over an EPBuffers, with the shapes the EP C API +// expects. +template +struct EPTensors { + TensorWrapper topk_idx, topk_weights, token_counts, handle_mem, tokens; + TensorWrapper recv_tokens, recv_topk_weights, result; + TensorWrapper grad_result, grad_expert, grad_tokens; + TensorWrapper g_recv_topk_weights, grad_topk_weights; + + int top_k_ = 0; + size_t alignment_ = 0; + + EPTensors(EPBuffers& b, int num_tokens, int top_k, int hidden_dim, + int num_local_experts) { + top_k_ = top_k; + alignment_ = b.alignment_; + constexpr DType kTokDType = test::TypeInfo::dtype; + using Shape = std::vector; + topk_idx = TensorWrapper(b.topk_idx.get(), + Shape{(size_t)num_tokens, (size_t)top_k}, DType::kInt64); + topk_weights = TensorWrapper(b.topk_weights.get(), + Shape{(size_t)num_tokens, (size_t)top_k}, DType::kFloat32); + token_counts = TensorWrapper(b.token_counts.get(), + Shape{(size_t)num_local_experts}, DType::kInt32); + handle_mem = TensorWrapper(b.handle_mem.get(), + Shape{b.handle_mem_size}, DType::kByte); + tokens = TensorWrapper(b.tokens.get(), + Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType); + recv_tokens = TensorWrapper(b.recv_tokens.get(), + Shape{b.recv_capacity, (size_t)hidden_dim}, kTokDType); + recv_topk_weights = TensorWrapper(b.recv_topk_weights.get(), + Shape{b.recv_capacity}, DType::kFloat32); + result = TensorWrapper(b.result.get(), + Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType); + grad_result = TensorWrapper(b.grad_result.get(), + Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType); + grad_expert = TensorWrapper(b.grad_expert.get(), + Shape{b.recv_capacity, (size_t)hidden_dim}, kTokDType); + grad_tokens = TensorWrapper(b.grad_tokens.get(), + Shape{(size_t)num_tokens, (size_t)hidden_dim}, kTokDType); + g_recv_topk_weights = TensorWrapper(b.g_recv_topk_weights.get(), + Shape{b.recv_capacity}, DType::kFloat32); + grad_topk_weights = TensorWrapper(b.grad_topk_weights.get(), + Shape{(size_t)num_tokens, (size_t)top_k}, DType::kFloat32); + } +}; + +// -- Shared fixture base ------------------------------------------------------- + +class EpOpTestBase : public ::testing::Test { + protected: + int ep_size_, num_experts_, num_local_experts_, hidden_dim_; + int max_tokens_per_rank_, top_k_, num_tokens_; + + void SetUp() override { + if (g_sm_major < 9) + GTEST_SKIP() << "EP requires SM_90+ (device is SM_" << g_sm_major << "0)"; + ASSERT_GE(g_num_processes, 2); + ASSERT_TRUE(g_ep_initialized); + + ep_size_ = g_ep_size; + num_experts_ = g_num_experts; + num_local_experts_ = num_experts_ / ep_size_; + hidden_dim_ = g_hidden_dim; + max_tokens_per_rank_ = g_max_tokens_per_rank; + top_k_ = 2; + num_tokens_ = 32; + } + + template + void upload_inputs(EPBuffers& buf, int rank = -1) { + if (rank < 0) rank = g_process_id; + auto h_idx = routing_balanced(rank, num_tokens_, top_k_, + num_experts_, num_local_experts_); + std::vector h_w(num_tokens_ * top_k_, 1.0f / top_k_); + auto h_tok = generate_tokens(rank, num_tokens_, hidden_dim_); + + NVTE_CHECK_CUDA(cudaMemcpy(buf.topk_idx.get(), h_idx.data(), + h_idx.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(buf.tokens.get(), h_tok.data(), + h_tok.size() * sizeof(T), cudaMemcpyHostToDevice)); + } + + // NVTE_CHECK_CUDA (fprintf+exit) so this non-void helper stays legal. + template + int read_total_recv(const EPBuffers& buf) const { + std::vector cnt(num_local_experts_); + NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(), + num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); + int total = 0; + for (int c : cnt) total += c; + return total; + } +}; + +// ============================================================================= +// EPDispatchTest: exact recv values and per-expert counts. +// ============================================================================= + +class EPDispatchTest : public EpOpTestBase {}; + +TEST_F(EPDispatchTest, PrepareAndDispatch) { + EPBuffers<> buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + NVTE_CHECK_CUDA(cudaMemset(buf.recv_tokens.get(), 0, buf.recv_tokens.bytes())); + + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + // 1. Per-expert counts. + std::vector got_counts(num_local_experts_); + NVTE_CHECK_CUDA(cudaMemcpy(got_counts.data(), buf.token_counts.get(), + num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); + auto exp_counts = expected_token_counts(g_process_id, g_num_processes, num_tokens_, top_k_, + num_experts_, num_local_experts_); + int total_recv = 0; + for (int i = 0; i < num_local_experts_; ++i) { + EXPECT_EQ(got_counts[i], exp_counts[i]) << "local expert " << i; + total_recv += exp_counts[i]; + } + ASSERT_LE(total_recv, static_cast(buf.recv_capacity)) + << "total_recv exceeded recv_capacity; overflow would corrupt downstream memory"; + + // 2. Recv values: read only the filled prefix per local-expert zone, not the + // whole recv buffer; avoids false positives from legitimate-zero token values. + std::vector h_recv(buf.recv_capacity * hidden_dim_); + NVTE_CHECK_CUDA(cudaMemcpy(h_recv.data(), buf.recv_tokens.get(), + h_recv.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + std::vector got_vals; + got_vals.reserve(total_recv); + size_t slot = 0; + for (int e = 0; e < num_local_experts_; ++e) { + for (int i = 0; i < got_counts[e]; ++i) { + got_vals.push_back(__bfloat162float(h_recv[slot * hidden_dim_])); + ++slot; + } + } + std::sort(got_vals.begin(), got_vals.end()); + + auto exp_vals = expected_recv_values_sorted(g_process_id, g_num_processes, num_tokens_, + top_k_, num_experts_, num_local_experts_); + + ASSERT_EQ(got_vals.size(), exp_vals.size()); + for (size_t i = 0; i < exp_vals.size(); ++i) + EXPECT_NEAR(got_vals[i], exp_vals[i], bf16_tol(exp_vals[i])) + << "recv value mismatch at sorted index " << i; + + // 3. recv_topk_weights: every filled slot must equal the per-token weight (1/top_k). + std::vector h_w(buf.recv_capacity); + NVTE_CHECK_CUDA(cudaMemcpy(h_w.data(), buf.recv_topk_weights.get(), + h_w.size() * sizeof(float), cudaMemcpyDeviceToHost)); + const float exp_w = 1.0f / static_cast(top_k_); + for (int i = 0; i < total_recv; ++i) + EXPECT_NEAR(h_w[i], exp_w, 1e-6f) << "recv_topk_weights[" << i << "]"; + + if (g_process_id == 0) + printf(" PrepareAndDispatch: passed (recv=%d, values + weights exact)\n", total_recv); + + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPCombineTest: round-trip identity expert -> result == top_k * tokens. +// ============================================================================= + +class EPCombineTest : public EpOpTestBase {}; + +TEST_F(EPCombineTest, Combine) { + EPBuffers<> buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(t.handle_mem.data(), t.recv_tokens.data(), NVTECommWindow{}, + t.result.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_result(num_tokens_ * hidden_dim_); + NVTE_CHECK_CUDA(cudaMemcpy(h_result.data(), buf.result.get(), + h_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + for (int tok = 0; tok < num_tokens_; ++tok) { + float exp = __bfloat162float(h_tok[tok * hidden_dim_]) * static_cast(top_k_); + for (int p = 0; p < hidden_dim_; ++p) { + float got = __bfloat162float(h_result[tok * hidden_dim_ + p]); + EXPECT_NEAR(got, exp, bf16_tol(exp)) + << "token " << tok << " rank " << g_process_id << " hidden " << p; + } + } + + if (g_process_id == 0) + printf(" Combine: passed (result == top_k * tokens)\n"); + + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPCombineBwdTest: filled slots in grad_expert == d_result (unweighted). +// ============================================================================= + +class EPCombineBwdTest : public EpOpTestBase {}; + +TEST_F(EPCombineBwdTest, CombineBwdCheck) { + EPBuffers<> buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(t.handle_mem.data(), t.recv_tokens.data(), NVTECommWindow{}, + t.result.data(), stream)); + + std::vector h_grad_r(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad_r.data(), + h_grad_r.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(t.handle_mem.data(), t.grad_result.data(), NVTECommWindow{}, + t.grad_expert.data(), NVTECommWindow{}, stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + int total_recv = read_total_recv(buf); + + std::vector cnt(num_local_experts_); + NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(), + num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); + std::vector h_ge(buf.recv_capacity * hidden_dim_); + NVTE_CHECK_CUDA(cudaMemcpy(h_ge.data(), buf.grad_expert.get(), + h_ge.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + // Walk filled slots by per-expert zone (no v != 0 heuristic). + const float kExpGrad = 0.1f; + size_t slot = 0; + int filled = 0; + for (int e = 0; e < num_local_experts_; ++e) { + for (int i = 0; i < cnt[e]; ++i) { + for (int p = 0; p < hidden_dim_; ++p) { + float v = __bfloat162float(h_ge[slot * hidden_dim_ + p]); + EXPECT_NEAR(v, kExpGrad, bf16_tol(kExpGrad)) + << "grad_expert expert " << e << " slot " << i + << " (linear " << slot << ") hidden " << p; + } + ++filled; ++slot; + } + } + EXPECT_EQ(filled, total_recv); + + if (g_process_id == 0) + printf(" CombineBwdCheck: passed (filled=%d)\n", filled); + + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPDispatchBwdTest: grad_tokens == top_k * d_result. +// ============================================================================= + +class EPDispatchBwdTest : public EpOpTestBase {}; + +TEST_F(EPDispatchBwdTest, DispatchBwdCheck) { + EPBuffers<> buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(t.handle_mem.data(), t.recv_tokens.data(), NVTECommWindow{}, + t.result.data(), stream)); + + std::vector h_grad(num_tokens_ * hidden_dim_, __float2bfloat16(0.1f)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), + h_grad.size() * sizeof(nv_bfloat16), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(t.handle_mem.data(), t.grad_result.data(), NVTECommWindow{}, + t.grad_expert.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(t.handle_mem.data(), t.grad_expert.data(), NVTECommWindow{}, + t.g_recv_topk_weights.data(), NVTECommWindow{}, + t.grad_tokens.data(), t.grad_topk_weights.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_gt(num_tokens_ * hidden_dim_); + NVTE_CHECK_CUDA(cudaMemcpy(h_gt.data(), buf.grad_tokens.get(), + h_gt.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + const float kExpGrad = static_cast(top_k_) * 0.1f; + for (int tok = 0; tok < num_tokens_; ++tok) + for (int p = 0; p < hidden_dim_; ++p) + EXPECT_NEAR(__bfloat162float(h_gt[tok * hidden_dim_ + p]), kExpGrad, + bf16_tol(kExpGrad)) + << "grad_tokens token " << tok << " hidden " << p; + + if (g_process_id == 0) + printf(" DispatchBwdCheck: passed (grad_tokens == %.2f)\n", kExpGrad); + + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// EPDispatchBwdGradWeightsTest: round-trip per-(t, k) weights. +// ============================================================================= + +class EPDispatchBwdGradWeightsTest : public EpOpTestBase {}; + +TEST_F(EPDispatchBwdGradWeightsTest, RoundTrip) { + EPBuffers<> buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors<> t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + // Distinct per-(rank, t, k) weights so each slot carries a unique value. + std::vector h_w(num_tokens_ * top_k_); + for (int tok = 0; tok < num_tokens_; ++tok) + for (int k = 0; k < top_k_; ++k) + h_w[tok * top_k_ + k] = 0.1f + 0.01f * tok + 0.001f * k + + 0.0001f * (g_process_id + 1); + NVTE_CHECK_CUDA(cudaMemcpy(buf.topk_weights.get(), h_w.data(), + h_w.size() * sizeof(float), cudaMemcpyHostToDevice)); + + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.recv_topk_weights.get(), 0, + buf.recv_topk_weights.bytes(), stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + + // Sentinel: NaN so any (t, k) the bwd kernel fails to write is immediately visible. + std::vector h_nan(num_tokens_ * top_k_, + std::numeric_limits::quiet_NaN()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_topk_weights.get(), h_nan.data(), + h_nan.size() * sizeof(float), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + + // g_recv_topk_weights := recv_topk_weights (the round-trip input). + auto g_recv_t = TensorWrapper(buf.recv_topk_weights.get(), + std::vector{buf.recv_capacity}, DType::kFloat32); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(t.handle_mem.data(), t.grad_expert.data(), + NVTECommWindow{}, g_recv_t.data(), NVTECommWindow{}, + t.grad_tokens.data(), t.grad_topk_weights.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector h_grad_w(num_tokens_ * top_k_); + NVTE_CHECK_CUDA(cudaMemcpy(h_grad_w.data(), buf.grad_topk_weights.get(), + h_grad_w.size() * sizeof(float), cudaMemcpyDeviceToHost)); + + const float kTol = 1e-5f; + int errs = 0, k0_eq_k1 = 0; + for (int tok = 0; tok < num_tokens_; ++tok) { + for (int k = 0; k < top_k_; ++k) { + float got = h_grad_w[tok * top_k_ + k]; + float exp = h_w[tok * top_k_ + k]; + if (std::isnan(got) || std::fabs(got - exp) > kTol) { + if (errs < 8) + fprintf(stderr, "Rank %d: grad_topk_weights[%d, %d]: got %.6f, expected %.6f\n", + g_process_id, tok, k, got, exp); + ++errs; + } + } + if (top_k_ >= 2 && + std::fabs(h_grad_w[tok * top_k_ + 0] - h_grad_w[tok * top_k_ + 1]) < 1e-7f) + ++k0_eq_k1; + } + EXPECT_EQ(errs, 0); + EXPECT_EQ(k0_eq_k1, 0) << "per-token-average regression: grad[t, 0] == grad[t, 1]"; + + if (g_process_id == 0 && errs == 0 && k0_eq_k1 == 0) + printf(" RoundTrip: passed (%d (t, k) gradients)\n", num_tokens_ * top_k_); + + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); +} + +// ============================================================================= +// Integrated FwdBwd: NaN/Inf check end-to-end. +// ============================================================================= + +class EPPipelineTest : public EpOpTestBase, public ::testing::WithParamInterface { + protected: + template + void run_full_forward_backward() { + EPBuffers buf; + buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(buf); + EPTensors t(buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), + t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), + NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, + t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(t.handle_mem.data(), t.recv_tokens.data(), NVTECommWindow{}, + t.result.data(), stream)); + + std::vector h_grad(num_tokens_ * hidden_dim_, tok_from_float(0.1f)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(buf.grad_result.get(), h_grad.data(), + h_grad.size() * sizeof(Tok), + cudaMemcpyHostToDevice, stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_expert.get(), 0, buf.grad_expert.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.g_recv_topk_weights.get(), 0, buf.g_recv_topk_weights.bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(buf.grad_topk_weights.get(), 0, buf.grad_topk_weights.bytes(), stream)); + + ASSERT_NO_THROW(nvte_ep_combine_bwd(t.handle_mem.data(), t.grad_result.data(), NVTECommWindow{}, + t.grad_expert.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch_bwd(t.handle_mem.data(), t.grad_expert.data(), NVTECommWindow{}, + t.g_recv_topk_weights.data(), NVTECommWindow{}, + t.grad_tokens.data(), t.grad_topk_weights.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + ASSERT_TRUE(check_no_nan_inf(buf.result.get(), num_tokens_ * hidden_dim_, "result")); + ASSERT_TRUE(check_no_nan_inf(buf.grad_tokens.get(), num_tokens_ * hidden_dim_, "grad_tokens")); + + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); + } +}; + +TEST_P(EPPipelineTest, FullForwardBackward) { + const DType dtype = GetParam(); + // NCCL EP backend currently asserts ncclBfloat16 in ncclEpDispatch + // (contrib/nccl_ep/nccl_ep.cc); skip FP16/FP32 until the backend supports them. + if (dtype != DType::kBFloat16) { + GTEST_SKIP() << test::typeName(dtype) << " not yet supported by NCCL EP backend"; + } + switch (dtype) { + case DType::kBFloat16: run_full_forward_backward(); break; + case DType::kFloat16: run_full_forward_backward<__half> (); break; + case DType::kFloat32: run_full_forward_backward (); break; + default: FAIL() << "unsupported token dtype " << static_cast(dtype); + } + if (g_process_id == 0) + printf(" FullForwardBackward[%s]: passed\n", test::typeName(dtype).c_str()); +} + +INSTANTIATE_TEST_SUITE_P( + Dtypes, EPPipelineTest, + ::testing::Values(DType::kBFloat16, DType::kFloat16, DType::kFloat32), + [](const ::testing::TestParamInfo& info) { + return test::typeName(info.param); + }); + +// ============================================================================= +// EPZeroCopyTest: dispatch/combine with NCCL symmetric-memory windows attached +// to payload tensors (zero-copy fast path via ncclEpTensorCreateFromWindow). +// Symm-mem requirements per spec: input&output of Dispatch, input of Combine, +// input&output of Combine bwd, input of Dispatch bwd. +// ============================================================================= + +namespace { + +// Caller-owned ncclMemAlloc'd buffer with a registered symmetric window. +// Frees in destructor (deregister + ncclMemFree). Non-copyable, move-only. +struct SymmBuf { + void* ptr = nullptr; + size_t bytes = 0; + ncclWindow_t win = nullptr; + + SymmBuf() = default; + SymmBuf(const SymmBuf&) = delete; + SymmBuf& operator=(const SymmBuf&) = delete; + SymmBuf(SymmBuf&& o) noexcept : ptr(o.ptr), bytes(o.bytes), win(o.win) { + o.ptr = nullptr; o.win = nullptr; o.bytes = 0; + } + ~SymmBuf() { + if (win) ncclCommWindowDeregister(g_ep_comm, win); + if (ptr) ncclMemFree(ptr); + } + + void alloc(size_t n_bytes) { + bytes = n_bytes; + NVTE_CHECK_NCCL(ncclMemAlloc(&ptr, bytes)); + NVTE_CHECK_CUDA(cudaMemset(ptr, 0, bytes)); + NVTE_CHECK_NCCL(ncclCommWindowRegister(g_ep_comm, ptr, bytes, &win, + NCCL_WIN_COLL_SYMMETRIC)); + } +}; + +// Build an NVTECommWindow descriptor pointing at a SymmBuf's window (offset 0). +static inline NVTECommWindow symm_window(const SymmBuf& b) { + return NVTECommWindow{b.win, /*offset=*/0}; +} + +} // namespace + +// Tests rebootstrap the backend to zero_copy=ON for the symm phase via +// ep_reinitialize(); TearDown restores OFF for the rest of the suite. +class EPZeroCopyTest : public EpOpTestBase { + protected: + void TearDown() override { + if (g_ep_initialized) ep_reinitialize(/*zero_copy=*/0); + } +}; + +// Identity round-trip with symm-mem on dispatch i/o + combine input. Bit-exact +// vs HBM reference (same routing, same input). +TEST_F(EPZeroCopyTest, IdentityAllSymm) { + // HBM reference run. + EPBuffers<> ref_buf; + ref_buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(ref_buf); + EPTensors<> ref_t(ref_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + ASSERT_NO_THROW(nvte_ep_prepare(ref_t.handle_mem.data(), ref_t.topk_idx.data(), ref_t.token_counts.data(), NVTEEpLayerConfig{ref_t.top_k_, ref_t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(ref_t.handle_mem.data(), ref_t.topk_idx.data(), + ref_t.tokens.data(), NVTECommWindow{}, ref_t.topk_weights.data(), + NVTECommWindow{}, ref_t.recv_tokens.data(), NVTECommWindow{}, + ref_t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(ref_t.handle_mem.data(), ref_t.recv_tokens.data(), NVTECommWindow{}, + ref_t.result.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector ref_recv(ref_buf.recv_capacity * hidden_dim_); + std::vector ref_result(num_tokens_ * hidden_dim_); + NVTE_CHECK_CUDA(cudaMemcpy(ref_recv.data(), ref_buf.recv_tokens.get(), + ref_recv.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(ref_result.data(), ref_buf.result.get(), + ref_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + // Switch backend to zero_copy=ON for the symm phase. + ep_reinitialize(/*zero_copy=*/1); + + // Symm-mem run: tokens, recv_tokens, combine_input (== recv_tokens) all symm. + EPBuffers<> sym_buf; // alloc all buffers except the symm ones. + sym_buf.alloc(num_tokens_, top_k_, hidden_dim_, num_local_experts_, + ep_size_, max_tokens_per_rank_); + upload_inputs(sym_buf); + + SymmBuf sym_tokens, sym_recv; + sym_tokens.alloc(num_tokens_ * hidden_dim_ * sizeof(nv_bfloat16)); + sym_recv .alloc(sym_buf.recv_capacity * hidden_dim_ * sizeof(nv_bfloat16)); + + // Stage same tokens into the symm-mem input. + auto h_tok = generate_tokens(g_process_id, num_tokens_, hidden_dim_); + NVTE_CHECK_CUDA(cudaMemcpy(sym_tokens.ptr, h_tok.data(), + h_tok.size() * sizeof(nv_bfloat16), cudaMemcpyHostToDevice)); + + EPTensors<> sym_t(sym_buf, num_tokens_, top_k_, hidden_dim_, num_local_experts_); + // Replace the tokens/recv_tokens views with ones pointing at the symm buffers. + sym_t.tokens = TensorWrapper(sym_tokens.ptr, + std::vector{(size_t)num_tokens_, (size_t)hidden_dim_}, DType::kBFloat16); + sym_t.recv_tokens = TensorWrapper(sym_recv.ptr, + std::vector{sym_buf.recv_capacity, (size_t)hidden_dim_}, DType::kBFloat16); + + ASSERT_NO_THROW(nvte_ep_prepare(sym_t.handle_mem.data(), sym_t.topk_idx.data(), sym_t.token_counts.data(), NVTEEpLayerConfig{sym_t.top_k_, sym_t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_dispatch(sym_t.handle_mem.data(), sym_t.topk_idx.data(), + sym_t.tokens.data(), symm_window(sym_tokens), + sym_t.topk_weights.data(), NVTECommWindow{}, + sym_t.recv_tokens.data(), symm_window(sym_recv), + sym_t.recv_topk_weights.data(), NVTECommWindow{}, stream)); + ASSERT_NO_THROW(nvte_ep_combine(sym_t.handle_mem.data(), sym_t.recv_tokens.data(), + symm_window(sym_recv), sym_t.result.data(), stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + + std::vector sym_recv_host(sym_buf.recv_capacity * hidden_dim_); + std::vector sym_result(num_tokens_ * hidden_dim_); + NVTE_CHECK_CUDA(cudaMemcpy(sym_recv_host.data(), sym_recv.ptr, + sym_recv_host.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(sym_result.data(), sym_buf.result.get(), + sym_result.size() * sizeof(nv_bfloat16), cudaMemcpyDeviceToHost)); + + // Compare per filled recv slot (HBM ref vs symm) and full result. + int total_recv = read_total_recv(sym_buf); + for (int i = 0; i < total_recv * hidden_dim_; ++i) + ASSERT_EQ(__bfloat162float(sym_recv_host[i]), __bfloat162float(ref_recv[i])) + << "recv mismatch at " << i; + for (size_t i = 0; i < sym_result.size(); ++i) + ASSERT_EQ(__bfloat162float(sym_result[i]), __bfloat162float(ref_result[i])) + << "result mismatch at " << i; + + if (g_process_id == 0) + printf(" IdentityAllSymm: passed (recv_slots=%d, bit-exact vs HBM)\n", total_recv); + + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); +} + + +// -- main ---------------------------------------------------------------------- + +int main(int argc, char* argv[]) { + if (!ep_bootstrap(argc, argv)) return 0; + int ret = RUN_ALL_TESTS(); + ep_teardown(); + return ret; +} diff --git a/tests/cpp_distributed/test_ep_common.h b/tests/cpp_distributed/test_ep_common.h new file mode 100644 index 0000000000..b2421ffd10 --- /dev/null +++ b/tests/cpp_distributed/test_ep_common.h @@ -0,0 +1,200 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* + * Shared TE EP test infrastructure. Include once per TU; ep_bootstrap() in + * each test binary's main() populates process-level globals. + * Defaults: 4 experts/rank, hidden_dim=256, max_tokens_per_rank=64. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include "../cpp/test_common.h" +#include "util/logging.h" + +using transformer_engine::DType; +using transformer_engine::TensorWrapper; + +#define CHECK_MPI(expr) \ + do { \ + int _err_mpi = (expr); \ + NVTE_CHECK(_err_mpi == MPI_SUCCESS, "MPI error: ", _err_mpi); \ + } while (false) + +// -- Process-level state ------------------------------------------------------- + +static int g_process_id = -1; +static int g_num_processes = -1; + +static int g_sm_major = -1; // set by ep_bootstrap; -1 until then +static int g_ep_size = -1; +static int g_num_experts = -1; +static int g_hidden_dim = 256; +static int g_max_tokens_per_rank = 64; +static NVTEDType g_max_token_dtype = kNVTEFloat32; // staging-buffer sizing +static bool g_ep_initialized = false; +static ncclComm_t g_ep_comm = nullptr; // owned by harness, destroyed in ep_teardown + +// RAII owner for a cudaMalloc'd device buffer; element-count API on top of +// test::CudaPtr. +template +struct DevBuf { + test::CudaPtr ptr; + size_t count = 0; + + DevBuf() = default; + explicit DevBuf(size_t n) { alloc(n); } + + void alloc(size_t n) { + count = n; + ptr = (n > 0) ? test::cuda_alloc(n * sizeof(T)) : test::CudaPtr{}; + } + void reset() { + ptr.reset(); + count = 0; + } + + T* get() const { return ptr.get(); } + size_t bytes() const { return count * sizeof(T); } +}; + +// -- Shared routing helper ----------------------------------------------------- + +// Balanced round-robin routing: token t on rank r maps top_k experts to +// (r * num_local_experts + t * top_k + k) % num_experts +static inline std::vector routing_balanced( + int rank, int num_tokens, int top_k, int num_experts, int num_local_experts) { + std::vector idx(num_tokens * top_k); + for (int t = 0; t < num_tokens; ++t) + for (int k = 0; k < top_k; ++k) + idx[t * top_k + k] = (rank * num_local_experts + t * top_k + k) % num_experts; + return idx; +} + +// -- ncclUniqueId exchange via MPI --------------------------------------------- + +static void exchange_unique_id(ncclUniqueId* uid) { + if (g_process_id == 0) NVTE_CHECK_NCCL(ncclGetUniqueId(uid)); + CHECK_MPI(MPI_Bcast(uid, sizeof(*uid), MPI_BYTE, 0, MPI_COMM_WORLD)); +} + +// -- CLI parsing --------------------------------------------------------------- + +static void ep_parse_args(int argc, char* argv[]) { + for (int i = 1; i < argc; ++i) { + std::string a(argv[i]); + if (a.rfind("--max-token-dtype=", 0) == 0) + g_max_token_dtype = static_cast(std::stoi(a.substr(18))); + } +} + +// -- Bootstrap / teardown ------------------------------------------------------ + +// Returns false if the binary should exit without running tests (wrong SM, etc.). +static bool ep_bootstrap(int argc, char* argv[]) { + int mpi_initialized = 0; + MPI_Initialized(&mpi_initialized); + if (!mpi_initialized) CHECK_MPI(MPI_Init(&argc, &argv)); + CHECK_MPI(MPI_Comm_rank(MPI_COMM_WORLD, &g_process_id)); + CHECK_MPI(MPI_Comm_size(MPI_COMM_WORLD, &g_num_processes)); + + ep_parse_args(argc, argv); + ::testing::InitGoogleTest(&argc, argv); + + int device_count; + cudaGetDeviceCount(&device_count); + cudaSetDevice(g_process_id % device_count); + + int device, major; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); + g_sm_major = major; + if (major < 9) { + if (g_process_id == 0) + printf("SKIP: EP requires SM_90+ (device is SM_%d0)\n", major); + return false; + } + if (g_num_processes < 2) { + if (g_process_id == 0) + printf("SKIP: at least 2 processes required\n"); + return false; + } + + g_ep_size = g_num_processes; + g_num_experts = g_ep_size * 4; // 4 experts per rank + + ncclUniqueId uid{}; + exchange_unique_id(&uid); + + NVTEEpGroupConfig group_config{}; + group_config.ep_size = g_ep_size; + group_config.num_experts = g_num_experts; + group_config.max_tokens_per_rank = g_max_tokens_per_rank; + // Worst-case for top_k fan-out: ep_size * max_tokens_per_rank * 2. + group_config.max_recv_tokens_per_rank = g_ep_size * g_max_tokens_per_rank * 2; + group_config.hidden_dim = g_hidden_dim; + group_config.max_token_dtype = g_max_token_dtype; + + NVTE_CHECK_NCCL(ncclCommInitRank(&g_ep_comm, g_num_processes, uid, g_process_id)); + nvte_ep_initialize(static_cast(g_ep_comm), group_config); + + if (g_process_id == 0) { + printf("EP initialized: ep_size=%d num_experts=%d " + "hidden_dim=%d max_tokens_per_rank=%d\n", + g_ep_size, g_num_experts, g_hidden_dim, g_max_tokens_per_rank); + } + + g_ep_initialized = true; + return true; +} + +// Re-bootstrap the EP backend on the existing g_ep_comm with a new zero_copy +// setting. +static void ep_reinitialize(int zero_copy) { + if (!g_ep_initialized) return; + nvte_ep_shutdown(); + NVTEEpGroupConfig group_config{}; + group_config.ep_size = g_ep_size; + group_config.num_experts = g_num_experts; + group_config.max_tokens_per_rank = g_max_tokens_per_rank; + group_config.max_recv_tokens_per_rank = g_ep_size * g_max_tokens_per_rank * 2; + group_config.hidden_dim = g_hidden_dim; + group_config.max_token_dtype = g_max_token_dtype; + group_config.zero_copy = zero_copy; + nvte_ep_initialize(static_cast(g_ep_comm), group_config); +} + +// Tear down in dependency order: backend's ep_group reads from ep_comm, +// so destroy the group first, then the comm. +static void ep_teardown() { + if (g_ep_initialized) { + nvte_ep_shutdown(); + if (g_ep_comm != nullptr) { + ncclCommDestroy(g_ep_comm); + g_ep_comm = nullptr; + } + g_ep_initialized = false; + } + int finalized = 0; + MPI_Finalized(&finalized); + if (!finalized) MPI_Finalize(); +} diff --git a/tests/jax/multi_process_launch_ep.sh b/tests/jax/multi_process_launch_ep.sh new file mode 100755 index 0000000000..a37ffc2952 --- /dev/null +++ b/tests/jax/multi_process_launch_ep.sh @@ -0,0 +1,67 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +#!/bin/bash + +SCRIPT_NAMES="${SCRIPT_NAMES:-test_multi_process_ep.py}" +TEST_TIMEOUT_S="${TEST_TIMEOUT_S:-180}" + + +XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_graph_min_graph_size=1" + +export XLA_FLAGS="${XLA_BASE_FLAGS}" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +export PYTHONPATH="${TE_REPO_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +NUM_RUNS=$(nvidia-smi -L | wc -l) + +if [ "${NUM_RUNS}" -lt 4 ]; then + echo "NCCL EP requires at least 4 GPUs (found ${NUM_RUNS}); SKIPPING." + exit 0 +fi +# Default test mesh is (2, 2); use exactly 4 ranks even on larger boxes. +NUM_RUNS="${NVTE_TEST_EP_NUM_RANKS:-4}" + +OVERALL_RET=0 + +for SCRIPT_NAME in $SCRIPT_NAMES; do + echo "=== Running ${SCRIPT_NAME} ===" + for ((i=1; i stdout_rank_${i}.txt 2>&1 & + done + + timeout --foreground --signal=KILL "${TEST_TIMEOUT_S}" \ + python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS 2>&1 | tee stdout_multi_process.txt + + wait + + RET=0 + if grep -q "FAILED" stdout_multi_process.txt; then + RET=1 + fi + # Treat missing test summary on rank 0 as hang/crash rather than silent success. + if ! grep -qE "Ran [0-9]+ test|^OK$|PASSED" stdout_multi_process.txt; then + echo "ERROR: rank 0 produced no test summary for ${SCRIPT_NAME} — likely a hang or early crash." + echo " NCCL EP requires NVLS multicast; check NCCL_DEBUG=INFO output." + RET=1 + fi + if [ "$RET" -ne 0 ]; then + for ((i=1; i/dev/null || echo "(no log)" + done + fi + + rm -f stdout_multi_process.txt stdout_rank_*.txt + if [ "$RET" -ne 0 ]; then + OVERALL_RET=1 + fi +done + +exit "$OVERALL_RET" diff --git a/tests/jax/run_te_ep_moe.sh b/tests/jax/run_te_ep_moe.sh new file mode 100755 index 0000000000..32d5f21956 --- /dev/null +++ b/tests/jax/run_te_ep_moe.sh @@ -0,0 +1,122 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Multiprocess (one-GPU-per-process) launcher for the TE-EP MoE custom_vjp +# test suite. Forks one pytest invocation per visible GPU, passing each +# its own --num-process=N --process-id=i, and waits for all of them. Each +# child calls jax.distributed.initialize(..., local_device_ids=process_id) +# so each Python process only sees its one GPU as a local device and the +# participating processes form a global (ep, fsdp) mesh. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" +TEST_FILE="$TE_ROOT/tests/jax/test_te_ep_moe.py" +PYTEST_INI="$TE_ROOT/tests/jax/pytest.ini" + +NUM_GPUS="${NUM_GPUS:-$(nvidia-smi -L | wc -l)}" +if [ "$NUM_GPUS" -lt 4 ]; then + echo "[run_te_ep_moe.sh] need >=4 GPUs (got $NUM_GPUS); aborting" >&2 + exit 1 +fi + +export XLA_PYTHON_CLIENT_PREALLOCATE="${XLA_PYTHON_CLIENT_PREALLOCATE:-false}" +export XLA_PYTHON_CLIENT_MEM_FRACTION="${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.5}" +export TE_EP_MOE_COORDINATOR_ADDRESS="${TE_EP_MOE_COORDINATOR_ADDRESS:-127.0.0.1:13457}" + +echo "============================================================" +echo "TE-EP MoE MULTIPROCESS test (one process per GPU, ${NUM_GPUS} GPUs)" +echo " test file : $TEST_FILE" +echo " coordinator : $TE_EP_MOE_COORDINATOR_ADDRESS" +echo " XLA_PYTHON_CLIENT_PREALLOCATE: $XLA_PYTHON_CLIENT_PREALLOCATE" +echo " XLA_PYTHON_CLIENT_MEM_FRACTION: $XLA_PYTHON_CLIENT_MEM_FRACTION" +echo "============================================================" + +if [ -n "${TE_EP_MOE_MP_LOG_DIR:-}" ]; then + LOG_DIR="$TE_EP_MOE_MP_LOG_DIR" + mkdir -p "$LOG_DIR" +else + LOG_DIR=$(mktemp -d -t te_ep_moe_mp_XXXXXX) +fi +echo "Per-process logs: $LOG_DIR" + +PIDS=() + +cleanup() { + for pid in "${PIDS[@]:-}"; do + if kill -0 "$pid" 2>/dev/null; then + kill -TERM "$pid" 2>/dev/null || true + fi + done + sleep 1 + for pid in "${PIDS[@]:-}"; do + if kill -0 "$pid" 2>/dev/null; then + kill -KILL "$pid" 2>/dev/null || true + fi + done +} +trap cleanup EXIT INT TERM + +for i in $(seq 0 $((NUM_GPUS - 1))); do + LOG_FILE="$LOG_DIR/proc_${i}.log" + PYTEST_CMD=( + python3 -m pytest -c "$PYTEST_INI" + "$TEST_FILE" + -p no:typeguard + -v -s + --num-process="$NUM_GPUS" + --process-id="$i" + ) + if [ "$i" -eq 0 ]; then + echo "=== Live output from process 0 ===" + "${PYTEST_CMD[@]}" 2>&1 | tee "$LOG_FILE" & + else + "${PYTEST_CMD[@]}" > "$LOG_FILE" 2>&1 & + fi + PIDS+=("$!") +done + +EXITS=() +for pid in "${PIDS[@]}"; do + if wait "$pid"; then + EXITS+=("0") + else + EXITS+=("$?") + fi +done + +echo +echo "============================================================" +echo "Per-process exit codes:" +for i in "${!EXITS[@]}"; do + echo " proc $i -> ${EXITS[$i]}" +done + +# Treat exit 0 (pass) and exit 5 (pytest "no tests collected", which the +# file emits via pytest.skip(allow_module_level=True) on pre-Blackwell +# GPUs) as success. +FAILED=0 +for e in "${EXITS[@]}"; do + if [ "$e" != "0" ] && [ "$e" != "5" ]; then + FAILED=1 + break + fi +done + +echo +if [ "$FAILED" -eq 0 ]; then + echo "[run_te_ep_moe.sh] all processes PASSED" + if [ -z "${TE_EP_MOE_MP_LOG_DIR:-}" ]; then + rm -rf "$LOG_DIR" + fi + exit 0 +fi + +echo "[run_te_ep_moe.sh] at least one process FAILED" +echo " retaining logs at $LOG_DIR for diagnosis" +echo " process 0 tail:" +tail -20 "$LOG_DIR/proc_0.log" 2>/dev/null || true +exit 1 diff --git a/tests/jax/test_multi_process_ep.py b/tests/jax/test_multi_process_ep.py new file mode 100644 index 0000000000..1f986adbe8 --- /dev/null +++ b/tests/jax/test_multi_process_ep.py @@ -0,0 +1,742 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Multi-process unit tests for the TE-JAX Expert Parallelism (EP) primitives. + +Default mesh is (dp=2, ep=2); override via ``NVTE_TEST_EP_MESH=DPxEP``. +Coverage: + + - ``ep_bootstrap`` rejects when ``ep_resource`` is unset. + - Individual primitives (``ep_prepare``, ``ep_dispatch_fwd``, ``ep_combine_fwd``) + round-trip an identity expert → output ≈ tokens. + - ``ep_dispatch`` custom_vjp: ``grad_tokens ≈ TOP_K · tokens`` (closed form). + - ``ep_combine`` custom_vjp: ``max|grad_eo| ≈ eo_const / TOP_K`` (closed form). + - ``ep_dispatch`` custom_vjp: exact per-(t, k) ``grad_topk_weights`` under + skewed upstream gradients (no k-axis averaging). + - HLO reshard guard: compile-only, no XLA collectives outside the EP FFI. + +Launch via tests/jax/multi_process_launch_ep.sh (one process per GPU). +""" + +import os +import sys +import unittest + +import jax +import jax.experimental.multihost_utils as jmu +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from transformer_engine.jax.sharding import MeshResource, global_shard_guard +from transformer_engine.jax.ep import EpLayerConfig, ep_bootstrap, ep_dispatch, ep_combine +from transformer_engine.jax.cpp_extensions.ep import ( + ep_prepare, + ep_dispatch_fwd, + ep_combine_fwd, + get_ep_config, +) + + +# ── Test config ───────────────────────────────────────────────────────────── +# NCCL EP requires NUM_LOCAL_EXPERTS*ep % 4 == 0 (TMA alignment in +# device/hybridep_adapter.cu:511). With NUM_LOCAL_EXPERTS=2, ep must be even. + +NUM_LOCAL_EXPERTS = 2 # per-rank → num_experts = NLE * EP +HIDDEN_DIM = 32 +TOP_K = 2 +TOKENS_PER_DP_SHARD = 4 # per device along dp + + +def _factor_dp_ep(num_procs): + """Default to a (2, 2) mesh. Override via ``NVTE_TEST_EP_MESH=DPxEP``. + + NUM_LOCAL_EXPERTS*ep must be a multiple of 4 for NCCL EP's TMA alignment. + """ + override = os.environ.get("NVTE_TEST_EP_MESH") + if override: + dp_str, ep_str = override.lower().split("x") + dp, ep = int(dp_str), int(ep_str) + if dp * ep != num_procs: + raise ValueError( + f"NVTE_TEST_EP_MESH={override!r} does not multiply to num_procs={num_procs}" + ) + if (NUM_LOCAL_EXPERTS * ep) % 4 != 0: + raise ValueError( + f"NUM_LOCAL_EXPERTS*ep ({NUM_LOCAL_EXPERTS}*{ep}) must be a multiple of 4 " + "for NCCL EP TMA alignment" + ) + return dp, ep + if num_procs != 4: + raise ValueError( + f"default mesh expects exactly 4 ranks (got {num_procs}); set " + "NVTE_TEST_EP_MESH=DPxEP to override" + ) + return 2, 2 + + +def _build_mesh(dp, ep): + devs = np.asarray(jax.devices()).reshape(dp, ep) + return Mesh(devs, ("dp", "ep")) + + +def _local_device_sm(): + """Return SM major*10+minor of the first local CUDA device, or None.""" + try: + dev = jax.local_devices()[0] + cap = getattr(dev, "compute_capability", None) + if cap is None: + return None + major, minor = (int(x) for x in str(cap).split(".")) + return major * 10 + minor + except Exception: + return None + + +class TestEP(unittest.TestCase): + @classmethod + def setUpClass(cls): + sm = _local_device_sm() + if sm is not None and sm < 90: + raise unittest.SkipTest(f"NCCL EP requires SM>=90 (got SM{sm})") + cls.num_procs = jax.process_count() + cls.rank = jax.process_index() + cls.dp, cls.ep = _factor_dp_ep(cls.num_procs) + cls.num_experts = NUM_LOCAL_EXPERTS * cls.ep + # recv_capacity is per-DP-group (NCCL EP comms isolated per DP color). + # Under PartitionSpec(("dp","ep"), None) each EP group sees + # T_global/dp = TOKENS_PER_DP_SHARD tokens total; pad for routing skew. + T_per_ep_group = TOKENS_PER_DP_SHARD + active_experts = min(cls.num_experts, T_per_ep_group * TOP_K) + overconc = cls.num_experts // active_experts + cls.recv_capacity_per_rank = ( + NUM_LOCAL_EXPERTS * max(T_per_ep_group * TOP_K, 16) * overconc * 2 + ) + cls.mesh = _build_mesh(cls.dp, cls.ep) + cls.mr = MeshResource(dp_resource="dp", ep_resource="ep") + with cls.mesh, global_shard_guard(cls.mr): + ep_bootstrap( + world_size=cls.num_procs, + rank=cls.rank, + num_experts=cls.num_experts, + max_tokens_per_rank=TOKENS_PER_DP_SHARD, + recv_capacity_per_rank=cls.recv_capacity_per_rank, + hidden_dim=HIDDEN_DIM, + ) + # Bootstrap must snapshot ep_size and num_ep_groups onto EpConfig so + # abstract-eval never needs the active mesh. + assert get_ep_config().ep_size == cls.ep + assert get_ep_config().num_ep_groups == cls.dp + # One layer config shared by all single-layer tests below; non-zero + # alignment exercises dispatch_output_per_expert_alignment end-to-end. + cls.hk = EpLayerConfig(top_k=TOP_K, dispatch_output_per_expert_alignment=16) + + # ── Bootstrap precondition ──────────────────────────────────────────── + + def test_bootstrap_rejects_missing_ep_axis(self): + """ep_bootstrap raises when MeshResource has no ep_resource.""" + with self.mesh, global_shard_guard(MeshResource()): + with self.assertRaisesRegex(ValueError, "ep_resource"): + ep_bootstrap( + world_size=self.num_procs, + rank=self.rank, + num_experts=self.num_experts, + max_tokens_per_rank=TOKENS_PER_DP_SHARD, + recv_capacity_per_rank=self.recv_capacity_per_rank, + hidden_dim=HIDDEN_DIM, + ) + + # ── Helpers ─────────────────────────────────────────────────────────── + + def _make_identity_inputs(self, nonuniform=False): + """Identity routing + uniform weights — combined output ≈ tokens. + + ``nonuniform=False``: ``(t*TOP_K+k) % E`` (round-robin, near-balanced). + ``nonuniform=True``: ``top1=0`` for every token, ``top2=1+(t%(E-1))`` — + expert 0 absorbs the entire batch while the others split the second + slot evenly. Exercises a skewed per-expert load. + """ + T_global = TOKENS_PER_DP_SHARD * self.dp + E = self.num_experts + topk_idx = np.empty((T_global, TOP_K), dtype=np.int32) + if nonuniform: + assert TOP_K == 2, "non-uniform pattern assumes top_k=2" + for t in range(T_global): + topk_idx[t, 0] = 0 + topk_idx[t, 1] = 1 + (t % (E - 1)) + else: + for t in range(T_global): + for k in range(TOP_K): + topk_idx[t, k] = (t * TOP_K + k) % E + topk_idx = jnp.asarray(topk_idx) + topk_weights = jnp.full((T_global, TOP_K), 1.0 / TOP_K, dtype=jnp.float32) + tokens = jnp.asarray( + np.linspace(0.1, 0.9, T_global * HIDDEN_DIM, dtype=np.float32).reshape( + T_global, HIDDEN_DIM + ), + dtype=jnp.bfloat16, + ) + return T_global, topk_idx, tokens, topk_weights + + def _make_random_inputs(self, seed=42, nonuniform=True): + """Random tokens + skewed top-2 routing (top1=0 always; top2 varies). + + Non-uniform load by default — guarantees expert 0 receives every token + while the rest of the experts split the second slot. Use + ``nonuniform=False`` for a balanced (t%E, (t+1)%E) pattern. + """ + T_dp = TOKENS_PER_DP_SHARD * self.dp + E = self.num_experts + rng = np.random.default_rng(seed=seed) + tokens = jnp.asarray( + rng.standard_normal((T_dp, HIDDEN_DIM), dtype=np.float32) * 0.5, + dtype=jnp.bfloat16, + ) + topk_idx_np = np.empty((T_dp, TOP_K), dtype=np.int32) + if nonuniform: + assert TOP_K == 2, "non-uniform pattern assumes top_k=2" + for t in range(T_dp): + topk_idx_np[t, 0] = 0 + topk_idx_np[t, 1] = 1 + (t % (E - 1)) + else: + for t in range(T_dp): + a, b = t % E, (t + 1) % E + topk_idx_np[t, 0], topk_idx_np[t, 1] = (a, b) if a < b else (b, a) + topk_idx = jnp.asarray(topk_idx_np) + topk_weights = jnp.asarray(np.full((T_dp, TOP_K), 1.0 / TOP_K, dtype=np.float32)) + return T_dp, tokens, topk_idx, topk_weights + + @staticmethod + def _preweight_expert_out(expert_out, recv_topk_weights): + """ep_combine is unweighted; mirror the caller-side weighting + mask.""" + mask = (recv_topk_weights != 0).astype(jnp.float32)[..., None] + w = recv_topk_weights[..., None] + return (expert_out.astype(jnp.float32) * w * mask).astype(expert_out.dtype) + + # ── Individual primitives (cpp_extensions level) ────────────────────── + + def test_two_handle_mems_no_aliasing(self): + """Two ``ep_prepare`` calls in one jit must produce distinct handle_mem + buffers; the pointer-keyed C++ cache must not alias HandleEntries + across distinct logical layers.""" + _T, topk_idx, _tokens, _w = self._make_identity_inputs() + ka, kb = ( + EpLayerConfig(top_k=TOP_K, dispatch_output_per_expert_alignment=16), + EpLayerConfig(top_k=TOP_K, dispatch_output_per_expert_alignment=16), + ) + dp_spec = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + + @jax.jit + def run(idx): + _tc_a, ha = ep_prepare(ka, idx) + _tc_b, hb = ep_prepare(kb, idx) + return ha, hb + + hm_a, hm_b = run(idx_s) + hm_a.block_until_ready() + hm_b.block_until_ready() + self.assertNotEqual(hm_a.unsafe_buffer_pointer(), hm_b.unsafe_buffer_pointer()) + + def test_two_layer_dispatch_no_handle_aliasing(self): + """Two ep_dispatch calls in one jit with distinct ``EpLayerConfig``s must + not clobber each other's routing state. Different inputs per layer with + identity routing + uniform weights => both recv buffers must independently + identity-round-trip via ep_combine.""" + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) + tokens_b = (tokens.astype(jnp.float32) * -1.0 + 0.25).astype(tokens.dtype) + ka, kb = ( + EpLayerConfig(top_k=TOP_K, dispatch_output_per_expert_alignment=16), + EpLayerConfig(top_k=TOP_K, dispatch_output_per_expert_alignment=16), + ) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + ta = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec)) + tb = jax.lax.with_sharding_constraint(tokens_b, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + + def one_layer(hk, idx, toks, w_): + recv_t, recv_w, hm, tc = ep_dispatch(hk, idx, toks, w_, self.recv_capacity_per_rank) + recv_t = jax.lax.with_sharding_constraint( + recv_t, NamedSharding(self.mesh, ep_spec_3d) + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, ep_spec_2d) + ) + weighted = self._preweight_expert_out(recv_t, recv_w) + return ep_combine(hk, hm, tc, weighted, T_global, out_sharding=(("dp", "ep"), None)) + + @jax.jit + def run(idx, ta_, tb_, w_): + return one_layer(ka, idx, ta_, w_), one_layer(kb, idx, tb_, w_) + + out_a, out_b = run(idx_s, ta, tb, w) + out_a.block_until_ready() + out_b.block_until_ready() + out_a_g = jmu.process_allgather(out_a, tiled=True) + out_b_g = jmu.process_allgather(out_b, tiled=True) + + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(out_a_g.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + np.testing.assert_allclose( + np.asarray(out_b_g.astype(jnp.float32)), + np.asarray(tokens_b.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + def test_primitive_prepare(self): + """ep_prepare returns token_counts and handle_mem of the expected shapes.""" + T_global, topk_idx, _tokens, _w = self._make_identity_inputs() + del T_global + dp_spec = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + + @jax.jit + def run(idx): + tc, hm = ep_prepare(self.hk, idx) + return tc, hm + + tc, hm = run(idx_s) + tc.block_until_ready() + self.assertEqual(tc.shape, (self.dp * self.ep, NUM_LOCAL_EXPERTS)) + self.assertEqual(hm.shape[0], self.dp * self.ep) + self.assertGreater(hm.shape[1], 0) + + def _run_identity_round_trip(self, nonuniform): + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=nonuniform) + dp_spec = PartitionSpec(("dp", "ep"), None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + tok_s = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec)) + w_s = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + @jax.jit + def run(idx, toks, w): + _tc, hm = ep_prepare(self.hk, idx) + recv_t, recv_w = ep_dispatch_fwd( + self.hk, hm, idx, toks, w, self.recv_capacity_per_rank + ) + recv_t = jax.lax.with_sharding_constraint( + recv_t, NamedSharding(self.mesh, ep_spec_3d) + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, ep_spec_2d) + ) + # Apply the weighted hadamard inline (combine FFI is unweighted). + mask = (recv_w != 0).astype(jnp.float32)[..., None] + weighted = (recv_t.astype(jnp.float32) * recv_w[..., None] * mask).astype( + recv_t.dtype + ) + weighted = jax.lax.with_sharding_constraint( + weighted, NamedSharding(self.mesh, ep_spec_3d) + ) + out = ep_combine_fwd( + self.hk, + hm, + weighted, + T_global, + out_partition_spec=(("dp", "ep"), None), + ) + return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec)) + + out = run(idx_s, tok_s, w_s) + out.block_until_ready() + # Allgather so the rank-0 numpy comparison sees the full global tensor. + out_global = jmu.process_allgather(out, tiled=True) + + # Identity expert + uniform weights → out ≈ tokens (rank-0 check). + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(out_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + def test_primitive_dispatch_combine_identity_uniform(self): + """Round-robin routing → identity round-trip via the primitive layer.""" + self._run_identity_round_trip(nonuniform=False) + + def test_primitive_dispatch_combine_identity_nonuniform(self): + """Skewed routing (top1=0 always) → identity round-trip via the primitive layer.""" + self._run_identity_round_trip(nonuniform=True) + + def test_primitive_dispatch_combine_identity_bwd_uniform(self): + """Bwd through identity round-trip: ∇(0.5 ||out||²) w.r.t. tokens ≈ tokens. + + Identity routing + uniform top-k weights ⇒ dispatch∘combine is the + identity, so loss = 0.5||tokens||² and ∇_tokens loss = tokens. + """ + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(toks): + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + recv_t, recv_w, hm, tc = ep_dispatch( + self.hk, idx, toks, w, self.recv_capacity_per_rank + ) + recv_t = jax.lax.with_sharding_constraint( + recv_t, NamedSharding(self.mesh, ep_spec_3d) + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, ep_spec_2d) + ) + weighted = self._preweight_expert_out(recv_t, recv_w) + out = ep_combine( + self.hk, hm, tc, weighted, T_global, out_sharding=(("dp", "ep"), None) + ) + return 0.5 * (out.astype(jnp.float32) ** 2).sum() + + grad = jax.jit(jax.grad(loss_fn))(tokens) + grad.block_until_ready() + grad_global = jmu.process_allgather(grad, tiled=True) + + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(grad_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + def test_dispatch_combine_3d_input_output(self): + """3D input ``[B, S, H]`` sharded on the first dim only — + ``(("dp","ep"), None, None)`` here — dispatch accepts the rank-3 shape + and combine returns a matching 3D ``[B, S, H]`` output. End-to-end + round trip recovers the original tokens under identity routing + + uniform top-k weights.""" + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs(nonuniform=False) + # B is sharded across all (dp*ep) ranks; S held in one piece per rank. + B, S, H = T_global, 1, tokens.shape[-1] + tokens_3d = tokens.reshape(B, S, H) + topk_idx_3d = topk_idx.reshape(B, S, -1) + topk_w_3d = topk_w.reshape(B, S, -1) + spec_3d = PartitionSpec(("dp", "ep"), None, None) + out_spec_3d = (("dp", "ep"), None, None) + with self.mesh, global_shard_guard(self.mr): + idx_s = jax.lax.with_sharding_constraint(topk_idx_3d, NamedSharding(self.mesh, spec_3d)) + tok_s = jax.lax.with_sharding_constraint(tokens_3d, NamedSharding(self.mesh, spec_3d)) + w_s = jax.lax.with_sharding_constraint(topk_w_3d, NamedSharding(self.mesh, spec_3d)) + + ep_t = PartitionSpec(("dp", "ep"), None, None) + ep_w = PartitionSpec(("dp", "ep"), None) + + @jax.jit + def run(idx, toks, w): + recv_t, recv_w, hm, _tc = ep_dispatch( + self.hk, idx, toks, w, self.recv_capacity_per_rank + ) + recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t)) + recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w)) + weighted = self._preweight_expert_out(recv_t, recv_w) + out = ep_combine( + self.hk, + hm, + _tc, + weighted, + num_local_tokens=(B, S), + out_sharding=out_spec_3d, + ) + return out + + out = run(idx_s, tok_s, w_s) + out.block_until_ready() + out_global = jmu.process_allgather(out, tiled=True) + + if self.rank == 0: + self.assertEqual(out_global.shape, (B, S, H)) + np.testing.assert_allclose( + np.asarray(out_global.astype(jnp.float32)), + np.asarray(tokens_3d.astype(jnp.float32)), + atol=5e-2, + rtol=5e-2, + ) + + # ── Custom-VJP tests ───────────────────────────────────────────────── + + def test_dispatch_vjp_fwd_bwd(self): + """ep_dispatch fwd + jax.grad w.r.t. tokens. + + Identity routing + loss = 0.5||recv_tokens||² ⇒ each token appears + TOP_K times in recv_tokens (all routes fit recv_capacity), so + grad_tokens = TOP_K * tokens (closed form). + """ + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs() + del T_global + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + + with self.mesh, global_shard_guard(self.mr): + + align = max(int(self.hk.dispatch_output_per_expert_alignment), 1) + + def loss_fn(toks): + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + recv_tokens, _recv_w, _hm, tc = ep_dispatch( + self.hk, idx, toks, w, self.recv_capacity_per_rank + ) + recv_tokens = jax.lax.with_sharding_constraint( + recv_tokens, NamedSharding(self.mesh, ep_spec_3d) + ) + # ep_dispatch fills only slots [0, sum(padded_per_expert)); + # the tail is uninitialized. Mask with jnp.where (NaN-safe; + # multiply would propagate NaN*0=NaN). + padded = ((tc + align - 1) // align) * align + total_recv = jnp.sum(padded, axis=-1, keepdims=True).astype(jnp.int32) + slot_idx = jnp.arange(self.recv_capacity_per_rank, dtype=jnp.int32) + mask = slot_idx[None, :] < total_recv + rt32 = jnp.where(mask[..., None], recv_tokens.astype(jnp.float32), 0.0) + return 0.5 * (rt32**2).sum() + + loss, grad_tokens = jax.jit(jax.value_and_grad(loss_fn))(tokens) + grad_tokens.block_until_ready() + grad_global = jmu.process_allgather(grad_tokens, tiled=True) + + self.assertTrue(np.isfinite(float(loss))) + self.assertEqual(grad_tokens.shape, tokens.shape) + if self.rank == 0: + np.testing.assert_allclose( + np.asarray(grad_global.astype(jnp.float32)), + np.asarray(tokens.astype(jnp.float32)) * float(TOP_K), + atol=5e-2, + rtol=5e-2, + ) + + def test_combine_vjp_fwd_bwd(self): + """ep_combine fwd + jax.grad w.r.t. expert_out. + + Identity routing + constant eo=c + uniform topk_w ⇒ combined[t] = c + (sum_k topk_w = 1) and grad_eo[e, s, h] = recv_w[e, s] * c at filled + slots — so max|grad_eo| ≈ c / TOP_K. + """ + T_global, topk_idx, tokens, topk_w = self._make_identity_inputs() + eo_const = 0.5 + expert_out = jnp.full( + (self.dp * self.ep, self.recv_capacity_per_rank, HIDDEN_DIM), + eo_const, + dtype=jnp.bfloat16, + ) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(eo): + eo = jax.lax.with_sharding_constraint(eo, NamedSharding(self.mesh, ep_spec_3d)) + toks = jax.lax.with_sharding_constraint(tokens, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec)) + _recv_tokens, recv_w, hm, tc = ep_dispatch( + self.hk, idx, toks, w, self.recv_capacity_per_rank + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, PartitionSpec(("dp", "ep"), None)) + ) + weighted = self._preweight_expert_out(eo, recv_w) + combined = ep_combine(self.hk, hm, tc, weighted, T_global) + # Pin combined to dp-sharded so autodiff transpose feeds + # ep_combine_bwd a per-shard cotangent. + combined = jax.lax.with_sharding_constraint( + combined, NamedSharding(self.mesh, dp_spec) + ) + return 0.5 * (combined.astype(jnp.float32) ** 2).sum() + + loss, grad_eo = jax.jit(jax.value_and_grad(loss_fn))(expert_out) + grad_eo.block_until_ready() + + self.assertTrue(np.isfinite(float(loss))) + self.assertEqual(grad_eo.shape, expert_out.shape) + for shard in grad_eo.addressable_shards: + arr = np.asarray(shard.data.astype(jnp.float32)) + self.assertTrue(np.all(np.isfinite(arr))) + self.assertGreater(arr.max(), 0.0, "grad_eo has no positive entry on filled slots") + np.testing.assert_allclose( + arr.max(), + eo_const / float(TOP_K), + atol=5e-2, + rtol=5e-2, + ) + + def test_dispatch_bwd_exact_per_k_topk_weights(self): + """Distinct per-(t, k) upstream grads ⇒ grad[t, 0] != grad[t, 1] for all t. + + Guards against a regression where the bwd would average across the k + axis (per-token mean instead of per-slot exact recovery). + """ + T_dp, tokens, topk_idx, topk_w = self._make_random_inputs() + dp_spec = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + def loss_fn(idx_in, tok_in, w_in): + idx_in = jax.lax.with_sharding_constraint(idx_in, NamedSharding(self.mesh, dp_spec)) + tok_in = jax.lax.with_sharding_constraint(tok_in, NamedSharding(self.mesh, dp_spec)) + w_in = jax.lax.with_sharding_constraint(w_in, NamedSharding(self.mesh, dp_spec)) + _recv_t, recv_w, _h, _tc = ep_dispatch( + self.hk, idx_in, tok_in, w_in, self.recv_capacity_per_rank + ) + # Per-slot index scale ⇒ each slot's contribution differs. + scale = jnp.asarray( + np.arange(recv_w.size, dtype=np.float32).reshape(recv_w.shape) + 1.0 + ) + return jnp.sum(recv_w * scale) + + grad_topk_w = jax.jit(jax.grad(loss_fn, argnums=2))(topk_idx, tokens, topk_w) + grad_topk_w.block_until_ready() + grad_global = jmu.process_allgather(grad_topk_w, tiled=True) + + if self.rank == 0: + grad_np = np.asarray(grad_global).astype(np.float32) + mismatch = sum(int(abs(grad_np[t, 0] - grad_np[t, 1]) < 1e-6) for t in range(T_dp)) + self.assertEqual( + mismatch, + 0, + f"Expected grad[t, 0] != grad[t, 1] for all {T_dp} tokens under skewed " + f"upstream scaling; got {mismatch} tokens with grad[t, 0] == grad[t, 1].", + ) + + # ── HLO reshard guard ──────────────────────────────────────────────── + # Compile-only: assert XLA inserts no cross-device collectives outside + # the EP FFI. EP-axis flux is carried by the FFI itself. + + def test_z_no_unexpected_reshard_in_hlo_fwd(self): + """Compiled fwd HLO must not insert XLA collectives outside the EP FFI.""" + T_dp, tokens, topk_idx, topk_w = self._make_random_inputs() + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + @jax.jit + def run(idx, toks, w): + idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec)) + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec)) + recv_t, recv_w, hm, tc = ep_dispatch( + self.hk, idx, toks, w, self.recv_capacity_per_rank + ) + recv_t = jax.lax.with_sharding_constraint( + recv_t, NamedSharding(self.mesh, ep_spec_3d) + ) + recv_w = jax.lax.with_sharding_constraint( + recv_w, NamedSharding(self.mesh, ep_spec_2d) + ) + weighted = self._preweight_expert_out(recv_t, recv_w) + out = ep_combine(self.hk, hm, tc, weighted, T_dp, out_sharding=(("dp", "ep"), None)) + return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec)) + + compiled = run.lower(topk_idx, tokens, topk_w).compile() + hlo = compiled.as_text() + # Match instruction names; "all-gather-start" and "all-gather-done" + # bracket a single async all-gather. + for op in ("all-gather-start", "all-to-all", "collective-permute"): + self.assertEqual(hlo.count(op), 0, f"unexpected XLA {op} in fwd HLO:\n{hlo}") + # XLA drops trailing-None entries from the spec; compare as a tuple. + # JAX collapses size-1 mesh axes, so dp=1 reduces ("dp","ep") to "ep". + expected = (("dp", "ep"),) if self.dp > 1 else ("ep",) + self.assertEqual(tuple(compiled.output_shardings.spec), expected) + + def test_z_no_unexpected_reshard_in_hlo_bwd(self): + """Compiled bwd HLO must not insert XLA collectives outside the EP FFI.""" + T_dp, tokens, topk_idx, topk_w = self._make_random_inputs() + rng = np.random.default_rng(seed=44) + expert_out = jnp.asarray( + rng.standard_normal( + (self.dp * self.ep, self.recv_capacity_per_rank, HIDDEN_DIM), dtype=np.float32 + ) + * 0.5, + dtype=jnp.bfloat16, + ) + dp_spec = PartitionSpec(("dp", "ep"), None) + ep_spec_3d = PartitionSpec(("dp", "ep"), None, None) + ep_spec_2d = PartitionSpec(("dp", "ep"), None) + + with self.mesh, global_shard_guard(self.mr): + + def fwd(eo, toks, idx, w): + eo = jax.lax.with_sharding_constraint(eo, NamedSharding(self.mesh, ep_spec_3d)) + toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec)) + idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec)) + w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec)) + _rt, rw, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank) + rw = jax.lax.with_sharding_constraint(rw, NamedSharding(self.mesh, ep_spec_2d)) + weighted = self._preweight_expert_out(eo, rw) + combined = ep_combine( + self.hk, hm, tc, weighted, T_dp, out_sharding=(("dp", "ep"), None) + ) + return jax.lax.with_sharding_constraint(combined, NamedSharding(self.mesh, dp_spec)) + + # jax.vjp + pinned cotangent feeds ep_combine_bwd/ep_dispatch_bwd + # the expected sharding without relying on XLA-transpose propagation. + def bwd_only(eo, toks, idx, w, g): + _y, vjp_fn = jax.vjp(fwd, eo, toks, idx, w) + g = jax.lax.with_sharding_constraint(g, NamedSharding(self.mesh, dp_spec)) + grads = vjp_fn(g) + return ( + jax.lax.with_sharding_constraint( + grads[0], NamedSharding(self.mesh, ep_spec_3d) + ), + jax.lax.with_sharding_constraint(grads[1], NamedSharding(self.mesh, dp_spec)), + ) + + g_seed = jnp.ones((T_dp, HIDDEN_DIM), dtype=jnp.bfloat16) + compiled = ( + jax.jit(bwd_only).lower(expert_out, tokens, topk_idx, topk_w, g_seed).compile() + ) + hlo = compiled.as_text() + for op in ("all-gather-start", "all-to-all", "collective-permute"): + self.assertEqual(hlo.count(op), 0, f"unexpected XLA {op} in bwd HLO:\n{hlo}") + + +# ── Entry point ────────────────────────────────────────────────────────────── + + +if __name__ == "__main__": + if len(sys.argv) < 4: + print("Usage: python test_multi_process_ep.py ") + sys.exit(1) + + coord_addr = sys.argv[1] + proc_id = int(sys.argv[2]) + num_procs = int(sys.argv[3]) + + jax.distributed.initialize( + coordinator_address=coord_addr, + num_processes=num_procs, + process_id=proc_id, + local_device_ids=[proc_id], + ) + + loader = unittest.TestLoader() + target = os.environ.get("TARGET_TEST") + if target: + name = target.split(".")[-1] + suite = loader.loadTestsFromName(name, TestEP) + else: + suite = loader.loadTestsFromTestCase(TestEP) + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + sys.exit(0 if result.wasSuccessful() else 1) diff --git a/tests/jax/test_te_ep_moe.py b/tests/jax/test_te_ep_moe.py new file mode 100644 index 0000000000..428379d3bd --- /dev/null +++ b/tests/jax/test_te_ep_moe.py @@ -0,0 +1,734 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Multi-process (one-GPU-per-process) tests for the TE-EP MoE custom_vjp. + +The launcher ``tests/jax/run_te_ep_moe.sh`` forks one pytest process per +visible GPU (mirroring ``run_multiprocess_moe_vjp.sh``). Each process binds +to exactly one device via +``jax.distributed.initialize(..., local_device_ids=process_id)``; the +participating processes form a global ``(ep, fsdp)`` mesh through JAX's +distributed runtime. + +How to run +---------- + +You typically do NOT invoke pytest on this file directly -- use the +launcher, which passes ``--num-process=N --process-id=i`` to each +forked process. Driving it directly with only one process will skip +every test because :func:`jax.distributed.initialize` requires +multiple participants, and the TE EP NCCL primitives require at +least four ranks. + + bash tests/jax/run_te_ep_moe.sh + +What this suite covers +---------------------- + +This file is the TE-EP-only successor to ``test_moe_vjp.py`` and +``test_multiprocess_moe_vjp.py``. Each test exercises one MoE-block +run and bundles every check that single run supports — shape, dtype, +finiteness AND numerical parity vs a pure-JAX reference. Variations +on the block are pytest parametrize values rather than separate test +classes: + +* ``test_forward`` covers the forward across a curated set of + configurations (softmax/sigmoid scoring, optional non-zero + expert_bias). Each config asserts shape, dtype, finiteness and + numerical parity vs the reference in one run. +* ``test_backward`` mirrors that for gradients. +* ``TestTeEpMoeAuxLoss`` covers the second return value end-to-end + (returned + parity + aux-only grad propagates to gate + combined + main+aux grads stay finite) in two consolidated tests. + +Intentional non-coverage: + +* No dedicated "Flax wrapper init+apply" smoke test: every config above + already calls ``MoEBlock`` (the Flax wrapper) end-to-end, so a + separate wrapper smoke would just duplicate ``test_forward[softmax]`` + + ``test_backward[softmax]``. +* No re-bootstrap-mismatch test: ``ep_bootstrap`` rejects a mismatched + signature unconditionally and is a one-line guard; covering it from + this suite would taint the per-process NCCL bootstrap cache for the + rest of the file with no real upside. + +FP8 / MXFP8 recipes are deferred — the ``quantizer_sets`` plumbing +has not yet been re-wired across the TE-EP ``shard_map`` boundary +(see ``.pr3036-review/INTEGRATION_DESIGN.md``). +""" + +import os + +os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") +os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.5") + +import sys +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from jax.experimental import mesh_utils +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +from flax.linen import partitioning as nn_partitioning + + +def _init_distributed(num_process: int, process_id: int) -> bool: + """Initialize jax.distributed for this pytest process. + + Returns True on a real multi-process launch, False otherwise so + the module can fast-skip when pytest collects it without the + launcher. + """ + if num_process <= 1: + return False + coord = os.environ.get("TE_EP_MOE_COORDINATOR_ADDRESS", "127.0.0.1:13457") + jax.distributed.initialize( + coordinator_address=coord, + num_processes=num_process, + process_id=process_id, + local_device_ids=process_id, + ) + assert jax.local_device_count() == 1, "one GPU per process is required for TE EP" + assert ( + jax.device_count() == num_process + ), f"global device_count {jax.device_count()} != num_process {num_process}" + return True + + +def _read_mp_options(): + num = int(os.environ.get("MP_NUM_PROCESS", "0") or "0") + pid = int(os.environ.get("MP_PROCESS_ID", "0") or "0") + for i, a in enumerate(sys.argv): + if a.startswith("--num-process="): + num = int(a.split("=", 1)[1]) + elif a == "--num-process" and i + 1 < len(sys.argv): + num = int(sys.argv[i + 1]) + elif a.startswith("--process-id="): + pid = int(a.split("=", 1)[1]) + elif a == "--process-id" and i + 1 < len(sys.argv): + pid = int(sys.argv[i + 1]) + return num, pid + + +_MP_NUM_PROCESS, _MP_PROCESS_ID = _read_mp_options() +_MP_ACTIVE = _init_distributed(_MP_NUM_PROCESS, _MP_PROCESS_ID) + +if not _MP_ACTIVE: + pytest.skip( + "test_te_ep_moe.py requires the multiprocess launcher (run_te_ep_moe.sh). Skipping.", + allow_module_level=True, + ) + +from transformer_engine_jax import get_device_compute_capability + +# Grouped GEMM in the MoE custom_vjp requires Blackwell (sm_100+). The +# TE EP NCCL primitives themselves need SM>=90, but the FFN body uses +# grouped_gemm, so the file as a whole gates on sm_100+. +if get_device_compute_capability(0) < 100: + pytest.skip( + "MoE TE EP tests require Blackwell (sm_100+) for grouped GEMM", + allow_module_level=True, + ) + +from transformer_engine.jax.flax import _MoEBlock as MoEBlock +from transformer_engine.jax.moe import moe, record_ep_bootstrap_signature_for_moe +from transformer_engine.jax.ep import ep_bootstrap +from transformer_engine.jax.sharding import MeshResource, global_shard_guard + + +# ----------------------------------------------------------------------------- +# Mesh / shape config +# ----------------------------------------------------------------------------- + +EP_AXIS = "ep" +FSDP_AXIS = "fsdp" +EP_SIZE = 2 +assert ( + jax.device_count() % EP_SIZE == 0 +), f"device_count {jax.device_count()} must be divisible by EP_SIZE={EP_SIZE}" +FSDP_SIZE = jax.device_count() // EP_SIZE +NUM_DEVICES_REQUIRED = EP_SIZE * FSDP_SIZE + +LOGICAL_AXIS_RULES = ( + ("exp", EP_AXIS), + ("embed", FSDP_AXIS), + ("mlp", None), + ("batch", (EP_AXIS, FSDP_AXIS)), +) + +# Small shapes so the parity tests stay tight on bf16. The block still +# has all four ranks participating in dispatch/combine. +DTYPE = jnp.bfloat16 +BATCH = EP_SIZE * FSDP_SIZE * 2 # 8 on 4-GPU, 16 on 8-GPU +SEQ = 32 +HIDDEN = 64 +INTER = 128 +NUM_EXPERTS = 8 +TOPK = 2 + +# bf16 grouped_gemm + softmax-topk + ep all-to-all stack drifts ~1e-1 vs a +# fp32 numpy reference. Keep these tight enough to catch real bugs but +# loose enough to absorb expected bf16 rounding. +FWD_ATOL = 5e-2 +FWD_RTOL = 5e-2 +GRAD_FFN_ATOL = 1e-1 +GRAD_FFN_RTOL = 1e-1 +GRAD_GATE_ATOL = 5e-1 +GRAD_GATE_RTOL = 5e-1 + +# Two TE EP runs that should be bitwise-equal modulo XLA fusion order +# (slot alignment rounding, etc.). +TE_TO_TE_ATOL = 5e-3 +TE_TO_TE_RTOL = 5e-3 + +# Aux loss is computed in float32 from the SAME logits as the routing +# path. Numerical drift between TE-EP and the reference is dominated by +# the bf16-rounded softmax inside the topk kernel. +AUX_ATOL = 1e-3 +AUX_RTOL = 1e-3 + + +# ----------------------------------------------------------------------------- +# Fixtures +# ----------------------------------------------------------------------------- + + +def _compute_worst_case_recv_pr(): + """Per-rank recv buffer the bootstrap must reserve. + + NCCL EP's HT path lays out the per-rank receive buffer as + ``[num_local_experts, ep_size * max_tokens_per_rank, hidden]`` + (per the LL combine assertion at ``nccl_ep.cc:2185`` and the + HT IPC buffer sizing at ``nccl_ep.cc:415``). We must mirror that + flattened total or ``ncclEpDispatch`` aborts with + ``invalid argument`` at ``ep_backend.cpp:414``. The moe block + computes ``recv_pr`` the same way (see ``moe.py``'s + ``natural_spe = num_ep * max_tokens_per_rank``); keeping the + bootstrap formula in lock-step here. + """ + num_procs = jax.device_count() + num_local_experts = NUM_EXPERTS // EP_SIZE + max_tokens_per_rank = (BATCH // num_procs) * SEQ + natural_spe = EP_SIZE * max_tokens_per_rank + return num_local_experts * natural_spe + + +@pytest.fixture(scope="module") +def mesh(): + if jax.device_count() < NUM_DEVICES_REQUIRED: + pytest.skip( + f"Need >={NUM_DEVICES_REQUIRED} devices for ep={EP_SIZE} x fsdp={FSDP_SIZE};" + f" have {jax.device_count()}" + ) + # ``ep`` must be the inner axis: ``ep_bootstrap`` forms NCCL EP groups + # from consecutive global ranks via ``dp_color = rank // ep_size``, so + # only an (outer_fsdp, inner_ep) device layout groups ranks correctly. + devices = mesh_utils.create_device_mesh((FSDP_SIZE, EP_SIZE)) + mesh_obj = Mesh(devices, axis_names=(FSDP_AXIS, EP_AXIS)) + + num_procs = jax.process_count() + max_tokens_per_rank = (BATCH // num_procs) * SEQ + recv_capacity_per_rank = _compute_worst_case_recv_pr() + + # Eager bootstrap: ep_bootstrap does a host-side NCCL UID allgather + # and cannot run from inside jax.jit. Sized to the worst-case recv_pr + # across _CONFIGS so every parametrized config is bootstrap-compatible. + with mesh_obj, global_shard_guard(MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS)): + ep_bootstrap( + world_size=num_procs, + rank=jax.process_index(), + num_experts=NUM_EXPERTS, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=HIDDEN, + max_token_dtype=DTYPE, + ) + record_ep_bootstrap_signature_for_moe( + num_experts=NUM_EXPERTS, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=HIDDEN, + ep_size=EP_SIZE, + ) + return mesh_obj + + +# ----------------------------------------------------------------------------- +# Pure-JAX reference MoE (no EP). Mirrors the exact math of TE's fused +# router primitive (see tests/jax/test_fused_router.py for the same +# reference applied to the standalone router kernel): +# +# softmax + post-softmax (use_pre_softmax=False, the default): +# 1. top_k by raw logits +# 2. softmax over just the K selected logits (so weights sum to 1) +# +# sigmoid + optional expert_bias: +# 1. scores = sigmoid(logits) +# 2. top_k by (scores + expert_bias) [bias only steers selection] +# 3. weights = scores at top_k positions, normalized when K > 1 +# +# Then for both: +# * weights *= scaling_factor (we leave scaling_factor=1.0 in this +# suite, matching _make_block's default). +# * per-expert FFN: silu(layer_w0) * layer_w1 → wo. +# ----------------------------------------------------------------------------- + + +@partial( + jax.jit, + static_argnames=( + "num_experts", + "num_experts_per_tok", + "aux_loss_coeff", + "score_function", + ), +) +def _pure_jax_moe_reference( + x, + gate_kernel, + wi_0, + wi_1, + wo, + expert_bias=None, + *, + num_experts, + num_experts_per_tok, + aux_loss_coeff: float = 0.0, + score_function: str = "softmax", +): + B, S, H = x.shape + T = B * S + K = num_experts_per_tok + x_2d = x.reshape(T, H) + + gate_kernel_cast = gate_kernel.astype(x.dtype) + logits = (x_2d @ gate_kernel_cast).astype(jnp.float32) # [T, E] + + if score_function == "softmax": + # use_pre_softmax=False: topk on raw logits, then softmax over K. + top_logits, top_indices = jax.lax.top_k(logits, k=K) + weights = jax.nn.softmax(top_logits, axis=-1) # [T, K], sums to 1 + elif score_function == "sigmoid": + scores = jax.nn.sigmoid(logits) # [T, E] + if expert_bias is not None and expert_bias.shape != (0,): + scores_for_routing = scores + expert_bias.astype(jnp.float32)[None, :] + _, top_indices = jax.lax.top_k(scores_for_routing, k=K) + weights = jnp.take_along_axis(scores, top_indices, axis=-1) + else: + weights, top_indices = jax.lax.top_k(scores, k=K) + # Sigmoid weights are normalized when K > 1 (matches the kernel). + if K > 1: + weights = weights / (weights.sum(axis=-1, keepdims=True) + 1e-20) + else: + raise ValueError(f"Unsupported score_function={score_function!r}") + + routing_weights_full = jnp.zeros((T, num_experts), dtype=jnp.float32) + routing_weights_full = routing_weights_full.at[jnp.arange(T)[:, None], top_indices].set(weights) + + # FFN. ``apply_topk_weights_early`` is a fusion knob that doesn't + # change the math (wo is linear), so the reference is identical for + # both placements. + layer_w0 = jnp.einsum("th,ehm->tem", x_2d, wi_0) + layer_w1 = jnp.einsum("th,ehm->tem", x_2d, wi_1) + # Activation runs in x.dtype (typically bf16) to mirror the impl -- + # the impl keeps silu+multiply in the wi GEMM output dtype because + # storing higher precision than the consumer (wo) GEMM buys nothing. + intermediate = jax.nn.silu(layer_w0) * layer_w1 + expert_out = jnp.einsum("tem,emh->teh", intermediate, wo) # [T, E, H] + output_2d = jnp.einsum("te,teh->th", routing_weights_full.astype(x.dtype), expert_out) + output = output_2d.reshape(B, S, H).astype(x.dtype) + + if aux_loss_coeff > 0.0: + # tex.fused_moe_aux_loss formula (matches the same + # reference_aux_loss helper from test_fused_router.py). The + # "aux scores" use the same score_function but always with + # K-normalised sigmoid (when sigmoid) / plain softmax (when + # softmax) — see tex.fused_topk_with_score_function_fwd with + # compute_aux_scores=True. + if score_function == "softmax": + aux_scores = jax.nn.softmax(logits, axis=-1) + else: # sigmoid + aux_scores = jax.nn.sigmoid(logits) + if K > 1: + aux_scores = aux_scores / (aux_scores.sum(axis=-1, keepdims=True) + 1e-20) + routing_map = (routing_weights_full > 0).astype(jnp.int32) + tokens_per_expert = jnp.sum(routing_map, axis=0) # [E] + sum_probs_per_expert = jnp.sum(aux_scores, axis=0) # [E] + aux_loss = (num_experts * aux_loss_coeff / (K * (T**2))) * jnp.sum( + sum_probs_per_expert * tokens_per_expert.astype(jnp.float32) + ) + aux_loss = aux_loss.astype(x.dtype) + else: + aux_loss = jnp.zeros((), dtype=x.dtype) + return output, aux_loss + + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- + + +def _make_block( + *, + apply_topk_weights_early=False, + aux_loss_coeff=0.0, + use_expert_routing_bias=False, + score_function="softmax", + bias_init=None, +): + kwargs = dict( + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + intermediate_size=INTER, + data_parallelism_axes=(FSDP_AXIS,), + apply_topk_weights_early=apply_topk_weights_early, + aux_loss_coeff=aux_loss_coeff, + use_expert_routing_bias=use_expert_routing_bias, + score_function=score_function, + dtype=DTYPE, + ) + # Custom bias_init lets tests inject a non-zero expert_bias without + # poking variables['params'] post-init. + if bias_init is not None: + kwargs["bias_init"] = bias_init + return MoEBlock(**kwargs) + + +def _strong_expert_bias_init(key, shape, dtype): + """Half +5, half -5 — large enough to force topk onto the +ve half.""" + del key + n = shape[0] + return jnp.concatenate( + [ + jnp.full((n // 2,), 5.0, dtype=dtype), + jnp.full((n - n // 2,), -5.0, dtype=dtype), + ] + ) + + +def _shard_inputs(x, mesh): + # Match the layout moe.py re-pins to: outer dp axes, then ep innermost. + return jax.lax.with_sharding_constraint( + x, NamedSharding(mesh, P((FSDP_AXIS, EP_AXIS), None, None)) + ) + + +def _ctx(mesh): + """Combined mesh + global_shard_guard + axis_rules context.""" + + class _Combo: + def __enter__(self_inner): + self_inner._m = mesh.__enter__() + self_inner._gs = global_shard_guard( + MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS) + ) + self_inner._gs.__enter__() + self_inner._ar = nn_partitioning.axis_rules(LOGICAL_AXIS_RULES) + self_inner._ar.__enter__() + return self_inner._m + + def __exit__(self_inner, *args): + self_inner._ar.__exit__(*args) + self_inner._gs.__exit__(*args) + mesh.__exit__(*args) + + return _Combo() + + +def _init_apply(block, mesh, x, key): + with _ctx(mesh): + x_sh = _shard_inputs(x, mesh) + variables = jax.jit(block.init)(key, x_sh) + jax.block_until_ready(jax.tree_util.tree_leaves(variables)[0]) + output, aux = jax.jit(block.apply)(variables, x_sh) + jax.block_until_ready(output) + return variables, output, aux + + +def _grad_step(block, variables, mesh, x, *, include_aux=False): + """Run jax.grad of mean(out^2) [+ aux if include_aux] vs params.""" + with _ctx(mesh): + x_sh = _shard_inputs(x, mesh) + + def loss_fn(variables, x): + output, aux = block.apply(variables, x) + loss = jnp.mean(output.astype(jnp.float32) ** 2) + if include_aux and aux is not None: + loss = loss + aux.astype(jnp.float32) + return loss + + grads = jax.jit(jax.grad(loss_fn))(variables, x_sh) + jax.block_until_ready(jax.tree_util.tree_leaves(grads)[0]) + return grads + + +def _grad_aux_only(block, variables, mesh, x): + """Jit'd grad of just the aux loss scalar — proves it reaches the + gate even when no main-output contribution is present.""" + with _ctx(mesh): + x_sh = _shard_inputs(x, mesh) + + def aux_only(variables, x): + _, aux = block.apply(variables, x) + return aux.astype(jnp.float32) + + grads = jax.jit(jax.grad(aux_only))(variables, x_sh) + jax.block_until_ready(jax.tree_util.tree_leaves(grads)[0]) + return grads + + +def _unwrap(x): + return x.value if hasattr(x, "value") else x + + +def _to_global_numpy(arr, mesh): + """Replicate a sharded JAX array onto every rank and return as numpy. + + Triggers an all-gather inside JIT. The resulting addressable_data(0) + contains the full global array on every process, so we can run the + pure-JAX reference and compare against it from any process. + """ + rep = NamedSharding(mesh, P()) + with mesh: + full = jax.jit(lambda a: jax.lax.with_sharding_constraint(a, rep))(arr) + full.block_until_ready() + return np.asarray(jax.device_get(full.addressable_data(0))) + + +def _params_global_numpy(variables, mesh): + """Pull every entry of variables['params'] to a replicated numpy array.""" + params = variables["params"] + return {name: _to_global_numpy(_unwrap(p), mesh) for name, p in params.items()} + + +def _make_inputs(key): + """Generate a globally-identical input tensor on every process.""" + return jax.random.normal(key, (BATCH, SEQ, HIDDEN), dtype=DTYPE) + + +# ----------------------------------------------------------------------------- +# Tests +# ----------------------------------------------------------------------------- + + +# ----------------------------------------------------------------------------- +# Parametrize variants exercised by both the forward and the backward +# parity tests. Each config is one MoE-block configuration the suite +# wants covered; the test body checks shape, dtype, finiteness AND +# numerical parity vs the same pure-JAX reference (which understands +# the same set of knobs). +# ----------------------------------------------------------------------------- + +_CONFIGS = [ + pytest.param( + dict(score_function="softmax"), + id="softmax", + ), + # TODO: re-add the apply_topk_weights_early=True config once the + # 0*NaN -> NaN leak from padded recv slots in the early-weighting + # multiply (intermediate * recv_w * mask) is debugged. Late + # weighting (combine-side) is unaffected and stays covered above. + # Note: align_size is no longer a user-facing parameter; it is + # hard-coded to _ALIGN_SIZE = 128 in moe.py. Re-add a distinct + # align-size config only if the constant is loosened, or a + # recipe-driven inference is added that selects a >128 alignment. + pytest.param( + dict(score_function="sigmoid"), + id="sigmoid", + ), + # NOTE: a ``sigmoid-bias-zero`` config (use_expert_routing_bias=True + # with a zero-initialised bias buffer) was previously exercised + # here. It was dropped because the routing math collapses to the + # no-bias case when the buffer is zero -- ``sigmoid`` already + # covers that numerical path. The bias-aware codepath is still + # exercised by ``sigmoid-bias-strong`` below, which uses a + # non-zero bias. + pytest.param( + dict( + score_function="sigmoid", + use_expert_routing_bias=True, + bias_init=_strong_expert_bias_init, + ), + id="sigmoid-bias-strong", + ), +] + + +def _reference_kwargs_from_config(config, params_np): + """Pick out the reference-relevant pieces of a parametrize config.""" + return dict( + score_function=config.get("score_function", "softmax"), + expert_bias=( + jnp.asarray(params_np["expert_bias"]) + if config.get("use_expert_routing_bias", False) + else None + ), + ) + + +class TestTeEpMoeForward: + """Per-config forward correctness in a single run: shape, dtype, + finiteness AND numerical parity vs the pure-JAX reference.""" + + @pytest.mark.parametrize("config", _CONFIGS) + def test_forward(self, mesh, config): + block = _make_block(**config) + x = _make_inputs(jax.random.PRNGKey(0)) + variables, output, aux = _init_apply(block, mesh, x, jax.random.PRNGKey(1)) + + # Shape / dtype / finiteness (cheap; on the local shard). + assert output.shape == x.shape + assert output.dtype == x.dtype + out_local = np.asarray(jax.device_get(output.addressable_data(0))) + assert np.all(np.isfinite(out_local)), "output has NaN/Inf" + assert aux is None, "aux_loss should be None when aux_loss_coeff == 0" + + # Numerical parity (replicated global view -> single rank's numpy). + params_np = _params_global_numpy(variables, mesh) + x_np = np.asarray(jax.device_get(x)) + out_te_np = _to_global_numpy(output, mesh) + + out_ref, _ = _pure_jax_moe_reference( + jnp.asarray(x_np), + jnp.asarray(params_np["gate_kernel"]), + jnp.asarray(params_np["wi_0"]), + jnp.asarray(params_np["wi_1"]), + jnp.asarray(params_np["wo"]), + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + **_reference_kwargs_from_config(config, params_np), + ) + np.testing.assert_allclose( + out_te_np.astype(np.float32), + np.asarray(jax.device_get(out_ref)).astype(np.float32), + atol=FWD_ATOL, + rtol=FWD_RTOL, + err_msg=f"forward parity breach for config={config}", + ) + + +class TestTeEpMoeBackward: + """Per-config backward correctness in a single run: per-tensor + grads finite, non-zero AND parity vs the pure-JAX reference.""" + + @pytest.mark.parametrize("config", _CONFIGS) + def test_backward(self, mesh, config): + block = _make_block(**config) + x = _make_inputs(jax.random.PRNGKey(2)) + variables, _, _ = _init_apply(block, mesh, x, jax.random.PRNGKey(3)) + grads_te = _grad_step(block, variables, mesh, x) + + # Reference grads via jax.grad over the pure-JAX MoE with the + # same config. + params_np = _params_global_numpy(variables, mesh) + x_np = np.asarray(jax.device_get(x)) + ref_kwargs = _reference_kwargs_from_config(config, params_np) + ref_expert_bias = ref_kwargs.pop("expert_bias") + + def loss_fn(params, x): + out, _ = _pure_jax_moe_reference( + x, + params["gate_kernel"], + params["wi_0"], + params["wi_1"], + params["wo"], + ref_expert_bias, + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + **ref_kwargs, + ) + return jnp.mean(out.astype(jnp.float32) ** 2) + + grads_ref = jax.jit(jax.grad(loss_fn))( + {k: jnp.asarray(v) for k, v in params_np.items() if k != "expert_bias"}, + jnp.asarray(x_np), + ) + grads_ref_np = {k: np.asarray(jax.device_get(v)) for k, v in grads_ref.items()} + + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + # Per-tensor: finite + non-zero + parity in one pass. + g_te = _to_global_numpy(_unwrap(grads_te["params"][name]), mesh) + assert np.all(np.isfinite(g_te)), f"{name} grad has NaN/Inf [config={config}]" + assert np.any(g_te != 0.0), f"{name} grad identically zero [config={config}]" + atol, rtol = ( + (GRAD_GATE_ATOL, GRAD_GATE_RTOL) + if name == "gate_kernel" + else (GRAD_FFN_ATOL, GRAD_FFN_RTOL) + ) + np.testing.assert_allclose( + g_te.astype(np.float32), + grads_ref_np[name].astype(np.float32), + atol=atol, + rtol=rtol, + err_msg=f"grad parity breach on {name} [config={config}]", + ) + + +class TestTeEpMoeAuxLoss: + """Aux-loss path. Consolidated into: + * ``test_aux_loss``: one run that checks the returned scalar's + shape / dtype / finiteness / magnitude AND numerical parity vs the + reference AND that the aux-only bwd propagates to gate_kernel. + * ``test_combined_loss_grads``: one run for joint main+aux bwd + finite + non-zero per tensor. + """ + + def test_aux_loss(self, mesh): + coeff = 1e-2 + block = _make_block(aux_loss_coeff=coeff) + x = _make_inputs(jax.random.PRNGKey(20)) + variables, _, aux = _init_apply(block, mesh, x, jax.random.PRNGKey(21)) + + # Shape / dtype / finiteness / magnitude. + assert aux is not None, "aux_loss should be returned when coeff > 0" + assert aux.shape == (), f"aux_loss must be 0-d scalar, got {aux.shape}" + assert aux.dtype == DTYPE, f"aux_loss dtype {aux.dtype} != {DTYPE}" + aux_np = _to_global_numpy(aux, mesh) + assert np.isfinite(aux_np), "aux_loss is NaN/Inf" + assert abs(float(aux_np)) < 1e2, f"aux_loss looks unreasonable: {aux_np}" + + # Numerical parity vs the reference. + params_np = _params_global_numpy(variables, mesh) + x_np = np.asarray(jax.device_get(x)) + _, aux_ref = _pure_jax_moe_reference( + jnp.asarray(x_np), + jnp.asarray(params_np["gate_kernel"]), + jnp.asarray(params_np["wi_0"]), + jnp.asarray(params_np["wi_1"]), + jnp.asarray(params_np["wo"]), + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + aux_loss_coeff=coeff, + ) + np.testing.assert_allclose( + float(aux_np), + float(jax.device_get(aux_ref)), + atol=AUX_ATOL, + rtol=AUX_RTOL, + ) + + # Aux-only bwd must propagate to gate_kernel — proves the + # fused_moe_aux_loss_bwd → topk(compute_aux_scores)_bwd chain is + # wired. + aux_grads = _grad_aux_only(block, variables, mesh, x) + g_gate = np.asarray( + jax.device_get(_unwrap(aux_grads["params"]["gate_kernel"]).addressable_data(0)) + ) + assert np.all(np.isfinite(g_gate)), "gate grad NaN/Inf under aux-only loss" + assert np.any(g_gate != 0.0), "aux bwd should propagate to gate_kernel" + + def test_combined_loss_grads(self, mesh): + """Joint main + aux loss bwd: per-tensor finite + non-zero in + one pass.""" + block = _make_block(aux_loss_coeff=1e-2) + x = _make_inputs(jax.random.PRNGKey(22)) + variables, _, _ = _init_apply(block, mesh, x, jax.random.PRNGKey(23)) + grads = _grad_step(block, variables, mesh, x, include_aux=True) + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + g_local = np.asarray(jax.device_get(_unwrap(grads["params"][name]).addressable_data(0))) + assert np.all(np.isfinite(g_local)), f"{name} grad NaN/Inf under main+aux" + assert np.any(g_local != 0.0), f"{name} grad zero under main+aux" diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 8f96432ed8..6f5117ef08 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -437,6 +437,98 @@ if (NVTE_WITH_CUSOLVERMP) message(STATUS "Using cuSolverMp at: ${CUSOLVERMP_DIR}") endif() +# -- NCCL EP (on by default, HT mode only) --------------------------------- +# Set -DNVTE_WITH_NCCL_EP=OFF (or NVTE_BUILD_WITH_NCCL_EP=0 in setup.py) to +# skip NCCL EP entirely - useful on older images whose system NCCL is below +# the 2.30.4 EP minimum. +option(NVTE_WITH_NCCL_EP "Build NCCL EP into libtransformer_engine.so" ON) +if(NVTE_WITH_NCCL_EP) +# SM>=90 and NCCL>=2.30.4 are gated at runtime in EPBackend::initialize. +# -- NCCL EP headers -------------------------------------------------------- +# Headers + libs are produced by the in-tree 3rdparty/nccl submodule build +# (auto-built by setup.py via build_nccl_ep_submodule). +set(NCCL_EP_SUBMODULE_ROOT + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/nccl") +set(NCCL_EP_INCLUDE_DIR "${NCCL_EP_SUBMODULE_ROOT}/build/include") +if(NOT EXISTS "${NCCL_EP_INCLUDE_DIR}/nccl_ep.h") + message(FATAL_ERROR + "NCCL EP header not found at ${NCCL_EP_INCLUDE_DIR}/nccl_ep.h. " + "Run `git submodule update --init --recursive` and rebuild TE.") +endif() +message(STATUS "NCCL EP headers: ${NCCL_EP_INCLUDE_DIR}") + +# -- libnccl_ep.so ---------------------------------------------------------- +# Resolved at runtime via dlopen in ep/ep_nccl_loader.cpp - NOT link-time bound, +# so libtransformer_engine.so still loads on systems missing libnccl_ep.so or +# with too-old NCCL. Locate the build artifact only to keep it on the rpath +# (so dlopen by SONAME finds the bundled copy via DT_RUNPATH). +set(NCCL_EP_LIB_DIR "${NCCL_EP_SUBMODULE_ROOT}/build/lib") +find_library(NCCL_EP_LIB + NAMES nccl_ep libnccl_ep + HINTS ${NCCL_EP_LIB_DIR} + NO_DEFAULT_PATH + REQUIRED) + +# -- NCCL core: nccl.h + libnccl.so ----------------------------------------- +# setup.py passes -DNCCL_INCLUDE_DIR; standalone CMake falls back to probing +# well-known NCCL install prefixes. +find_path(NCCL_INCLUDE_DIR nccl.h + HINTS /opt/nvidia/nccl/include /usr/local/nccl/include) +if(NOT NCCL_INCLUDE_DIR) + message(FATAL_ERROR + "nccl.h not found. Pass -DNCCL_INCLUDE_DIR=/include.") +endif() +if(NOT NCCL_LIB) + find_library(NCCL_LIB + NAMES nccl libnccl + PATH_SUFFIXES lib lib64 + REQUIRED) +endif() + +# Diagnostic: log detected NCCL header version (minimum enforced at runtime). +file(READ "${NCCL_INCLUDE_DIR}/nccl.h" _nvte_nccl_h) +string(REGEX MATCH "#define[ \t]+NCCL_MAJOR[ \t]+([0-9]+)" _ "${_nvte_nccl_h}") +set(_nvte_nccl_major "${CMAKE_MATCH_1}") +string(REGEX MATCH "#define[ \t]+NCCL_MINOR[ \t]+([0-9]+)" _ "${_nvte_nccl_h}") +set(_nvte_nccl_minor "${CMAKE_MATCH_1}") +string(REGEX MATCH "#define[ \t]+NCCL_PATCH[ \t]+([0-9]+)" _ "${_nvte_nccl_h}") +set(_nvte_nccl_patch "${CMAKE_MATCH_1}") +if(_nvte_nccl_major AND _nvte_nccl_minor AND _nvte_nccl_patch) + message(STATUS "NCCL header: ${NCCL_INCLUDE_DIR}/nccl.h (version ${_nvte_nccl_major}.${_nvte_nccl_minor}.${_nvte_nccl_patch})") +endif() + +target_include_directories(transformer_engine PRIVATE + ${NCCL_EP_INCLUDE_DIR} + ${NCCL_INCLUDE_DIR}) + +# libnccl_ep.so is dlopen'd from ep_nccl_loader.cpp, so do NOT link it here. +# libnccl.so stays direct-linked: only ancient symbols (ncclGetVersion, +# ncclCommCount, ncclGetErrorString) are referenced from this TU. +target_link_libraries(transformer_engine PUBLIC + ${NCCL_LIB} + ${CMAKE_DL_LIBS}) + +# rpath for dlopen("libnccl_ep.so.0"): in-tree build dir for dev, $ORIGIN for +# the wheel (libnccl_ep.so.0 ships beside libtransformer_engine.so). +# libnccl.so: resolved by the dynamic linker via its configured paths. +set_target_properties(transformer_engine PROPERTIES + BUILD_RPATH "${NCCL_EP_LIB_DIR}" + INSTALL_RPATH "$ORIGIN") + +target_sources(transformer_engine PRIVATE + ep/ep_backend.cpp + ep/ep_api.cpp + ep/ep_nccl_loader.cpp) +target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_NCCL_EP) + +message(STATUS "NCCL EP enabled (dlopen at runtime): ${NCCL_EP_LIB}") +message(STATUS "NCCL EP include: ${NCCL_EP_INCLUDE_DIR}") +else() + # NCCL EP off: ep_api.cpp's #else branch exports throwing nvte_ep_* stubs. + target_sources(transformer_engine PRIVATE ep/ep_api.cpp) + message(STATUS "NCCL EP disabled (NVTE_WITH_NCCL_EP=OFF) - using nvte_ep_* stubs") +endif() + # Number of philox4x32 rounds for stochastic rounding (build-time constant). set(NVTE_BUILD_NUM_PHILOX_ROUNDS_STR $ENV{NVTE_BUILD_NUM_PHILOX_ROUNDS}) if (NOT NVTE_BUILD_NUM_PHILOX_ROUNDS_STR) diff --git a/transformer_engine/common/ep/ep_api.cpp b/transformer_engine/common/ep/ep_api.cpp new file mode 100644 index 0000000000..b8cf04aa4a --- /dev/null +++ b/transformer_engine/common/ep/ep_api.cpp @@ -0,0 +1,130 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_api.cpp + * \brief nvte_ep_* C API: thin delegations to the EPBackend singleton. + * + * When NVTE_WITH_NCCL_EP is undefined, the entry points become throwing + * stubs so framework bindings still link without NCCL EP support. + */ + +#include + +#include "../util/logging.h" + +#if defined(NVTE_WITH_NCCL_EP) + +#include + +#include "../common.h" +#include "ep_backend.h" + +using transformer_engine::ep::EPBackend; + +void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config) { + NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null"); + EPBackend::initialize(static_cast(ep_comm), group_config); +} + +void nvte_ep_shutdown(void) { EPBackend::shutdown(); } + +size_t nvte_ep_handle_mem_size(NVTEEpLayerConfig layer_cfg) { + return EPBackend::get().handle_mem_size(layer_cfg); +} + +namespace { +inline void* handle_mem_ptr(NVTETensor mem) { + void* p = nvte_tensor_data(mem); + NVTE_CHECK(p != nullptr, "handle_mem tensor data must not be null"); + return p; +} +} // namespace + +void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor token_counts, + NVTEEpLayerConfig layer_cfg, cudaStream_t stream) { + EPBackend::get().prepare(handle_mem_ptr(handle_mem), topk_idx, token_counts, layer_cfg, stream); +} + +void nvte_ep_dispatch(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor tokens, + NVTECommWindow tokens_win, NVTETensor topk_weights, + NVTECommWindow topk_weights_win, NVTETensor recv_tokens, + NVTECommWindow recv_tokens_win, NVTETensor recv_topk_weights, + NVTECommWindow recv_topk_weights_win, cudaStream_t stream) { + EPBackend::get().dispatch(handle_mem_ptr(handle_mem), topk_idx, tokens, tokens_win, topk_weights, + topk_weights_win, recv_tokens, recv_tokens_win, recv_topk_weights, + recv_topk_weights_win, stream); +} + +void nvte_ep_combine(NVTETensor handle_mem, NVTETensor expert_out, NVTECommWindow expert_out_win, + NVTETensor result, cudaStream_t stream) { + EPBackend::get().combine(handle_mem_ptr(handle_mem), expert_out, expert_out_win, result, stream); +} + +void nvte_ep_dispatch_bwd(NVTETensor handle_mem, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor g_recv_topk_weights, NVTECommWindow g_recv_topk_weights_win, + NVTETensor grad_tokens, NVTETensor grad_topk_weights, + cudaStream_t stream) { + EPBackend::get().dispatch_bwd(handle_mem_ptr(handle_mem), grad, grad_win, g_recv_topk_weights, + g_recv_topk_weights_win, grad_tokens, grad_topk_weights, stream); +} + +void nvte_ep_combine_bwd(NVTETensor handle_mem, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor grad_expert_out, NVTECommWindow grad_expert_out_win, + cudaStream_t stream) { + EPBackend::get().combine_bwd(handle_mem_ptr(handle_mem), grad, grad_win, grad_expert_out, + grad_expert_out_win, stream); +} + +#else // !NVTE_WITH_NCCL_EP - throwing stubs. + +namespace { +[[noreturn]] void ep_not_built() { + NVTE_ERROR( + "NCCL EP is not built into this TransformerEngine. Rebuild TE with " + "NVTE_BUILD_WITH_NCCL_EP=1 and CUDA arch >= 90 (e.g. NVTE_CUDA_ARCHS=\"90\")."); +} +} // namespace + +void nvte_ep_initialize(void* /*ep_comm*/, NVTEEpGroupConfig /*group_config*/) { ep_not_built(); } + +void nvte_ep_shutdown(void) {} + +size_t nvte_ep_handle_mem_size(NVTEEpLayerConfig /*layer_cfg*/) { ep_not_built(); } + +void nvte_ep_prepare(NVTETensor /*handle_mem*/, NVTETensor /*topk_idx*/, + NVTETensor /*token_counts*/, NVTEEpLayerConfig /*layer_cfg*/, + cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_dispatch(NVTETensor /*handle_mem*/, NVTETensor /*topk_idx*/, NVTETensor /*tokens*/, + NVTECommWindow /*tokens_win*/, NVTETensor /*topk_weights*/, + NVTECommWindow /*topk_weights_win*/, NVTETensor /*recv_tokens*/, + NVTECommWindow /*recv_tokens_win*/, NVTETensor /*recv_topk_weights*/, + NVTECommWindow /*recv_topk_weights_win*/, cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_combine(NVTETensor /*handle_mem*/, NVTETensor /*expert_out*/, + NVTECommWindow /*expert_out_win*/, NVTETensor /*result*/, + cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_dispatch_bwd(NVTETensor /*handle_mem*/, NVTETensor /*grad*/, + NVTECommWindow /*grad_win*/, NVTETensor /*g_recv_topk_weights*/, + NVTECommWindow /*g_recv_topk_weights_win*/, NVTETensor /*grad_tokens*/, + NVTETensor /*grad_topk_weights*/, cudaStream_t /*stream*/) { + ep_not_built(); +} + +void nvte_ep_combine_bwd(NVTETensor /*handle_mem*/, NVTETensor /*grad*/, + NVTECommWindow /*grad_win*/, NVTETensor /*grad_expert_out*/, + NVTECommWindow /*grad_expert_out_win*/, cudaStream_t /*stream*/) { + ep_not_built(); +} + +#endif // NVTE_WITH_NCCL_EP diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp new file mode 100644 index 0000000000..b43b01fa73 --- /dev/null +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -0,0 +1,474 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_backend.cpp + * \brief EPBackend implementation. See ep_backend.h for the op flow. + */ + +#include "ep_backend.h" + +#include +#include +#include +#include +#include +#include + +#include "../common.h" +#include "../util/cuda_runtime.h" +#include "../util/logging.h" +#include "ep_nccl_loader.h" + +namespace transformer_engine { +namespace ep { + +namespace { + +ncclDataType_t te_dtype_to_nccl_dtype(NVTEDType dtype) { + switch (dtype) { + case kNVTEFloat32: + return ncclFloat32; + case kNVTEFloat16: + return ncclFloat16; + case kNVTEBFloat16: + return ncclBfloat16; + case kNVTEInt32: + return ncclInt32; + case kNVTEInt64: + return ncclInt64; + case kNVTEByte: + return ncclUint8; + case kNVTEFloat8E4M3: + return ncclFloat8e4m3; + case kNVTEFloat8E5M2: + return ncclFloat8e5m2; + default: + NVTE_ERROR("Unsupported NVTEDType for NCCL dtype conversion: ", static_cast(dtype)); + } + return ncclFloat32; // unreachable +} + +// shape_out is caller-owned; desc.sizes aliases shape_out.data and must +// outlive the NCCL EP call. +inline ncclEpTensor_t make_nccl_ep_tensor(const NVTETensor t, NVTEShape& shape_out, + const NVTECommWindow& win = {}) { + shape_out = nvte_tensor_shape(t); + ncclEpTensor_t desc = NCCL_EP_TENSOR_INIT; + desc.ndim = shape_out.ndim; + desc.sizes = shape_out.data; + desc.datatype = te_dtype_to_nccl_dtype(nvte_tensor_type(t)); + if (win.window != nullptr) { + desc.win_hdl = win.window; + desc.win_offset = win.offset; + } else { + desc.data = nvte_tensor_data(t); + NVTE_CHECK(desc.data != nullptr, "tensor data must not be null"); + } + return desc; +} + +} // namespace + +// --------------------------------------------------------------------------- +// Singleton + bootstrap +// --------------------------------------------------------------------------- + +EPBackend& EPBackend::instance() { + static EPBackend inst; + return inst; +} + +EPBackend& EPBackend::get() { + EPBackend& inst = instance(); + NVTE_CHECK(inst.initialized_, "EPBackend not initialized. Call nvte_ep_initialize() first."); + return inst; +} + +void EPBackend::validate_config(const NVTEEpGroupConfig& config) { + NVTE_CHECK(config.ep_size > 0, "ep_size must be positive, got ", config.ep_size); + NVTE_CHECK(config.num_experts > 0, "num_experts must be positive, got ", config.num_experts); + NVTE_CHECK(config.max_tokens_per_rank > 0, "max_tokens_per_rank must be positive, got ", + config.max_tokens_per_rank); + NVTE_CHECK(config.max_recv_tokens_per_rank > 0, "max_recv_tokens_per_rank must be positive, got ", + config.max_recv_tokens_per_rank); + NVTE_CHECK(config.hidden_dim > 0, "hidden_dim must be positive, got ", config.hidden_dim); + NVTE_CHECK(config.max_token_dtype >= 0 && config.max_token_dtype < kNVTENumTypes, + "max_token_dtype out of range, got ", static_cast(config.max_token_dtype)); + const size_t elem_bytes = typeToSize(static_cast(config.max_token_dtype)); + NVTE_CHECK(config.hidden_dim * elem_bytes >= 16, + "hidden_dim * sizeof(max_token_dtype) must be >= 16 (NCCL EP 16B row alignment); " + "got hidden_dim=", + config.hidden_dim, ", element_bytes=", elem_bytes); + NVTE_CHECK(config.num_experts % config.ep_size == 0, "num_experts (", config.num_experts, + ") must be divisible by ep_size (", config.ep_size, ")"); + NVTE_CHECK(config.max_num_sms >= 0, "max_num_sms must be >= 0 (0 = auto), got ", + config.max_num_sms); + + int device, major; + NVTE_CHECK_CUDA(cudaGetDevice(&device)); + NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + NVTE_CHECK(major >= 9, + "NCCL EP requires SM_90+ (Hopper or later), " + "but current device has compute capability ", + major, ".x"); + + NVTE_CHECK(cuda::supports_multicast(device), "NCCL EP requires CUDA multicast support on device ", + device); +} + +void EPBackend::initialize(ncclComm_t ep_comm, NVTEEpGroupConfig config) { + EPBackend& inst = instance(); + std::lock_guard lock(inst.mutex_); + NVTE_CHECK(!inst.initialized_, "EP already initialized. Call initialize only once per process."); + NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null"); + + // Runtime gate: NCCL >= 2.30.4 (matches the submodule pin). + constexpr int kMinNcclVersion = 23004; + int nccl_version = 0; + NVTE_CHECK_NCCL(ncclGetVersion(&nccl_version)); + NVTE_CHECK(nccl_version >= kMinNcclVersion, "NCCL EP requires NCCL >= 2.30.4, found ", + nccl_version / 10000, ".", (nccl_version / 100) % 100, ".", nccl_version % 100, + " at runtime."); + + validate_config(config); + + int comm_size = 0; + NVTE_CHECK_NCCL(ncclCommCount(ep_comm, &comm_size)); + NVTE_CHECK(comm_size == config.ep_size, "ep_comm size (", comm_size, ") must equal ep_size (", + config.ep_size, "). Pass the EP sub-communicator, not the world comm."); + + inst.init(ep_comm, config); +} + +void EPBackend::shutdown() { + EPBackend& inst = instance(); + std::lock_guard lock(inst.mutex_); + if (!inst.initialized_) return; + const auto& nccl = loader::fns(); + for (auto& e : inst.lru_) { + if (e.handle != nullptr) nccl.HandleDestroy(e.handle); + } + inst.lru_.clear(); + inst.index_.clear(); + inst.fallback_layer_cfg_.reset(); + // ncclEpGroupDestroy reads from ep_comm_; destroy group while comm is still alive. + if (inst.ep_group_ != nullptr) { + nccl.GroupDestroy(inst.ep_group_); + inst.ep_group_ = nullptr; + } + inst.ep_comm_ = nullptr; // borrowed; caller destroys + inst.initialized_ = false; +} + +ncclEpHandle_t EPBackend::open_handle(void* handle_mem, size_t handle_mem_size, int num_topk, + size_t dispatch_output_per_expert_alignment) { + size_t hm_sizes[1] = {handle_mem_size}; + ncclEpTensor_t routing_desc = NCCL_EP_TENSOR_INIT; + routing_desc.ndim = 1; + routing_desc.datatype = ncclUint8; + routing_desc.data = handle_mem; + routing_desc.sizes = hm_sizes; + ncclEpHandleConfig_t hcfg = NCCL_EP_HANDLE_CONFIG_INIT; + hcfg.dispatch_output_per_expert_alignment = dispatch_output_per_expert_alignment; + ncclEpHandle_t handle; + NVTE_CHECK_NCCL(loader::fns().InitHandle(&handle, ep_group_, NCCL_EP_LAYOUT_EXPERT_MAJOR, &hcfg, + num_topk, &routing_desc)); + return handle; +} + +// --------------------------------------------------------------------------- +// Lifecycle +// --------------------------------------------------------------------------- + +// Static-dtor teardown: skip NCCL calls (CUDA context / borrowed ep_comm_ may +// already be gone) and release in-memory state only. +EPBackend::~EPBackend() { + std::lock_guard lock(mutex_); + if (!initialized_) return; + lru_.clear(); + index_.clear(); + fallback_layer_cfg_.reset(); + ep_group_ = nullptr; + ep_comm_ = nullptr; + initialized_ = false; +} + +void EPBackend::init(ncclComm_t ep_comm, NVTEEpGroupConfig group_config) { + NVTE_CHECK(!initialized_, "EPBackend already initialized"); + + group_config_ = group_config; + + ncclEpGroupConfig_t cfg = NCCL_EP_GROUP_CONFIG_INIT; + cfg.algorithm = NCCL_EP_ALGO_HIGH_THROUGHPUT; + cfg.num_experts = static_cast(group_config.num_experts); + cfg.max_dispatch_tokens_per_rank = static_cast(group_config.max_tokens_per_rank); + const size_t elem_bytes = typeToSize(static_cast(group_config.max_token_dtype)); + cfg.max_token_bytes = static_cast(group_config.hidden_dim * elem_bytes); + cfg.rdma_buffer_size = NCCL_EP_AUTO; + cfg.num_qp_per_rank = NCCL_EP_AUTO; + cfg.num_channels = NCCL_EP_AUTO; + cfg.max_num_sms = group_config.max_num_sms > 0 + ? static_cast(group_config.max_num_sms) + : NCCL_EP_AUTO; + // Must be > 0; NCCL EP errors out on 0. + cfg.max_recv_tokens_per_rank = static_cast(group_config.max_recv_tokens_per_rank); + cfg.zero_copy = group_config.zero_copy ? NCCL_EP_ZERO_COPY_ON : NCCL_EP_ZERO_COPY_OFF; + + NVTE_CHECK_NCCL(loader::fns().CreateGroup(&ep_group_, ep_comm, &cfg)); + + ep_comm_ = ep_comm; + + initialized_ = true; +} + +// --------------------------------------------------------------------------- +// Pointer-keyed LRU cache +// --------------------------------------------------------------------------- + +size_t EPBackend::cache_cap_locked() { + if (handle_cache_cap_ == 0) { + const char* cap_env = std::getenv("NVTE_EP_HANDLE_CACHE_SIZE"); + if (cap_env != nullptr) { + const int64_t v = static_cast(std::atol(cap_env)); + if (v < 0) { + // Unlimited cache. WAR for JAX until XLA fixes handle_mem + // reloc between runs. + handle_cache_cap_ = SIZE_MAX; + } else { + NVTE_CHECK(v > 0, + "NVTE_EP_HANDLE_CACHE_SIZE=0 is invalid; use -1 for unlimited or a positive " + "cap."); + handle_cache_cap_ = static_cast(v); + } + } else { + handle_cache_cap_ = 4096; + } + } + return handle_cache_cap_; +} + +ncclEpHandle_t EPBackend::prepare_handle_locked(void* handle_mem, NVTEEpLayerConfig layer_cfg) { + // Update the program-wide fallback cfg so dispatch/combine/_bwd can + // reconstruct the handle on a pointer-cache miss (WAR for XLA buffer reloc + // between runs; one cfg per process). Remove this once XLA preserves the + // handle_mem device pointer across runs. + if (fallback_layer_cfg_.has_value()) { + NVTE_CHECK(fallback_layer_cfg_->top_k == layer_cfg.top_k, "EP prepare top_k=", layer_cfg.top_k, + " disagrees with process-wide cached top_k=", fallback_layer_cfg_->top_k); + NVTE_CHECK(fallback_layer_cfg_->dispatch_output_per_expert_alignment == + layer_cfg.dispatch_output_per_expert_alignment, + "EP prepare alignment=", layer_cfg.dispatch_output_per_expert_alignment, + " disagrees with process-wide cached alignment=", + fallback_layer_cfg_->dispatch_output_per_expert_alignment); + } else { + fallback_layer_cfg_ = layer_cfg; + } + + auto it = index_.find(handle_mem); + if (it != index_.end()) { + lru_.splice(lru_.begin(), lru_, it->second); + return it->second->handle; + } + ncclEpHandleConfig_t hcfg = NCCL_EP_HANDLE_CONFIG_INIT; + hcfg.dispatch_output_per_expert_alignment = layer_cfg.dispatch_output_per_expert_alignment; + size_t hm_size = 0; + NVTE_CHECK_NCCL(loader::fns().HandleMemSize(ep_group_, NCCL_EP_LAYOUT_EXPERT_MAJOR, &hcfg, + &hm_size, layer_cfg.top_k)); + ncclEpHandle_t h = open_handle(handle_mem, hm_size, layer_cfg.top_k, + layer_cfg.dispatch_output_per_expert_alignment); + lru_.push_front(HandleEntry{handle_mem, h, layer_cfg, hm_size}); + index_.emplace(handle_mem, lru_.begin()); + while (lru_.size() > cache_cap_locked()) { + HandleEntry& victim = lru_.back(); + if (victim.handle != nullptr) loader::fns().HandleDestroy(victim.handle); + index_.erase(victim.handle_mem); + lru_.pop_back(); + } + return h; +} + +ncclEpHandle_t EPBackend::lookup_handle_locked(void* handle_mem) { + auto it = index_.find(handle_mem); + if (it != index_.end()) { + lru_.splice(lru_.begin(), lru_, it->second); + return it->second->handle; + } + // Miss: reconstruct from the process-wide cached cfg. XLA may relocate + // handle_mem between runs, breaking the pointer key; the fallback cfg lets + // us open a fresh handle on the new buffer. Drop this branch once XLA + // preserves buffer pointers. + const uintptr_t hm_addr = reinterpret_cast(handle_mem); + NVTE_CHECK(fallback_layer_cfg_.has_value(), "ep op on handle_mem=0x", hm_addr, + " with no cached entry and no prior nvte_ep_prepare; call prepare first."); + return prepare_handle_locked(handle_mem, *fallback_layer_cfg_); +} + +// --------------------------------------------------------------------------- +// Per-step operations +// --------------------------------------------------------------------------- + +size_t EPBackend::handle_mem_size(NVTEEpLayerConfig layer_cfg) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(layer_cfg.top_k > 0, "top_k must be > 0, got ", layer_cfg.top_k); + ncclEpHandleConfig_t hcfg = NCCL_EP_HANDLE_CONFIG_INIT; + hcfg.dispatch_output_per_expert_alignment = layer_cfg.dispatch_output_per_expert_alignment; + size_t hm_size = 0; + NVTE_CHECK_NCCL(loader::fns().HandleMemSize(ep_group_, NCCL_EP_LAYOUT_EXPERT_MAJOR, &hcfg, + &hm_size, layer_cfg.top_k)); + return hm_size; +} + +void EPBackend::prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor token_counts, + NVTEEpLayerConfig layer_cfg, cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + NVTE_CHECK(layer_cfg.top_k > 0, "top_k must be > 0, got ", layer_cfg.top_k); + + NVTEShape topk_idx_shape; + ncclEpTensor_t nccl_topk_idx = make_nccl_ep_tensor(topk_idx, topk_idx_shape); + + // ncclEpUpdateHandle writes per-expert counts via expert_counters. + NVTEShape token_counts_shape; + ncclEpTensor_t token_counts_desc; + if (token_counts != nullptr) { + token_counts_desc = make_nccl_ep_tensor(token_counts, token_counts_shape); + } + ncclEpLayoutInfo_t layout_info = NCCL_EP_LAYOUT_INFO_INIT; + layout_info.expert_counters = (token_counts != nullptr) ? &token_counts_desc : nullptr; + + std::lock_guard lock(mutex_); + ncclEpHandle_t h = prepare_handle_locked(handle_mem, layer_cfg); + NVTE_CHECK_NCCL(loader::fns().UpdateHandle(h, &nccl_topk_idx, &layout_info, stream)); +} + +void EPBackend::dispatch(void* handle_mem, const NVTETensor topk_idx, const NVTETensor tokens, + const NVTECommWindow& tokens_win, const NVTETensor topk_weights, + const NVTECommWindow& topk_weights_win, NVTETensor recv_tokens, + const NVTECommWindow& recv_tokens_win, NVTETensor recv_topk_weights, + const NVTECommWindow& recv_topk_weights_win, cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + NVTEDType tok_dtype = nvte_tensor_type(tokens); + NVTE_CHECK(typeToSize(static_cast(tok_dtype)) <= + typeToSize(static_cast(group_config_.max_token_dtype)), + "tokens dtype (", static_cast(tok_dtype), ") wider than group max_token_dtype (", + static_cast(group_config_.max_token_dtype), ")"); + NVTEDType recv_dtype = nvte_tensor_type(recv_tokens); + NVTE_CHECK(typeToSize(static_cast(recv_dtype)) <= + typeToSize(static_cast(group_config_.max_token_dtype)), + "recv_tokens dtype (", static_cast(recv_dtype), + ") wider than group max_token_dtype (", + static_cast(group_config_.max_token_dtype), ")"); + + NVTEShape tokens_shape, recv_tokens_shape; + ncclEpTensor_t nccl_tokens_in = make_nccl_ep_tensor(tokens, tokens_shape, tokens_win); + ncclEpTensor_t nccl_tokens_out = + make_nccl_ep_tensor(recv_tokens, recv_tokens_shape, recv_tokens_win); + + // Routing is cached in handle_mem by ep_prepare; dispatch only needs + // topk_weights to reconstruct the sparse-to-dense prob map. + const bool is_forward = (topk_weights != nullptr); + NVTEShape topk_weights_shape, recv_topk_weights_shape; + ncclEpTensor_t nccl_topk_weights_in; + ncclEpTensor_t nccl_topk_weights_out; + if (is_forward) { + NVTE_CHECK(topk_idx != nullptr, "topk_idx required in forward dispatch"); + NVTE_CHECK(recv_topk_weights != nullptr, + "recv_topk_weights must not be null in forward dispatch"); + NVTE_CHECK(nvte_tensor_shape(recv_topk_weights).ndim == 1, + "recv_topk_weights must be 1D [recv_capacity]"); + nccl_topk_weights_in = make_nccl_ep_tensor(topk_weights, topk_weights_shape, topk_weights_win); + nccl_topk_weights_out = + make_nccl_ep_tensor(recv_topk_weights, recv_topk_weights_shape, recv_topk_weights_win); + } + + ncclEpDispatchInputs_t in_struct = NCCL_EP_DISPATCH_INPUTS_INIT; + in_struct.tokens = &nccl_tokens_in; + in_struct.topk_weights = is_forward ? &nccl_topk_weights_in : nullptr; + + ncclEpDispatchOutputs_t out_struct = NCCL_EP_DISPATCH_OUTPUTS_INIT; + out_struct.tokens = &nccl_tokens_out; + out_struct.topk_weights = is_forward ? &nccl_topk_weights_out : nullptr; + + ncclEpDispatchConfig_t dispatch_cfg = NCCL_EP_DISPATCH_CONFIG_INIT; + dispatch_cfg.pass_direction = is_forward ? NCCL_EP_FWD_PASS : NCCL_EP_BWD_PASS; + + std::lock_guard lock(mutex_); + ncclEpHandle_t h = lookup_handle_locked(handle_mem); + NVTE_CHECK_NCCL(loader::fns().Dispatch(h, &in_struct, &out_struct, + /*layout_info=*/nullptr, &dispatch_cfg, stream)); +} + +void EPBackend::combine(void* handle_mem, const NVTETensor expert_out, + const NVTECommWindow& expert_out_win, NVTETensor result, + cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + NVTEShape expert_out_shape, result_shape; + ncclEpTensor_t nccl_expert_in = make_nccl_ep_tensor(expert_out, expert_out_shape, expert_out_win); + ncclEpTensor_t nccl_result_out = make_nccl_ep_tensor(result, result_shape); + + ncclEpCombineInputs_t in_struct = NCCL_EP_COMBINE_INPUTS_INIT; + in_struct.tokens = &nccl_expert_in; + + ncclEpCombineOutputs_t out_struct = NCCL_EP_COMBINE_OUTPUTS_INIT; + out_struct.tokens = &nccl_result_out; + + std::lock_guard lock(mutex_); + ncclEpHandle_t h = lookup_handle_locked(handle_mem); + NVTE_CHECK_NCCL(loader::fns().Combine(h, &in_struct, &out_struct, /*config=*/nullptr, stream)); +} + +void EPBackend::dispatch_bwd(void* handle_mem, const NVTETensor grad, + const NVTECommWindow& grad_win, const NVTETensor g_recv_topk_weights, + const NVTECommWindow& g_recv_topk_weights_win, NVTETensor grad_tokens, + NVTETensor grad_topk_weights, cudaStream_t stream) { + NVTE_CHECK(initialized_, "EPBackend not initialized"); + NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); + + // g_recv_topk_weights must be 1D [recv_capacity]; caller flattens. + NVTE_CHECK(nvte_tensor_shape(g_recv_topk_weights).ndim == 1, + "g_recv_topk_weights must be 1D [recv_capacity]; caller must flatten leading dims"); + NVTE_CHECK(nvte_tensor_shape(grad_topk_weights).ndim == 2, + "grad_topk_weights must be 2D [T, top_k]"); + + NVTEShape grad_shape, g_recv_w_shape, grad_tokens_shape, grad_w_shape; + ncclEpTensor_t nccl_tok_in = make_nccl_ep_tensor(grad, grad_shape, grad_win); + ncclEpTensor_t nccl_w_in = + make_nccl_ep_tensor(g_recv_topk_weights, g_recv_w_shape, g_recv_topk_weights_win); + ncclEpTensor_t nccl_tok_out = make_nccl_ep_tensor(grad_tokens, grad_tokens_shape); + ncclEpTensor_t nccl_w_out = make_nccl_ep_tensor(grad_topk_weights, grad_w_shape); + + ncclEpCombineInputs_t in_struct = NCCL_EP_COMBINE_INPUTS_INIT; + in_struct.tokens = &nccl_tok_in; + in_struct.topk_weights = &nccl_w_in; + + ncclEpCombineOutputs_t out_struct = NCCL_EP_COMBINE_OUTPUTS_INIT; + out_struct.tokens = &nccl_tok_out; + out_struct.topk_weights = &nccl_w_out; + + ncclEpCombineConfig_t cfg = NCCL_EP_COMBINE_CONFIG_INIT; + cfg.pass_direction = NCCL_EP_BWD_PASS; + + std::lock_guard lock(mutex_); + ncclEpHandle_t h = lookup_handle_locked(handle_mem); + NVTE_CHECK_NCCL(loader::fns().Combine(h, &in_struct, &out_struct, &cfg, stream)); +} + +void EPBackend::combine_bwd(void* handle_mem, const NVTETensor grad, const NVTECommWindow& grad_win, + NVTETensor grad_expert_out, const NVTECommWindow& grad_expert_out_win, + cudaStream_t stream) { + // Backward of combine = reverse-direction dispatch. + dispatch(handle_mem, /*topk_idx=*/nullptr, grad, grad_win, + /*topk_weights=*/nullptr, /*topk_weights_win=*/NVTECommWindow{}, grad_expert_out, + grad_expert_out_win, + /*recv_topk_weights=*/nullptr, /*recv_topk_weights_win=*/NVTECommWindow{}, stream); +} + +} // namespace ep +} // namespace transformer_engine diff --git a/transformer_engine/common/ep/ep_backend.h b/transformer_engine/common/ep/ep_backend.h new file mode 100644 index 0000000000..ea9aa019fa --- /dev/null +++ b/transformer_engine/common/ep/ep_backend.h @@ -0,0 +1,120 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_backend.h + * \brief Internal NCCL EP singleton; not part of the public API. + * + * ncclEpHandles are cached by handle_mem device pointer. nvte_ep_prepare + * seeds the entry with the layer_cfg; dispatch/combine/_bwd look up by + * pointer. Cache cap: NVTE_EP_HANDLE_CACHE_SIZE (default 4096; -1 disables + * LRU eviction). + */ + +#ifndef TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_ +#define TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace transformer_engine { +namespace ep { + +/*! \brief EP backend singleton; owns the NCCL EP group, borrows the comm. */ +class EPBackend { + public: + /*! \brief Access the singleton. Aborts if not initialized. */ + static EPBackend& get(); + + /*! \brief Bootstrap from an existing EP sub-communicator. + * ep_comm is borrowed; the caller keeps it alive until shutdown() returns + * and must span exactly config.ep_size ranks. + */ + static void initialize(ncclComm_t ep_comm, NVTEEpGroupConfig config); + + /*! \brief Tear down the backend. Idempotent. Does not destroy ep_comm_. */ + static void shutdown(); + + // Host-only: report handle_mem byte size for layer_cfg. + size_t handle_mem_size(NVTEEpLayerConfig layer_cfg); + + // Seeds the cache for handle_mem with layer_cfg and runs the routing AllGather. + void prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor token_counts, + NVTEEpLayerConfig layer_cfg, cudaStream_t stream); + + // Per-step ops below require a prior prepare(). + void dispatch(void* handle_mem, const NVTETensor topk_idx, const NVTETensor tokens, + const NVTECommWindow& tokens_win, const NVTETensor topk_weights, + const NVTECommWindow& topk_weights_win, NVTETensor recv_tokens, + const NVTECommWindow& recv_tokens_win, NVTETensor recv_topk_weights, + const NVTECommWindow& recv_topk_weights_win, cudaStream_t stream); + + void combine(void* handle_mem, const NVTETensor expert_out, const NVTECommWindow& expert_out_win, + NVTETensor result, cudaStream_t stream); + + // g_recv_topk_weights: 1D [recv_capacity] f32; grad_topk_weights: 2D [T, top_k] f32. + void dispatch_bwd(void* handle_mem, const NVTETensor grad, const NVTECommWindow& grad_win, + const NVTETensor g_recv_topk_weights, + const NVTECommWindow& g_recv_topk_weights_win, NVTETensor grad_tokens, + NVTETensor grad_topk_weights, cudaStream_t stream); + + void combine_bwd(void* handle_mem, const NVTETensor grad, const NVTECommWindow& grad_win, + NVTETensor grad_expert_out, const NVTECommWindow& grad_expert_out_win, + cudaStream_t stream); + + private: + EPBackend() = default; + ~EPBackend(); + EPBackend(const EPBackend&) = delete; + EPBackend& operator=(const EPBackend&) = delete; + + // ep_comm is borrowed; caller retains ownership across the backend lifetime. + void init(ncclComm_t ep_comm, NVTEEpGroupConfig config); + + static EPBackend& instance(); // Meyers singleton accessor + static void validate_config(const NVTEEpGroupConfig& config); + + // Open a fresh ncclEpHandle over handle_mem. num_topk=-1 for paths + // that don't carry per-token weights. + ncclEpHandle_t open_handle(void* handle_mem, size_t handle_mem_size, int num_topk, + size_t dispatch_output_per_expert_alignment); + + // LRU cache: most-recently-used at the front of lru_, evict from the back. + struct HandleEntry { + void* handle_mem; + ncclEpHandle_t handle; + NVTEEpLayerConfig layer_cfg; + size_t handle_mem_size; + }; + + ncclEpGroup_t ep_group_{nullptr}; + ncclComm_t ep_comm_{nullptr}; + NVTEEpGroupConfig group_config_{}; + bool initialized_{false}; + std::mutex mutex_; + std::list lru_; + std::unordered_map::iterator> index_; + size_t handle_cache_cap_{0}; // set lazily from NVTE_EP_HANDLE_CACHE_SIZE + std::optional fallback_layer_cfg_; + + // Caller must hold mutex_. + ncclEpHandle_t prepare_handle_locked(void* handle_mem, NVTEEpLayerConfig layer_cfg); + ncclEpHandle_t lookup_handle_locked(void* handle_mem); + size_t cache_cap_locked(); +}; + +} // namespace ep +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_EP_EP_BACKEND_H_ diff --git a/transformer_engine/common/ep/ep_nccl_loader.cpp b/transformer_engine/common/ep/ep_nccl_loader.cpp new file mode 100644 index 0000000000..8374acd7b3 --- /dev/null +++ b/transformer_engine/common/ep/ep_nccl_loader.cpp @@ -0,0 +1,74 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "ep_nccl_loader.h" + +#include + +#include "../util/logging.h" + +namespace transformer_engine { +namespace ep { +namespace loader { + +namespace { + +constexpr const char* kSonames[] = {"libnccl_ep.so.0", "libnccl_ep.so"}; + +void* try_dlopen(std::string& last_err) { + for (const char* name : kSonames) { + dlerror(); + void* h = dlopen(name, RTLD_LAZY | RTLD_LOCAL); + if (h != nullptr) return h; + if (const char* e = dlerror()) last_err = e; + } + return nullptr; +} + +template +Fn resolve(void* lib, const char* sym) { + dlerror(); + void* p = dlsym(lib, sym); + const char* err = dlerror(); + NVTE_CHECK(err == nullptr && p != nullptr, "libnccl_ep.so is loaded but symbol '", sym, + "' could not be resolved", (err != nullptr ? std::string(": ") + err : std::string{}), + ". The runtime libnccl_ep.so is older than the version TransformerEngine " + "was built against; upgrade NCCL EP or rebuild TE with -DNVTE_WITH_NCCL_EP=OFF."); + return reinterpret_cast(p); +} + +NcclEpFns load_or_throw() { + std::string last_err; + void* lib = try_dlopen(last_err); + NVTE_CHECK(lib != nullptr, "Failed to load libnccl_ep.so (", + (last_err.empty() ? "no error message" : last_err), + "). NCCL EP requires libnccl_ep.so (>= 0.0.1) and NCCL >= 2.30.4 at runtime. " + "Install the NCCL EP shared library, or rebuild TransformerEngine with " + "-DNVTE_WITH_NCCL_EP=OFF to disable EP support."); + NcclEpFns fns{}; + fns.InitHandle = resolve(lib, "ncclEpInitHandle"); + fns.CreateGroup = resolve(lib, "ncclEpCreateGroup"); + fns.GroupDestroy = resolve(lib, "ncclEpGroupDestroy"); + fns.HandleDestroy = resolve(lib, "ncclEpHandleDestroy"); + fns.HandleMemSize = resolve(lib, "ncclEpHandleMemSize"); + fns.UpdateHandle = resolve(lib, "ncclEpUpdateHandle"); + fns.Dispatch = resolve(lib, "ncclEpDispatch"); + fns.Combine = resolve(lib, "ncclEpCombine"); + return fns; +} + +} // namespace + +const NcclEpFns& fns() { + // Function-local static: thread-safe one-shot init; re-throws on every call + // if initialization fails, so a missing library is surfaced consistently. + static const NcclEpFns table = load_or_throw(); + return table; +} + +} // namespace loader +} // namespace ep +} // namespace transformer_engine diff --git a/transformer_engine/common/ep/ep_nccl_loader.h b/transformer_engine/common/ep/ep_nccl_loader.h new file mode 100644 index 0000000000..8ffb437ed8 --- /dev/null +++ b/transformer_engine/common/ep/ep_nccl_loader.h @@ -0,0 +1,48 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep_nccl_loader.h + * \brief Lazy dlopen-based resolver for libnccl_ep.so. + * + * libtransformer_engine.so is not link-time bound to libnccl_ep.so. The first + * call to ep::loader::fns() opens it via dlopen and dlsyms the ncclEp* + * entry points the EP backend uses. If the library or any symbol cannot be + * resolved (e.g. libnccl_ep.so is missing, or system NCCL is older than the + * EP minimum so libnccl_ep.so's own DT_NEEDED chain fails), the call throws + * NVTE_ERROR with remediation instead of preventing libtransformer_engine.so + * from loading. + */ + +#ifndef TRANSFORMER_ENGINE_COMMON_EP_EP_NCCL_LOADER_H_ +#define TRANSFORMER_ENGINE_COMMON_EP_EP_NCCL_LOADER_H_ + +#include + +namespace transformer_engine { +namespace ep { +namespace loader { + +struct NcclEpFns { + decltype(&::ncclEpInitHandle) InitHandle; + decltype(&::ncclEpCreateGroup) CreateGroup; + decltype(&::ncclEpGroupDestroy) GroupDestroy; + decltype(&::ncclEpHandleDestroy) HandleDestroy; + decltype(&::ncclEpHandleMemSize) HandleMemSize; + decltype(&::ncclEpUpdateHandle) UpdateHandle; + decltype(&::ncclEpDispatch) Dispatch; + decltype(&::ncclEpCombine) Combine; +}; + +/*! \brief Resolve libnccl_ep.so on first call; cache the table thereafter. + * Thread-safe; throws NVTE_ERROR if the library or any symbol is missing. + */ +const NcclEpFns& fns(); + +} // namespace loader +} // namespace ep +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_EP_EP_NCCL_LOADER_H_ diff --git a/transformer_engine/common/include/transformer_engine/comm_window.h b/transformer_engine/common/include/transformer_engine/comm_window.h new file mode 100644 index 0000000000..088ea7f0c3 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/comm_window.h @@ -0,0 +1,32 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file comm_window.h + * \brief Borrowed symmetric-memory window + offset for zero-copy one-sided ops. + * Pass ``{NULL, 0}`` to use the raw-pointer path. + */ + +#ifndef TRANSFORMER_ENGINE_COMM_WINDOW_H_ +#define TRANSFORMER_ENGINE_COMM_WINDOW_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief NCCL window + byte offset for a zero-copy payload tensor. */ +typedef struct { + ncclWindow_t window; /*!< NCCL window, or NULL to use the raw data pointer. */ + uint64_t offset; /*!< Byte offset of the payload within ``window``. */ +} NVTECommWindow; + +#ifdef __cplusplus +} +#endif + +#endif // TRANSFORMER_ENGINE_COMM_WINDOW_H_ diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h new file mode 100644 index 0000000000..5682e9fdb6 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/ep.h @@ -0,0 +1,203 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ep.h + * \brief Public C API for Expert Parallelism. Per-step ops are + * allocation-free and CUDA graph-capturable. + * + * Per layer: call nvte_ep_handle_mem_size(layer_cfg) for the buffer size; + * allocate handle_mem as a kByte NVTETensor. Per step: nvte_ep_prepare seeds + * routing, then nvte_ep_dispatch / nvte_ep_combine / _bwd consume it. + * Cache cap: NVTE_EP_HANDLE_CACHE_SIZE (default 4096; -1 disables eviction). + */ + +#ifndef TRANSFORMER_ENGINE_EP_H_ +#define TRANSFORMER_ENGINE_EP_H_ + +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* -- Config structs ------------------------------------------------------- */ +/* TODO: add a struct_size/version field to these configs (and align with other + * TE public structs) once a TE-wide convention for ABI versioning lands. */ + +/*! \brief Group-level EP configuration (fixed for the EP group lifetime). */ +typedef struct { + /*! EP world size. */ + int ep_size; + /*! Total experts across all ranks. */ + int num_experts; + /*! Upper bound on tokens this rank sends per dispatch. */ + int max_tokens_per_rank; + /*! Upper bound on tokens this rank receives per dispatch (must be > 0). */ + int max_recv_tokens_per_rank; + /*! Token hidden dimension. */ + int hidden_dim; + /*! Max SMs for EP kernels. 0 = auto. */ + int max_num_sms; + /*! Widest token dtype the group will dispatch; sizes staging buffers. + * Per-dispatch tensors may use any dtype with element size <= this. */ + NVTEDType max_token_dtype; + /*! Zero-copy dispatch/combine. When nonzero, payload tensors must be backed + * by NVTECommWindow handles and transfer in place (no staging copies); + * 0 (default) = staged. */ + int zero_copy; +} NVTEEpGroupConfig; + +/*! \brief Per-layer configuration consumed by nvte_ep_handle_mem_size and + * nvte_ep_prepare. Reserved for future per-call options (fp8 scale, + * overflow policy, ...). + */ +typedef struct { + /*! Per-token expert fan-out (> 0). */ + int top_k; + /*! Per-expert recv-slab alignment in tokens (power of two; 0/1 disables). + * When > 1, each expert's slab in recv_tokens is zero-padded up to a + * multiple of this for downstream per-expert GEMM alignment. */ + size_t dispatch_output_per_expert_alignment; +} NVTEEpLayerConfig; + +/* -- Bootstrap ------------------------------------------------------------ */ + +/*! \brief Bootstrap the EP backend from an existing NCCL EP sub-communicator. + * Requires SM>=90. + * + * ep_comm is borrowed and must span exactly group_config.ep_size ranks. The + * caller retains ownership and must keep it alive until nvte_ep_shutdown() + * returns. Re-init after shutdown is allowed; double-init throws. One EP + * group per process, bound to the current CUDA device. + * + * \param[in] ep_comm Opaque ncclComm_t for the EP sub-group. + * \param[in] group_config Group-level EP configuration. + */ +void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config); + +/*! \brief Tear down the EP backend. Idempotent. Does not destroy ep_comm. */ +void nvte_ep_shutdown(void); + +/* -- Layer sizing (host-only) --------------------------------------------- */ + +/*! \brief Report the handle_mem byte size required for the given layer config. + * + * handle_mem is a per-layer kByte routing-state buffer; allocate once and + * thread the same pointer through every prepare/dispatch/combine/_bwd call + * for that layer (the backend keys its cache on the pointer). Host-only; + * size is stable for a given (group, layer) pair. + * + * \param[in] layer_cfg Per-call layer configuration. + * \return size in bytes for the handle_mem buffer. + */ +size_t nvte_ep_handle_mem_size(NVTEEpLayerConfig layer_cfg); + +/* -- Per-step ops (all allocation-free, CUDA graph-capturable) ------------ */ + +/*! \brief Seed handle_mem with this step's routing plan. + * + * AllGathers topk_idx across the EP group and stages per-expert offsets and + * counts into handle_mem so the matching dispatch/combine/_bwd can run with + * no further routing computation. Must precede every dispatch/combine/_bwd + * that uses this handle_mem. token_counts becomes host-valid after a stream + * sync. + * + * \param[in] handle_mem uint8 routing-state buffer. + * \param[in] topk_idx [T, top_k] int64 routing indices. + * \param[out] token_counts [num_local_experts] int32 counts. + * \param[in] layer_cfg Per-call layer configuration. + * \param[in] stream CUDA stream. + */ +void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor token_counts, + NVTEEpLayerConfig layer_cfg, cudaStream_t stream); + +/*! \brief Dispatch tokens (and routing weights) to expert ranks. + * + * Each local token is sent to its top_k destinations; recv_tokens is laid out + * expert-major (contiguous per-expert slabs, padded per layer_cfg). The + * *_win arguments enable zero-copy via symmem windows; pass NVTECommWindow{} + * when unused. Requires a prior nvte_ep_prepare on this handle_mem. + * + * \param[in] handle_mem uint8 routing-state buffer (from prepare). + * \param[in] topk_idx [T, top_k] int64 sparse routing indices. + * \param[in] tokens [T, hidden_dim] input tokens. + * \param[in] tokens_win Optional symmem window for ``tokens``. + * \param[in] topk_weights [T, top_k] float32 weights, or null in backward. + * \param[in] topk_weights_win Optional symmem window for ``topk_weights``. + * \param[out] recv_tokens [recv_T, hidden_dim] received tokens. + * \param[in] recv_tokens_win Optional symmem window for ``recv_tokens``. + * \param[out] recv_topk_weights [recv_T] float32 per-slot weights, or null in backward. + * \param[in] recv_topk_weights_win Optional symmem window for ``recv_topk_weights``. + * \param[in] stream CUDA stream. + */ +void nvte_ep_dispatch(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor tokens, + NVTECommWindow tokens_win, NVTETensor topk_weights, + NVTECommWindow topk_weights_win, NVTETensor recv_tokens, + NVTECommWindow recv_tokens_win, NVTETensor recv_topk_weights, + NVTECommWindow recv_topk_weights_win, cudaStream_t stream); + +/*! \brief Scatter-sum expert outputs back to originating ranks. + * + * Inverse of dispatch: the top_k destination slots for token t are summed + * into result[t]. Sums are unweighted; pre-scale expert_out by + * recv_topk_weights (and the valid-slot mask) before calling. Requires a + * prior nvte_ep_prepare on this handle_mem. + * + * \param[in] handle_mem uint8 routing-state buffer (from prepare). + * \param[in] expert_out [recv_T, hidden_dim] pre-weighted expert outputs. + * \param[in] expert_out_win Optional symmem window for ``expert_out``. + * \param[out] result [T, hidden_dim] combined output. + * \param[in] stream CUDA stream. + */ +void nvte_ep_combine(NVTETensor handle_mem, NVTETensor expert_out, NVTECommWindow expert_out_win, + NVTETensor result, cudaStream_t stream); + +/*! \brief Backward of dispatch: route per-recv-slot grads back to source. + * + * Sums the top_k recv-slot grads into grad_tokens[t]; scatters per-slot + * recv-weight grads into grad_topk_weights[t, k]. Padded recv slots + * contribute nothing. Requires a prior nvte_ep_prepare on this handle_mem. + * + * \param[in] handle_mem uint8 routing-state buffer (from prepare). + * \param[in] grad [recv_capacity, hidden_dim] grad w.r.t. recv_tokens. + * \param[in] grad_win Optional symmem window for ``grad``. + * \param[in] g_recv_topk_weights [recv_capacity] f32 grad w.r.t. recv_topk_weights. + * \param[in] g_recv_topk_weights_win Optional symmem window for ``g_recv_topk_weights``. + * \param[out] grad_tokens [T, hidden_dim] grad w.r.t. tokens. + * \param[out] grad_topk_weights [T, top_k] f32 grad w.r.t. topk_weights. + * \param[in] stream CUDA stream. + */ +void nvte_ep_dispatch_bwd(NVTETensor handle_mem, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor g_recv_topk_weights, NVTECommWindow g_recv_topk_weights_win, + NVTETensor grad_tokens, NVTETensor grad_topk_weights, + cudaStream_t stream); + +/*! \brief Backward of combine: replicate each source-token grad to its recv + * slots from the forward. + * + * Padded recv slots in grad_expert_out are zeroed. Requires a prior + * nvte_ep_prepare on this handle_mem. + * + * \param[in] handle_mem uint8 routing-state buffer (from prepare). + * \param[in] grad [T, hidden_dim] grad w.r.t. result. + * \param[in] grad_win Optional symmem window for ``grad``. + * \param[out] grad_expert_out [recv_capacity, hidden_dim] grad w.r.t. expert_out. + * \param[in] grad_expert_out_win Optional symmem window for ``grad_expert_out``. + * \param[in] stream CUDA stream. + */ +void nvte_ep_combine_bwd(NVTETensor handle_mem, NVTETensor grad, NVTECommWindow grad_win, + NVTETensor grad_expert_out, NVTECommWindow grad_expert_out_win, + cudaStream_t stream); + +#ifdef __cplusplus +} +#endif + +#endif // TRANSFORMER_ENGINE_EP_H_ diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index b42c909740..c9647afb82 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -11,4 +11,5 @@ from .softmax import * from .gemm import * from .router import * +from .ep import * from .topk import * diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 6eb588c849..2cdef4bfe7 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -266,6 +266,17 @@ def _gspmd_wrapper(*args, **kwargs): for _name, _value in transformer_engine_jax.registrations().items(): ffi.register_ffi_target(_name, _value, platform="CUDA") +# Register EpInstanceState (no-op when TE is built without NCCL EP). +if hasattr(transformer_engine_jax, "get_ep_instance_state_type_id"): + ffi.register_ffi_type( + "EpInstanceState", + { + "type_id": transformer_engine_jax.get_ep_instance_state_type_id(), + "type_info": transformer_engine_jax.get_ep_instance_state_type_info(), + }, + platform="CUDA", + ) + def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False): """ diff --git a/transformer_engine/jax/cpp_extensions/ep.py b/transformer_engine/jax/cpp_extensions/ep.py new file mode 100644 index 0000000000..ce2f552f42 --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/ep.py @@ -0,0 +1,967 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX/TE custom ops for Expert Parallelism (EP). + +Sharding model: + - EpPrepare / EpDispatch outputs carry a single leading ``num_procs`` dim. + Sharded compound ``(dp_resource, ep_resource)`` when DP is set, else + ``ep_resource`` alone. + - EpDispatch inputs are 2D ``[T, H]`` or 3D ``[B, S, H]``; only the first + dim may be sharded, with axis in {ep, (dp, ep), dp, None}. Trailing dims + must be replicated. ``dp`` alone gets ``ep`` folded in locally. + - EpCombine output sharding comes from ``out_sharding`` or defaults to the + compound ``(dp, ep)`` axis on the leading dim. +""" + +from dataclasses import dataclass + +import jax +import jax.numpy as jnp +from jax import dtypes, ffi +from jax.sharding import NamedSharding, PartitionSpec + +import transformer_engine_jax +from .base import BasePrimitive, register_primitive +from ..sharding import global_mesh_resource, get_mesh_axis_size + +__all__ = [ + "EpConfig", + "EpLayerConfig", + "set_ep_config", + "get_ep_config", + "ep_handle_mem_size", + "ep_prepare", + "ep_dispatch_fwd", + "ep_combine_fwd", + "ep_dispatch_bwd", + "ep_combine_bwd", +] + + +# ── Module-level EP config ────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class EpConfig: + """Snapshot of the EP bootstrap config (see ep_bootstrap). + + num_ep_groups is the size of the outer dp/fsdp mesh axis (1 if neither + is set), captured at bootstrap so abstract-eval never reads the mesh. + """ + + world_size: int + rank: int + ep_size: int + num_ep_groups: int + num_experts: int + num_local_experts: int + max_tokens_per_rank: int + recv_capacity_per_rank: int + hidden_dim: int + + +_ep_config: EpConfig = None + + +def set_ep_config(config: EpConfig) -> None: + """Cache the EP config for abstract-eval / sharding helpers. Call once.""" + global _ep_config + _ep_config = config + + +def get_ep_config() -> EpConfig: + """Return the process-wide EpConfig set by ep_bootstrap.""" + if _ep_config is None: + raise RuntimeError("EpConfig has not been set. Did you call ep_bootstrap()?") + return _ep_config + + +@dataclass(frozen=True) +class EpLayerConfig: + """Per-layer EP config; mirrors C ``NVTEEpLayerConfig``. + + Threaded through every per-step op so the pointer-keyed C++ cache can + validate consistency across a handle_mem's prepare / dispatch / combine. + Reserved for future per-call fields (fp8 scale, overflow policy, ...). + """ + + top_k: int + dispatch_output_per_expert_alignment: int = 0 + + +def ep_handle_mem_size(cfg: EpLayerConfig) -> int: + """Return the handle_mem byte size for ``cfg``. Host-only; cheap.""" + return int( + transformer_engine_jax.ep_handle_mem_size( + int(cfg.top_k), int(cfg.dispatch_output_per_expert_alignment) + ) + ) + + +def _leading_axis_ok(spec): + """Validate an EP input spec; return ``(ok, ep_axis, outer_axes)``. + + Leading dim is ``ep`` or a tuple ending in ``ep`` (outer dp/fsdp axes + first); all other dims must be replicated. + """ + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer_axes = tuple(a for a in (gsr.dp_resource, gsr.fsdp_resource) if a is not None) + if len(spec) < 2 or ep_axis is None: + return False, ep_axis, outer_axes + if any(ax is not None for ax in spec[1:]): + return False, ep_axis, outer_axes + leading = spec[0] + elts = leading if isinstance(leading, tuple) else (leading,) + if ep_axis not in elts: + return False, ep_axis, outer_axes + allowed = set(outer_axes) | {ep_axis} + return all(a in allowed for a in elts), ep_axis, outer_axes + + +def _ep_outer_axis(): + """The single dp/fsdp axis (if any) sitting outside ep on EP-output tensors. + + When set, EP-output globals carry an extra leading ``dp_size`` dim so SPMD + sees each DP color's slab as distinct (rather than replicated across DP). + + A dp/fsdp axis that is sized 1 in the active mesh is treated as absent so + we don't pin EP-output specs to a degenerate axis that JAX may collapse. + """ + gsr = global_mesh_resource() + if gsr.dp_resource is not None and get_mesh_axis_size(gsr.dp_resource) > 1: + return gsr.dp_resource + if gsr.fsdp_resource is not None and get_mesh_axis_size(gsr.fsdp_resource) > 1: + return gsr.fsdp_resource + return gsr.dp_resource or gsr.fsdp_resource + + +def _ep_leading_dims(is_outer): + """Leading dim of an EP-output tensor: num_ep_groups*ep_size globally, + 1 per shard. Read from EpConfig so abstract-eval needs no active mesh.""" + cfg = get_ep_config() + if not is_outer: + return (1,) + return (cfg.num_ep_groups * cfg.ep_size,) + + +def _ep_output_spec(*trailing): + """PartitionSpec for an EP-output tensor: ``(("dp","ep"), *trailing)`` when + DP is set (compound leading axis on a single dim), else ``("ep",*trailing)``.""" + gsr = global_mesh_resource() + outer = _ep_outer_axis() + if outer is None: + return PartitionSpec(gsr.ep_resource, *trailing) + return PartitionSpec((outer, gsr.ep_resource), *trailing) + + +def _ep_spec_ok(spec, trailing_count): + """Leading dim shards along ep (and outer dp/fsdp when set); trailing dims + are replicated. JAX may collapse size-1 mesh axes to ``None`` or drop them, + so the leading entry is normalized to a set of named axes before comparing. + """ + gsr = global_mesh_resource() + ep_axis = gsr.ep_resource + outer = _ep_outer_axis() + if len(spec) != 1 + trailing_count: + return False + if any(ax is not None for ax in spec[1:]): + return False + leading = spec[0] + elts = leading if isinstance(leading, tuple) else (leading,) + actual = frozenset(a for a in elts if a is not None) + expected = {ep_axis} if outer is None else {ep_axis, outer} + return actual <= expected + + +# ── ep_prepare ────────────────────────────────────────────────────────────── + + +class EpPreparePrimitive(BasePrimitive): + """FFI primitive for nvte_ep_prepare: routing setup and per-expert token counts.""" + + name = "te_ep_prepare_ffi" + multiple_results = True + impl_static_args = (1, 2, 3) # top_k, dispatch_output_per_expert_alignment, is_outer + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(topk_idx_aval, *, top_k, dispatch_output_per_expert_alignment, is_outer): + # is_outer=True: global leading dim = (dp*ep,) (or (ep,) with no DP); + # False: per-shard = (1,). + cfg = get_ep_config() + num_local_experts = cfg.num_local_experts + assert ( + len(topk_idx_aval.shape) >= 2 + ), f"topk_idx must be at least 2D [..., top_k], got shape {topk_idx_aval.shape}" + handle_mem_size = int( + transformer_engine_jax.ep_handle_mem_size( + int(top_k), int(dispatch_output_per_expert_alignment) + ) + ) + leading = _ep_leading_dims(is_outer) + token_counts_aval = jax.core.ShapedArray(leading + (num_local_experts,), jnp.int32) + handle_mem_aval = jax.core.ShapedArray(leading + (handle_mem_size,), jnp.uint8) + # FFI scratch for the int32 -> int64 topk_idx upcast. int32 with last + # dim doubled to keep the int64 byte count without JAX_ENABLE_X64. + # TODO(phuong): drop once NCCL EP supports int32 topk_idx. + workspace_shape = topk_idx_aval.shape[:-1] + (topk_idx_aval.shape[-1] * 2,) + workspace_aval = jax.core.ShapedArray(workspace_shape, jnp.int32) + return token_counts_aval, handle_mem_aval, workspace_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + kwargs["is_outer"] = True + return EpPreparePrimitive.abstract(*args, **kwargs)[:2] # pylint: disable=missing-kwoa + + @staticmethod + def lowering(ctx, topk_idx, *, top_k, dispatch_output_per_expert_alignment, is_outer): + del is_outer + return ffi.ffi_lowering(EpPreparePrimitive.name)( + ctx, + topk_idx, + top_k=int(top_k), + dispatch_output_per_expert_alignment=int(dispatch_output_per_expert_alignment), + ) + + @staticmethod + def impl(topk_idx, top_k, dispatch_output_per_expert_alignment, is_outer): + assert EpPreparePrimitive.inner_primitive is not None + token_counts, handle_mem, _workspace = EpPreparePrimitive.inner_primitive.bind( + topk_idx, + top_k=top_k, + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + is_outer=is_outer, + ) + return token_counts, handle_mem + + @staticmethod + def batcher(batched_args, batch_dims, *, top_k, dispatch_output_per_expert_alignment, is_outer): + raise NotImplementedError("EpPreparePrimitive does not support vmap") + + @staticmethod + def partition( + top_k, dispatch_output_per_expert_alignment, is_outer, mesh, arg_infos, result_infos + ): + del is_outer, result_infos + idx_spec = arg_infos[0].sharding.spec + ok, ep_axis, outer_axes = _leading_axis_ok(idx_spec) + if not ok: + raise NotImplementedError( + "EpPrepare: topk_idx leading dim must include ep_resource" + f" ('{ep_axis}'), optionally tupled with {outer_axes}," + f" with the topk dim replicated; got spec={idx_spec}." + ) + arg_shardings = tuple(a.sharding for a in arg_infos) + # token_counts / handle_mem inherit the input's leading axis (trailing dims auto-pad to None). + leading_spec = PartitionSpec(idx_spec[0]) + tc_sharding = NamedSharding(mesh, leading_spec) + hm_sharding = NamedSharding(mesh, leading_spec) + + def sharded_impl(topk_idx): + return EpPreparePrimitive.impl( + topk_idx, top_k, dispatch_output_per_expert_alignment, False + ) + + return mesh, sharded_impl, (tc_sharding, hm_sharding), arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Static args + # for this primitive are (top_k, dispatch_alignment, is_outer). + value_types = args[-2] + topk_idx_rank = len(value_types[0].shape) + in_axes = " ".join(f"L{i}" for i in range(topk_idx_rank - 1)) + " topk" + return f"{in_axes} -> EPL nle, EPL hm" + + +register_primitive(EpPreparePrimitive) + + +# ── ep_dispatch ───────────────────────────────────────────────────────────── + + +class EpDispatchPrimitive(BasePrimitive): + """FFI primitive for nvte_ep_dispatch (forward).""" + + name = "te_ep_dispatch_ffi" + multiple_results = True + impl_static_args = (4, 5, 6, 7) # top_k, dispatch_output_per_expert_alignment, + # recv_capacity_per_rank, is_outer + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + handle_mem_aval, + topk_idx_aval, + tokens_aval, + topk_weights_aval, + *, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + ): + # is_outer=True: global leading dim = (dp*ep,) (or (ep,) with no DP); + # False: per-shard = (1,). + del topk_weights_aval, top_k, dispatch_output_per_expert_alignment, handle_mem_aval + assert ( + len(tokens_aval.shape) >= 2 + ), f"tokens must be at least 2D [..., H], got shape {tokens_aval.shape}" + recv_pr = recv_capacity_per_rank + tok_dtype = dtypes.canonicalize_dtype(tokens_aval.dtype) + hidden_dim = tokens_aval.shape[-1] + leading = _ep_leading_dims(is_outer) + recv_tokens_aval = jax.core.ShapedArray(leading + (recv_pr, hidden_dim), tok_dtype) + recv_topk_weights_aval = jax.core.ShapedArray(leading + (recv_pr,), jnp.float32) + # int32 with last dim doubled to keep the int64 byte count without JAX_ENABLE_X64. + workspace_shape = topk_idx_aval.shape[:-1] + (topk_idx_aval.shape[-1] * 2,) + workspace_aval = jax.core.ShapedArray(workspace_shape, jnp.int32) + return (recv_tokens_aval, recv_topk_weights_aval, workspace_aval) + + @staticmethod + def outer_abstract(*args, **kwargs): + kwargs["is_outer"] = True + return EpDispatchPrimitive.abstract(*args, **kwargs)[:2] # pylint: disable=missing-kwoa + + @staticmethod + def lowering( + ctx, + handle_mem, + topk_idx, + tokens, + topk_weights, + *, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + ): + del recv_capacity_per_rank, is_outer + return ffi.ffi_lowering(EpDispatchPrimitive.name)( + ctx, + handle_mem, + topk_idx, + tokens, + topk_weights, + top_k=int(top_k), + dispatch_output_per_expert_alignment=int(dispatch_output_per_expert_alignment), + ) + + @staticmethod + def impl( + handle_mem, + topk_idx, + tokens, + topk_weights, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + ): + assert EpDispatchPrimitive.inner_primitive is not None + recv_tokens, recv_topk_weights, _workspace = EpDispatchPrimitive.inner_primitive.bind( + handle_mem, + topk_idx, + tokens, + topk_weights, + top_k=top_k, + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + recv_capacity_per_rank=recv_capacity_per_rank, + is_outer=is_outer, + ) + return recv_tokens, recv_topk_weights + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + ): + raise NotImplementedError("EpDispatchPrimitive does not support vmap") + + @staticmethod + def partition( + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + mesh, + arg_infos, + result_infos, + ): + del is_outer, result_infos + tokens_spec = arg_infos[2].sharding.spec + ok, ep_axis, outer_axes = _leading_axis_ok(tokens_spec) + if not ok: + raise NotImplementedError( + "EpDispatch: tokens leading dim must include ep_resource" + f" ('{ep_axis}'), optionally tupled with {outer_axes}," + f" hidden dim replicated; got spec={tokens_spec}." + ) + idx_spec = arg_infos[1].sharding.spec + tw_spec = arg_infos[3].sharding.spec + if idx_spec[0] != tokens_spec[0] or tw_spec[0] != tokens_spec[0]: + raise NotImplementedError( + "EpDispatch: topk_idx, tokens, topk_weights must share the leading" + f" axis; got topk_idx={idx_spec}, tokens={tokens_spec}, topk_weights={tw_spec}." + ) + # Recv outputs share the tokens leading-only spec (trailing dims auto-pad to None). + leading_spec = PartitionSpec(tokens_spec[0]) + arg_shardings = tuple(a.sharding for a in arg_infos) + out_shardings = ( + NamedSharding(mesh, leading_spec), + NamedSharding(mesh, leading_spec), + ) + + def sharded_impl(handle_mem, topk_idx, tokens, topk_weights): + return EpDispatchPrimitive.impl( + handle_mem, + topk_idx, + tokens, + topk_weights, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + False, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Static args + # for this primitive are (top_k, dispatch_alignment, recv_capacity_per_rank, is_outer). + value_types = args[-2] + # Inputs: handle_mem, topk_idx, tokens, topk_weights. + idx_rank = len(value_types[1].shape) + tok_rank = len(value_types[2].shape) + tw_rank = len(value_types[3].shape) + idx_axes = " ".join(f"I{i}" for i in range(idx_rank - 1)) + " topk_in" + tok_axes = " ".join(f"T{i}" for i in range(tok_rank - 1)) + " H" + tw_axes = " ".join(f"W{i}" for i in range(tw_rank - 1)) + " topk" + return f"EPL hm, {idx_axes}, {tok_axes}, {tw_axes} -> EPL recv_pr H, EPL recv_pr" + + +register_primitive(EpDispatchPrimitive) + + +# ── ep_combine ────────────────────────────────────────────────────────────── +# `expert_out` here is the post-weight buffer; ep.ep_combine applies the +# hadamard before calling. + + +def _normalize_leading_shape(s): + return s if isinstance(s, tuple) else (int(s),) + + +def _prod(seq): + p = 1 + for x in seq: + p *= int(x) + return p + + +def _leading_per_shard(out_leading_shape, leading_axis, mesh): + """Per-shard leading shape: divide ``out_leading_shape[0]`` by the mesh factor on ``leading_axis``.""" + axes = leading_axis if isinstance(leading_axis, tuple) else (leading_axis,) + factor = 1 + for a in axes: + factor *= mesh.shape[a] + assert ( + out_leading_shape[0] % factor == 0 + ), f"leading dim {out_leading_shape[0]} not divisible by shard factor {factor} on axes {axes}" + return (out_leading_shape[0] // factor,) + tuple(out_leading_shape[1:]) + + +class EpCombinePrimitive(BasePrimitive): + """FFI primitive for nvte_ep_combine (forward).""" + + name = "te_ep_combine_ffi" + multiple_results = False + impl_static_args = (2, 3, 4, 5) # top_k, dispatch_output_per_expert_alignment, + # out_leading_shape, out_partition_spec + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + handle_mem_aval, + expert_out_aval, + *, + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + ): + del top_k, dispatch_output_per_expert_alignment, out_partition_spec, handle_mem_aval + assert ( + len(expert_out_aval.shape) == 3 + ), f"expert_out must be 3D [num_procs, recv_pr, H], got shape {expert_out_aval.shape}" + eo_dtype = dtypes.canonicalize_dtype(expert_out_aval.dtype) + hidden_dim = expert_out_aval.shape[-1] + out_shape = tuple(out_leading_shape) + (hidden_dim,) + return jax.core.ShapedArray(out_shape, eo_dtype) + + @staticmethod + def lowering( + ctx, + handle_mem, + expert_out, + *, + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + ): + del out_leading_shape, out_partition_spec + return ffi.ffi_lowering(EpCombinePrimitive.name)( + ctx, + handle_mem, + expert_out, + top_k=int(top_k), + dispatch_output_per_expert_alignment=int(dispatch_output_per_expert_alignment), + ) + + @staticmethod + def impl( + handle_mem, + expert_out, + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + ): + assert EpCombinePrimitive.inner_primitive is not None + return EpCombinePrimitive.inner_primitive.bind( + handle_mem, + expert_out, + top_k=top_k, + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + out_leading_shape=out_leading_shape, + out_partition_spec=out_partition_spec, + ) + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + ): + raise NotImplementedError("EpCombinePrimitive does not support vmap") + + @staticmethod + def partition( + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + mesh, + arg_infos, + result_infos, + ): + del result_infos + eo_spec = arg_infos[1].sharding.spec + if not _ep_spec_ok(eo_spec, trailing_count=2): + raise NotImplementedError( + "EpCombine: expert_out must be sharded as PartitionSpec(ep_resource," + " None, None) (or ((dp, ep), None, None) when dp/fsdp is set)" + f" over [num_procs, recv_pr, H]; got spec={eo_spec}." + ) + per_shard_leading = _leading_per_shard(out_leading_shape, out_partition_spec[0], mesh) + arg_shardings = tuple(a.sharding for a in arg_infos) + out_sharding = NamedSharding(mesh, PartitionSpec(*out_partition_spec)) + + def sharded_impl(handle_mem, expert_out): + return EpCombinePrimitive.impl( + handle_mem, + expert_out, + top_k, + dispatch_output_per_expert_alignment, + per_shard_leading, + out_partition_spec, + ) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Static args: + # (top_k, dispatch_alignment, out_leading_shape, out_partition_spec). + result_types = args[-1] + out_rank = len(result_types[0].shape) + out_axes = " ".join(f"O{i}" for i in range(out_rank - 1)) + " H" + return f"EPL hm, EPL recv_pr H -> {out_axes}" + + +register_primitive(EpCombinePrimitive) + + +# ── ep_dispatch_bwd ───────────────────────────────────────────────────────── + + +class EpDispatchBwdPrimitive(BasePrimitive): + """FFI primitive for the backward of nvte_ep_dispatch.""" + + name = "te_ep_dispatch_bwd_ffi" + multiple_results = True + impl_static_args = (3, 4, 5, 6) # top_k, dispatch_output_per_expert_alignment, + # out_leading_shape, out_partition_spec + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + handle_mem_aval, + grad_aval, + g_recv_topk_weights_aval, + *, + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + ): + del dispatch_output_per_expert_alignment + del g_recv_topk_weights_aval, out_partition_spec, handle_mem_aval + assert ( + len(grad_aval.shape) == 3 + ), f"grad must be 3D [num_procs, recv_pr, H], got shape {grad_aval.shape}" + g_dtype = dtypes.canonicalize_dtype(grad_aval.dtype) + hidden_dim = grad_aval.shape[-1] + result_aval = jax.core.ShapedArray(tuple(out_leading_shape) + (hidden_dim,), g_dtype) + grad_topk_weights_aval = jax.core.ShapedArray( + tuple(out_leading_shape) + (top_k,), jnp.float32 + ) + return result_aval, grad_topk_weights_aval + + @staticmethod + def lowering( + ctx, + handle_mem, + grad, + g_recv_topk_weights, + *, + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + ): + del out_leading_shape, out_partition_spec + return ffi.ffi_lowering(EpDispatchBwdPrimitive.name)( + ctx, + handle_mem, + grad, + g_recv_topk_weights, + top_k=int(top_k), + dispatch_output_per_expert_alignment=int(dispatch_output_per_expert_alignment), + ) + + @staticmethod + def impl( + handle_mem, + grad, + g_recv_topk_weights, + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + ): + assert EpDispatchBwdPrimitive.inner_primitive is not None + return EpDispatchBwdPrimitive.inner_primitive.bind( + handle_mem, + grad, + g_recv_topk_weights, + top_k=top_k, + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + out_leading_shape=out_leading_shape, + out_partition_spec=out_partition_spec, + ) + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + ): + raise NotImplementedError("EpDispatchBwdPrimitive does not support vmap") + + @staticmethod + def partition( + top_k, + dispatch_output_per_expert_alignment, + out_leading_shape, + out_partition_spec, + mesh, + arg_infos, + result_infos, + ): + del result_infos + g_spec = arg_infos[1].sharding.spec + if not _ep_spec_ok(g_spec, trailing_count=2): + raise NotImplementedError( + "EpDispatchBwd: grad must be sharded as PartitionSpec(ep_resource," + " None, None) (or ((dp, ep), None, None) when dp/fsdp is set)" + f" over [num_procs, recv_pr, H]; got spec={g_spec}." + ) + gw_spec = arg_infos[2].sharding.spec + if not _ep_spec_ok(gw_spec, trailing_count=1): + raise NotImplementedError( + "EpDispatchBwd: g_recv_topk_weights must be sharded as" + " PartitionSpec(ep_resource, None) (or ((dp, ep), None) when dp/fsdp is set)" + f" over [num_procs, recv_pr]; got spec={gw_spec}." + ) + if gw_spec[0] != g_spec[0]: + raise NotImplementedError( + "EpDispatchBwd: grad and g_recv_topk_weights must share the leading" + f" axis; got grad={g_spec}, g_recv_topk_weights={gw_spec}." + ) + per_shard_leading = _leading_per_shard(out_leading_shape, out_partition_spec[0], mesh) + arg_shardings = tuple(a.sharding for a in arg_infos) + out_sharding = NamedSharding(mesh, PartitionSpec(*out_partition_spec)) + out_shardings = [out_sharding, out_sharding] + + def sharded_impl(handle_mem, grad, g_recv_topk_weights): + return EpDispatchBwdPrimitive.impl( + handle_mem, + grad, + g_recv_topk_weights, + top_k, + dispatch_output_per_expert_alignment, + per_shard_leading, + out_partition_spec, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # Signature: (*static_args, mesh, value_types, result_types). Result rank + # follows out_leading_shape (static arg #2): rank = len(out_leading) + 1. + result_types = args[-1] + out_rank = len(result_types[0].shape) + out_axes = " ".join(f"O{i}" for i in range(out_rank - 1)) + return f"EPL hm, EPL recv_pr H, EPL recv_pr -> {out_axes} H, {out_axes} k" + + +register_primitive(EpDispatchBwdPrimitive) + + +# ── ep_combine_bwd ────────────────────────────────────────────────────────── + + +class EpCombineBwdPrimitive(BasePrimitive): + """FFI primitive for the backward of nvte_ep_combine.""" + + name = "te_ep_combine_bwd_ffi" + multiple_results = False + impl_static_args = (2, 3, 4, 5) # top_k, dispatch_output_per_expert_alignment, + # recv_capacity_per_rank, is_outer + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + handle_mem_aval, + grad_aval, + *, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + ): + # is_outer=True: global leading dim = (dp*ep,) (or (ep,) with no DP); + # False: per-shard = (1,). + del top_k, dispatch_output_per_expert_alignment, handle_mem_aval + assert ( + len(grad_aval.shape) >= 2 + ), f"grad must be at least 2D [..., H], got shape {grad_aval.shape}" + g_dtype = dtypes.canonicalize_dtype(grad_aval.dtype) + hidden_dim = grad_aval.shape[-1] + leading = _ep_leading_dims(is_outer) + return jax.core.ShapedArray(leading + (recv_capacity_per_rank, hidden_dim), g_dtype) + + @staticmethod + def outer_abstract(*args, **kwargs): + kwargs["is_outer"] = True + return EpCombineBwdPrimitive.abstract(*args, **kwargs) # pylint: disable=missing-kwoa + + @staticmethod + def lowering( + ctx, + handle_mem, + grad, + *, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + ): + del recv_capacity_per_rank, is_outer + return ffi.ffi_lowering(EpCombineBwdPrimitive.name)( + ctx, + handle_mem, + grad, + top_k=int(top_k), + dispatch_output_per_expert_alignment=int(dispatch_output_per_expert_alignment), + ) + + @staticmethod + def impl( + handle_mem, + grad, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + ): + assert EpCombineBwdPrimitive.inner_primitive is not None + return EpCombineBwdPrimitive.inner_primitive.bind( + handle_mem, + grad, + top_k=top_k, + dispatch_output_per_expert_alignment=dispatch_output_per_expert_alignment, + recv_capacity_per_rank=recv_capacity_per_rank, + is_outer=is_outer, + ) + + @staticmethod + def batcher( + batched_args, + batch_dims, + *, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + ): + raise NotImplementedError("EpCombineBwdPrimitive does not support vmap") + + @staticmethod + def partition( + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + is_outer, + mesh, + arg_infos, + result_infos, + ): + del is_outer, result_infos + arg_shardings = tuple(a.sharding for a in arg_infos) + # EP-output leading (trailing dims auto-pad to None). + out_sharding = NamedSharding(mesh, _ep_output_spec()) + + def sharded_impl(handle_mem, grad): + return EpCombineBwdPrimitive.impl( + handle_mem, + grad, + top_k, + dispatch_output_per_expert_alignment, + recv_capacity_per_rank, + False, + ) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(*args): + # T axes are dynamic-rank based on the actual cotangent shape. + value_types = args[-2] + g_rank = len(value_types[1].shape) + g_axes = " ".join(f"T{i}" for i in range(g_rank - 1)) + " H" + return f"EPL hm, {g_axes} -> EPL recv_pr H" + + +register_primitive(EpCombineBwdPrimitive) + + +# ── Public-ish helpers (used by jax/ep.py) ────────────────────────────────── + + +def ep_prepare(cfg: EpLayerConfig, topk_idx): + """Exchange routing metadata for ``cfg``; return ``(token_counts, handle_mem)``.""" + return EpPreparePrimitive.outer_primitive.bind( + topk_idx, + top_k=int(cfg.top_k), + dispatch_output_per_expert_alignment=int(cfg.dispatch_output_per_expert_alignment), + is_outer=True, + ) + + +def ep_dispatch_fwd( + cfg: EpLayerConfig, handle_mem, topk_idx, tokens, topk_weights, recv_capacity_per_rank +): + """Scatter tokens and weights to expert ranks; returns (recv_tokens, recv_topk_weights).""" + return EpDispatchPrimitive.outer_primitive.bind( + handle_mem, + topk_idx, + tokens, + topk_weights, + top_k=int(cfg.top_k), + dispatch_output_per_expert_alignment=int(cfg.dispatch_output_per_expert_alignment), + recv_capacity_per_rank=recv_capacity_per_rank, + is_outer=True, + ) + + +def ep_combine_fwd( + cfg: EpLayerConfig, handle_mem, expert_out, num_local_tokens, out_partition_spec=None +): + """Gather expert outputs back to home ranks. expert_out is pre-weighted.""" + out_leading = _normalize_leading_shape(num_local_tokens) + return EpCombinePrimitive.outer_primitive.bind( + handle_mem, + expert_out, + top_k=int(cfg.top_k), + dispatch_output_per_expert_alignment=int(cfg.dispatch_output_per_expert_alignment), + out_leading_shape=out_leading, + out_partition_spec=out_partition_spec, + ) + + +def ep_dispatch_bwd( + cfg: EpLayerConfig, + handle_mem, + grad, + g_recv_topk_weights, + num_local_tokens, + out_partition_spec=None, +): + """Backward of dispatch; returns (grad_tokens, grad_topk_weights).""" + out_leading = _normalize_leading_shape(num_local_tokens) + return EpDispatchBwdPrimitive.outer_primitive.bind( + handle_mem, + grad, + g_recv_topk_weights, + top_k=int(cfg.top_k), + dispatch_output_per_expert_alignment=int(cfg.dispatch_output_per_expert_alignment), + out_leading_shape=out_leading, + out_partition_spec=out_partition_spec, + ) + + +def ep_combine_bwd(cfg: EpLayerConfig, handle_mem, grad, recv_capacity_per_rank): + """Backward of combine; returns grad_expert_out [num_procs, recv_capacity_per_rank, H].""" + return EpCombineBwdPrimitive.outer_primitive.bind( + handle_mem, + grad, + top_k=int(cfg.top_k), + dispatch_output_per_expert_alignment=int(cfg.dispatch_output_per_expert_alignment), + recv_capacity_per_rank=recv_capacity_per_rank, + is_outer=True, + ) diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py index 3245439689..46f51c9d33 100644 --- a/transformer_engine/jax/cpp_extensions/router.py +++ b/transformer_engine/jax/cpp_extensions/router.py @@ -412,7 +412,12 @@ def partition( arg_infos, result_infos, ): - del result_infos, routing_map_format + # NOTE: do NOT include ``routing_map_format`` in this ``del``: the + # ``sharded_impl`` closure below resolves it by name at call time + # (when XLA invokes the partitioned impl), so deleting it here + # raises ``NameError: cannot access free variable 'routing_map_format'`` + # at execution time of the bwd custom_partitioning. + del result_infos grad_spec = get_padded_spec(arg_infos[2]) out_sharding = NamedSharding(mesh, PartitionSpec(*grad_spec)) arg_shardings = (arg_infos[0].sharding, arg_infos[1].sharding, arg_infos[2].sharding) @@ -645,7 +650,14 @@ def shardy_sharding_rule(*args): # backward reconstructs the full [num_tokens, num_experts] grad_probs from # scalar inputs. Shardy will leave num_tokens unsharded, which matches the # replicated PartitionSpec(None, None) in partition(). - return "const_buf_one, num_experts, grad_one -> i num_experts" + # + # grad_aux_loss is the cotangent of a scalar loss and is therefore + # rank-0; the third operand entry is empty (no factor labels). Declaring + # it with the spurious "grad_one" factor gave it rank-1 and tripped + # JAX's custom_partitioning_sharding_rule check once the MoE block + # lifted its aux-loss path out of shard_map (the rule is skipped under + # shard_map, which is why this surfaces only at global view). + return "const_buf_one, num_experts, -> i num_experts" register_primitive(FusedMoEAuxLossBwdPrimitive) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index c0fa3acaeb..b9c7c849f2 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -204,6 +204,28 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedMoEAuxLossBackwardHandler); +// Bootstrap EP (eager NCCL comm init); anchor released by ReleaseEpResources. +// max_token_dtype is the NVTEDType enum value (int) for the widest token dtype +// the group will dispatch. +void SetEpBootstrapParams(pybind11::bytes unique_id_bytes, int ep_size, int rank_within_group, + int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, + int hidden_dim, int max_num_sms, int max_token_dtype); +void ReleaseEpResources(); +// Return the handle_mem byte size for a layer config. +size_t EpHandleMemSize(int top_k, size_t dispatch_output_per_expert_alignment); + +// EpInstanceState type_id / type_info capsules for jax.ffi.register_ffi_type. +pybind11::capsule GetEpInstanceStateTypeIdCapsule(); +pybind11::capsule GetEpInstanceStateTypeInfoCapsule(); + +// EP FFI handlers +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpInstantiateHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpPrepareHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpDispatchHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpCombineHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpDispatchBwdHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(EpCombineBwdHandler); + // TopK XLA_FFI_DECLARE_HANDLER_SYMBOL(TopkHandler); pybind11::tuple GetTopkWorkspaceSizes(int batch_size, int seq_len, int k); diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp new file mode 100644 index 0000000000..ee204e7594 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -0,0 +1,497 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifdef NVTE_WITH_NCCL_EP + +#include "transformer_engine/ep.h" + +#include + +#include +#include +#include +#include +#include + +#include "../extensions.h" +#include "common.h" +#include "transformer_engine/gemm.h" + +namespace transformer_engine { +namespace jax { + +// NCCL comm + EPBackend lifetime tracks live JAX executables via XLA stateful FFI. + +struct EpBootstrapParams { + std::array uid_bytes{}; + int ep_size = 0; + int rank_within_group = 0; + int num_experts = 0; + int max_tokens_per_rank = 0; + int max_recv_tokens_per_rank = 0; + int hidden_dim = 0; + int max_num_sms = 0; + NVTEDType max_token_dtype = kNVTEBFloat16; +}; + +class EpResources { + public: + explicit EpResources(const EpBootstrapParams& p) { + ncclUniqueId uid; + std::memcpy(&uid, p.uid_bytes.data(), sizeof(uid)); + NVTE_CHECK_NCCL(ncclCommInitRank(&comm_, p.ep_size, uid, p.rank_within_group)); + // zero_copy=0: JAX EP path always stages payloads; the zero-copy fast path + // requires NVTECommWindow-backed tensors, which JAX bindings don't expose. + NVTEEpGroupConfig cfg{.ep_size = p.ep_size, + .num_experts = p.num_experts, + .max_tokens_per_rank = p.max_tokens_per_rank, + .max_recv_tokens_per_rank = p.max_recv_tokens_per_rank, + .hidden_dim = p.hidden_dim, + .max_num_sms = p.max_num_sms, + .max_token_dtype = p.max_token_dtype, + .zero_copy = 0}; + try { + nvte_ep_initialize(static_cast(comm_), cfg); + } catch (...) { + ncclCommDestroy(comm_); + comm_ = nullptr; + throw; + } + } + + ~EpResources() { + if (comm_ == nullptr) return; + nvte_ep_shutdown(); + ncclCommDestroy(comm_); + } + + EpResources(const EpResources&) = delete; + EpResources& operator=(const EpResources&) = delete; + + ncclComm_t comm() const { return comm_; } + + private: + ncclComm_t comm_{nullptr}; +}; + +struct EpInstanceState { + static ::xla::ffi::TypeId id; + static ::xla::ffi::TypeInfo info; + std::shared_ptr resources; +}; + +::xla::ffi::TypeId EpInstanceState::id = {}; +::xla::ffi::TypeInfo EpInstanceState::info = ::xla::ffi::MakeTypeInfo(); + +namespace { + +std::mutex g_ep_mu; +EpBootstrapParams g_ep_params; +bool g_ep_params_set = false; +std::weak_ptr g_ep_resources_weak; +// Python-held anchor so trace-time handle_mem allocs find EPBackend ready. +std::shared_ptr g_ep_resources_anchor; + +std::shared_ptr AcquireEpResources() { + std::lock_guard lock(g_ep_mu); + NVTE_CHECK(g_ep_params_set, + "EP bootstrap params not set; call transformer_engine_jax." + "set_ep_bootstrap_params() (typically via ep_bootstrap) first."); + auto sp = g_ep_resources_weak.lock(); + if (sp) return sp; + sp = std::make_shared(g_ep_params); + g_ep_resources_weak = sp; + return sp; +} + +} // namespace + +// top_k and dispatch_output_per_expert_alignment are baked as static FFI +// attributes; prepare passes them to the C API as NVTEEpLayerConfig, and the +// per-step ops carry top_k only to validate the topk_idx last dim. + +struct EpConfig { + int64_t top_k; + int64_t dispatch_output_per_expert_alignment; +}; + +// ── Bootstrap helpers ───────────────────────────────────────────────────────── + +// Caches uid + group config and eagerly creates the NCCL comm (ranks +// synchronize via the UID broadcast). +void SetEpBootstrapParams(pybind11::bytes unique_id_bytes_obj, int ep_size, int rank_within_group, + int num_experts, int max_tokens_per_rank, int max_recv_tokens_per_rank, + int hidden_dim, int max_num_sms, int max_token_dtype) { + std::string uid_str = unique_id_bytes_obj; + NVTE_CHECK(static_cast(uid_str.size()) >= 128, + "unique_id_bytes must be at least 128 bytes (ncclUniqueId size)."); + std::shared_ptr anchor; + { + std::lock_guard lock(g_ep_mu); + NVTE_CHECK(!g_ep_resources_anchor, + "EP bootstrap already initialized; call release_ep_resources() before re-init."); + std::memcpy(g_ep_params.uid_bytes.data(), uid_str.data(), 128); + g_ep_params.ep_size = ep_size; + g_ep_params.rank_within_group = rank_within_group; + g_ep_params.num_experts = num_experts; + g_ep_params.max_tokens_per_rank = max_tokens_per_rank; + g_ep_params.max_recv_tokens_per_rank = max_recv_tokens_per_rank; + g_ep_params.hidden_dim = hidden_dim; + g_ep_params.max_num_sms = max_num_sms; + g_ep_params.max_token_dtype = static_cast(max_token_dtype); + g_ep_params_set = true; + } + // Acquire outside the lock: EpResources ctor runs ncclCommInitRank which is + // a collective and may block on peer ranks. + anchor = AcquireEpResources(); + std::lock_guard lock(g_ep_mu); + g_ep_resources_anchor = std::move(anchor); +} + +// Drops the anchor; comm tears down once the last executable also releases. +void ReleaseEpResources() { + std::shared_ptr to_drop; + { + std::lock_guard lock(g_ep_mu); + to_drop = std::move(g_ep_resources_anchor); + } + // to_drop dtor runs outside the lock. +} + +size_t EpHandleMemSize(int top_k, size_t dispatch_output_per_expert_alignment) { + NVTEEpLayerConfig layer_cfg{top_k, dispatch_output_per_expert_alignment}; + return nvte_ep_handle_mem_size(layer_cfg); +} + +pybind11::capsule GetEpInstanceStateTypeIdCapsule() { + return pybind11::capsule(static_cast(&EpInstanceState::id), "xla.ffi.type_id"); +} + +pybind11::capsule GetEpInstanceStateTypeInfoCapsule() { + return pybind11::capsule(static_cast(&EpInstanceState::info), "xla.ffi.type_info"); +} + +// ── Instantiate handler ───────────────────────────────────────────────────── + +static ::xla::ffi::ErrorOr> EpInstantiateImpl() { + auto state = std::make_unique(); + try { + state->resources = AcquireEpResources(); + } catch (const std::exception& e) { + return ::xla::ffi::Unexpected( + ::xla::ffi::Error::Internal(std::string("EP instantiate failed: ") + e.what())); + } + return state; +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpInstantiateHandler, EpInstantiateImpl, FFI::BindInstantiate()); + +// ── ep_prepare ──────────────────────────────────────────────────────────────── + +Error_Type EpPrepareFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type topk_idx, + Result_Type token_counts, Result_Type handle_mem, Result_Type workspace, + EpConfig config) { + (void)ep_state; // lifetime only. + auto topk_dims = topk_idx.dimensions(); + NVTE_CHECK(topk_dims.size() >= 2, + "topk_idx must be at least 2D [..., top_k], got ndim=", topk_dims.size()); + auto idx_etype = topk_idx.element_type(); + NVTE_CHECK(idx_etype == ::xla::ffi::DataType::S64 || idx_etype == ::xla::ffi::DataType::S32, + "topk_idx must be int32 or int64; got element_type=", static_cast(idx_etype)); + + std::vector topk_shape = {product(topk_dims, 0, topk_dims.size() - 1), + static_cast(topk_dims.back())}; + // NCCL EP currently requires int64 topk_idx; upcast int32 on-stream. + // TODO(phuong): drop once NCCL EP accepts int32. + void* topk_idx_data = topk_idx.untyped_data(); + if (idx_etype == ::xla::ffi::DataType::S32) { + const size_t n = topk_shape[0] * topk_shape[1]; + NVTE_CHECK(static_cast(workspace->element_count()) >= n, + "workspace too small for int32 → int64 upcast: element_count=", + workspace->element_count(), " < required ", n); + int64_t* ws = reinterpret_cast(workspace->untyped_data()); + nvte_convert_int32_to_int64(reinterpret_cast(topk_idx_data), ws, n, stream); + topk_idx_data = ws; + } + auto topk_idx_ = TensorWrapper(topk_idx_data, topk_shape, DType::kInt64); + + std::vector tc_shape = {static_cast(token_counts->element_count())}; + auto token_counts_ = TensorWrapper(token_counts->untyped_data(), tc_shape, DType::kInt32); + + std::vector hm_shape = {static_cast(handle_mem->element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem->untyped_data(), hm_shape, DType::kByte); + + NVTEEpLayerConfig layer_cfg{static_cast(config.top_k), + static_cast(config.dispatch_output_per_expert_alignment)}; + nvte_ep_prepare(handle_mem_.data(), topk_idx_.data(), token_counts_.data(), layer_cfg, stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpPrepareHandler, EpPrepareFFI, + FFI::Bind() + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // topk_idx + .Ret() // token_counts + .Ret() // handle_mem + .Ret() // workspace (FFI scratch) + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_dispatch ─────────────────────────────────────────────────────────────── + +Error_Type EpDispatchFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type topk_idx, Buffer_Type tokens, Buffer_Type topk_weights, + Result_Type recv_tokens, Result_Type recv_topk_weights, + Result_Type workspace, EpConfig config) { + (void)ep_state; + auto token_dims = tokens.dimensions(); + NVTE_CHECK(token_dims.size() >= 2, + "tokens must be at least 2D [..., H], got ndim=", token_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + auto idx_dims = topk_idx.dimensions(); + NVTE_CHECK(idx_dims.size() >= 2, + "topk_idx must be at least 2D [..., top_k], got ndim=", idx_dims.size()); + auto idx_etype = topk_idx.element_type(); + NVTE_CHECK(idx_etype == ::xla::ffi::DataType::S64 || idx_etype == ::xla::ffi::DataType::S32, + "topk_idx must be int32 or int64; got element_type=", static_cast(idx_etype)); + NVTE_CHECK(static_cast(idx_dims.back()) == config.top_k, "top_k attr (", config.top_k, + ") must match topk_idx last dim (", idx_dims.back(), ")"); + std::vector idx_shape = {product(idx_dims, 0, idx_dims.size() - 1), + static_cast(idx_dims.back())}; + // NCCL EP currently requires int64 topk_idx; upcast int32 on-stream. + // TODO(phuong): drop once NCCL EP accepts int32. + void* topk_idx_data = topk_idx.untyped_data(); + if (idx_etype == ::xla::ffi::DataType::S32) { + const size_t n = idx_shape[0] * idx_shape[1]; + NVTE_CHECK(static_cast(workspace->element_count()) >= n, + "workspace too small for int32 → int64 upcast: element_count=", + workspace->element_count(), " < required ", n); + int64_t* ws = reinterpret_cast(workspace->untyped_data()); + nvte_convert_int32_to_int64(reinterpret_cast(topk_idx_data), ws, n, stream); + topk_idx_data = ws; + } + auto topk_idx_ = TensorWrapper(topk_idx_data, idx_shape, DType::kInt64); + + const size_t T_flat = product(token_dims, 0, token_dims.size() - 1); + const size_t H = static_cast(token_dims.back()); + std::vector tok_shape = {T_flat, H}; + auto token_dtype = convert_ffi_datatype_to_te_dtype(tokens.element_type()); + auto tokens_ = TensorWrapper(tokens.untyped_data(), tok_shape, token_dtype); + + auto tw_dims = topk_weights.dimensions(); + NVTE_CHECK(tw_dims.size() >= 2, + "topk_weights must be at least 2D [..., top_k], got ndim=", tw_dims.size()); + std::vector tw_shape = {product(tw_dims, 0, tw_dims.size() - 1), + static_cast(tw_dims.back())}; + auto topk_weights_ = TensorWrapper(topk_weights.untyped_data(), tw_shape, DType::kFloat32); + + // recv_tokens: flatten any leading dims into recv_capacity_per_rank. + auto recv_dims = recv_tokens->dimensions(); + NVTE_CHECK(recv_dims.size() >= 2, + "recv_tokens must be at least 2D [..., recv_pr, H]; got ndim=", recv_dims.size()); + const size_t recv_capacity_per_rank = product(recv_dims, 0, recv_dims.size() - 1); + std::vector recv_shape = {recv_capacity_per_rank, H}; + auto recv_tokens_ = TensorWrapper(recv_tokens->untyped_data(), recv_shape, token_dtype); + + auto recv_w_dims = recv_topk_weights->dimensions(); + NVTE_CHECK(recv_w_dims.size() >= 1, + "recv_topk_weights must be at least 1D; got ndim=", recv_w_dims.size()); + const size_t recv_w_total = product(recv_w_dims, 0, recv_w_dims.size()); + NVTE_CHECK(recv_w_total == recv_capacity_per_rank, "recv_topk_weights total (", recv_w_total, + ") must match recv_tokens recv_pr (", recv_capacity_per_rank, ")"); + std::vector recv_w_shape = {recv_capacity_per_rank}; + auto recv_topk_weights_ = + TensorWrapper(recv_topk_weights->untyped_data(), recv_w_shape, DType::kFloat32); + + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_dispatch(handle_mem_.data(), topk_idx_.data(), tokens_.data(), no_win, + topk_weights_.data(), no_win, recv_tokens_.data(), no_win, + recv_topk_weights_.data(), no_win, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchHandler, EpDispatchFFI, + FFI::Bind() + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // topk_idx + .Arg() // tokens + .Arg() // topk_weights + .Ret() // recv_tokens + .Ret() // recv_topk_weights + .Ret() // workspace (FFI scratch) + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_combine ──────────────────────────────────────────────────────────────── + +Error_Type EpCombineFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type expert_out, Result_Type result, EpConfig config) { + (void)ep_state; + auto eo_dims = expert_out.dimensions(); + NVTE_CHECK(eo_dims.size() >= 2, + "expert_out must be at least 2D [..., recv_pr, H]; got ndim=", eo_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + const size_t recv_capacity_per_rank = product(eo_dims, 0, eo_dims.size() - 1); + const size_t H = static_cast(eo_dims.back()); + std::vector eo_shape = {recv_capacity_per_rank, H}; + auto eo_dtype = convert_ffi_datatype_to_te_dtype(expert_out.element_type()); + auto expert_out_ = TensorWrapper(expert_out.untyped_data(), eo_shape, eo_dtype); + + auto res_dims = result->dimensions(); + NVTE_CHECK(res_dims.size() >= 2, + "result must be at least 2D [..., H]; got ndim=", res_dims.size()); + const size_t res_T_flat = product(res_dims, 0, res_dims.size() - 1); + std::vector res_shape = {res_T_flat, H}; + auto result_ = TensorWrapper(result->untyped_data(), res_shape, eo_dtype); + + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_combine(handle_mem_.data(), expert_out_.data(), no_win, result_.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineHandler, EpCombineFFI, + FFI::Bind() + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // expert_out + .Ret() // result + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_dispatch_bwd ─────────────────────────────────────────────────────────── + +Error_Type EpDispatchBwdFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type grad, Buffer_Type g_recv_topk_weights, + Result_Type grad_tokens, Result_Type grad_topk_weights, + EpConfig config) { + (void)ep_state; + auto grad_dims = grad.dimensions(); + NVTE_CHECK(grad_dims.size() >= 2, + "grad must be at least 2D [..., recv_pr, H]; got ndim=", grad_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + const size_t recv_capacity_per_rank = product(grad_dims, 0, grad_dims.size() - 1); + const size_t H = static_cast(grad_dims.back()); + std::vector g_shape = {recv_capacity_per_rank, H}; + auto g_dtype = convert_ffi_datatype_to_te_dtype(grad.element_type()); + auto grad_ = TensorWrapper(grad.untyped_data(), g_shape, g_dtype); + + auto gw_dims = g_recv_topk_weights.dimensions(); + NVTE_CHECK( + gw_dims.size() >= 1, + "g_recv_topk_weights rank must flatten to recv_capacity_per_rank; got ndim=", gw_dims.size()); + const size_t gw_total = product(gw_dims, 0, gw_dims.size()); + NVTE_CHECK(gw_total == recv_capacity_per_rank, "g_recv_topk_weights total (", gw_total, + ") must match grad recv_pr (", recv_capacity_per_rank, ")"); + std::vector gw_shape = {recv_capacity_per_rank}; + auto g_recv_topk_weights_ = + TensorWrapper(g_recv_topk_weights.untyped_data(), gw_shape, DType::kFloat32); + + auto out_dims = grad_tokens->dimensions(); + NVTE_CHECK(out_dims.size() >= 2, + "grad_tokens must be at least 2D [..., H], got ndim=", out_dims.size()); + const size_t T_flat = product(out_dims, 0, out_dims.size() - 1); + std::vector out_shape = {T_flat, H}; + auto grad_tokens_ = TensorWrapper(grad_tokens->untyped_data(), out_shape, g_dtype); + + auto gtw_dims = grad_topk_weights->dimensions(); + NVTE_CHECK(gtw_dims.size() >= 2, + "grad_topk_weights must be at least 2D [..., top_k]; got ndim=", gtw_dims.size()); + const size_t gtw_T_flat = product(gtw_dims, 0, gtw_dims.size() - 1); + NVTE_CHECK(gtw_T_flat == T_flat, "grad_topk_weights leading-dim product (", gtw_T_flat, + ") must equal grad_tokens leading-dim product (", T_flat, ")"); + const size_t top_k = static_cast(gtw_dims.back()); + NVTE_CHECK(static_cast(top_k) == config.top_k, "top_k attr (", config.top_k, + ") must match grad_topk_weights last dim (", top_k, ")"); + std::vector gtw_shape = {T_flat, top_k}; + auto grad_topk_weights_ = + TensorWrapper(grad_topk_weights->untyped_data(), gtw_shape, DType::kFloat32); + + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_dispatch_bwd(handle_mem_.data(), grad_.data(), no_win, g_recv_topk_weights_.data(), + no_win, grad_tokens_.data(), grad_topk_weights_.data(), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpDispatchBwdHandler, EpDispatchBwdFFI, + FFI::Bind() + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // grad (w.r.t. recv_tokens) + .Arg() // g_recv_topk_weights + .Ret() // grad_tokens + .Ret() // grad_topk_weights + .Attrs(), + FFI_CudaGraph_Traits); + +// ── ep_combine_bwd ──────────────────────────────────────────────────────────── + +Error_Type EpCombineBwdFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type handle_mem, + Buffer_Type grad, Result_Type grad_expert_out, EpConfig config) { + (void)ep_state; + auto grad_dims = grad.dimensions(); + NVTE_CHECK(grad_dims.size() >= 2, + "grad must be at least 2D [..., H], got ndim=", grad_dims.size()); + + std::vector hm_shape = {static_cast(handle_mem.element_count())}; + auto handle_mem_ = TensorWrapper(handle_mem.untyped_data(), hm_shape, DType::kByte); + + const size_t T_flat = product(grad_dims, 0, grad_dims.size() - 1); + const size_t H = static_cast(grad_dims.back()); + std::vector g_shape = {T_flat, H}; + auto g_dtype = convert_ffi_datatype_to_te_dtype(grad.element_type()); + auto grad_ = TensorWrapper(grad.untyped_data(), g_shape, g_dtype); + + auto out_dims = grad_expert_out->dimensions(); + NVTE_CHECK(out_dims.size() >= 2, + "grad_expert_out must be at least 2D [..., recv_pr, H]; got ndim=", out_dims.size()); + const size_t recv_capacity_per_rank = product(out_dims, 0, out_dims.size() - 1); + const size_t out_H = static_cast(out_dims.back()); + NVTE_CHECK(out_H == H, "grad_expert_out hidden dim (", out_H, ") must match grad H (", H, ")"); + std::vector out_shape = {recv_capacity_per_rank, H}; + auto grad_expert_out_ = TensorWrapper(grad_expert_out->untyped_data(), out_shape, g_dtype); + + NVTECommWindow no_win{nullptr, 0}; + nvte_ep_combine_bwd(handle_mem_.data(), grad_.data(), no_win, grad_expert_out_.data(), no_win, + stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(EpCombineBwdHandler, EpCombineBwdFFI, + FFI::Bind() + .Ctx() // stream + .Ctx<::xla::ffi::State>() // EP state + .Arg() // handle_mem + .Arg() // grad (w.r.t. result) + .Ret() // grad_expert_out + .Attrs(), + FFI_CudaGraph_Traits); + +} // namespace jax +} // namespace transformer_engine + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::EpConfig, ::xla::ffi::StructMember("top_k"), + ::xla::ffi::StructMember("dispatch_output_per_expert_alignment")); + +#endif // NVTE_WITH_NCCL_EP diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 2432f65005..db5468afe6 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -107,6 +107,25 @@ pybind11::dict Registrations() { dict["te_fused_moe_aux_loss_forward_ffi"] = EncapsulateFFI(FusedMoEAuxLossForwardHandler); dict["te_fused_moe_aux_loss_backward_ffi"] = EncapsulateFFI(FusedMoEAuxLossBackwardHandler); +#ifdef NVTE_WITH_NCCL_EP + // Expert Parallelism (instantiate handler pins NCCL comm to executable lifetime). + dict["te_ep_prepare_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpPrepareHandler)); + dict["te_ep_dispatch_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpDispatchHandler)); + dict["te_ep_combine_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpCombineHandler)); + dict["te_ep_dispatch_bwd_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpDispatchBwdHandler)); + dict["te_ep_combine_bwd_ffi"] = + pybind11::dict(pybind11::arg("instantiate") = EncapsulateFFI(EpInstantiateHandler), + pybind11::arg("execute") = EncapsulateFFI(EpCombineBwdHandler)); +#endif // NVTE_WITH_NCCL_EP + // TopK dict["te_topk_ffi"] = EncapsulateFFI(TopkHandler); @@ -136,6 +155,18 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("is_collective_gemm_with_cublasmp", &IsCollectiveGemmWithCublasmp); m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams); m.def("get_grouped_gemm_setup_workspace_size", &nvte_get_grouped_gemm_setup_workspace_size); +#ifdef NVTE_WITH_NCCL_EP + m.def("set_ep_bootstrap_params", &SetEpBootstrapParams, pybind11::arg("unique_id_bytes"), + pybind11::arg("ep_size"), pybind11::arg("rank_within_group"), pybind11::arg("num_experts"), + pybind11::arg("max_tokens_per_rank"), pybind11::arg("max_recv_tokens_per_rank"), + pybind11::arg("hidden_dim"), pybind11::arg("max_num_sms"), + pybind11::arg("max_token_dtype")); + m.def("release_ep_resources", &ReleaseEpResources); + m.def("ep_handle_mem_size", &EpHandleMemSize, pybind11::arg("top_k"), + pybind11::arg("dispatch_output_per_expert_alignment") = 0); + m.def("get_ep_instance_state_type_id", &GetEpInstanceStateTypeIdCapsule); + m.def("get_ep_instance_state_type_info", &GetEpInstanceStateTypeInfoCapsule); +#endif // NVTE_WITH_NCCL_EP pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) diff --git a/transformer_engine/jax/ep.py b/transformer_engine/jax/ep.py new file mode 100644 index 0000000000..666b46f95b --- /dev/null +++ b/transformer_engine/jax/ep.py @@ -0,0 +1,311 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX Expert Parallelism (EP) API.""" + +import atexit +import ctypes +from functools import partial + +import jax +import jax.numpy as jnp +import jax.experimental.multihost_utils as jmu +import numpy as np + +import transformer_engine_jax +import transformer_engine.jax.cpp_extensions as tex +from transformer_engine.jax.cpp_extensions.ep import _ep_outer_axis +from transformer_engine.jax.cpp_extensions.misc import jax_dtype_to_te_dtype +from transformer_engine.jax.sharding import ( + global_mesh_resource, + get_mesh_axis_size, + with_sharding_constraint, +) + +ep_prepare = tex.ep_prepare +EpLayerConfig = tex.EpLayerConfig +ep_handle_mem_size = tex.ep_handle_mem_size + +__all__ = [ + "EpLayerConfig", + "ep_bootstrap", + "ep_handle_mem_size", + "ep_prepare", + "ep_dispatch", + "ep_combine", +] + +_atexit_registered = False + + +def _allgather_uid(uid_arr, world_size, uid_size): + """Allgather UID bytes across all processes. + + Tries ``jax.experimental.multihost_utils.process_allgather`` first; + falls back to an XLA collective (process-local sharded global array + replicated via ``jax.jit``) when the multihost helper returns a + short buffer, which has been observed under some launchers. + """ + try: + gathered = jmu.process_allgather(uid_arr, tiled=True) + if gathered.size == world_size * uid_size: + return np.asarray(gathered).reshape(world_size, uid_size) + except Exception: # pylint: disable=broad-except + pass + devices = np.asarray(jax.devices()) + if devices.size != world_size: + raise RuntimeError( + f"_allgather_uid fallback expected {world_size} global devices, got {devices.size}." + ) + mesh = jax.sharding.Mesh(devices, ("_uid_all",)) + sharded = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("_uid_all", None)) + replicated = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + local = np.asarray(uid_arr).reshape(1, uid_size) + g_in = jax.make_array_from_process_local_data(sharded, local, (world_size, uid_size)) + g_out = jax.jit(lambda x: x, out_shardings=replicated)(g_in) + return np.asarray(g_out).reshape(world_size, uid_size) + + +# ── Bootstrap ──────────────────────────────────────────────────────────────── + + +def ep_bootstrap( + world_size, + rank, + num_experts, + max_tokens_per_rank, + recv_capacity_per_rank, + hidden_dim, + max_token_dtype=jnp.bfloat16, + max_num_sms=0, +): + """Initialize the EP communicator. Call once per process before any EP op. + + Must run inside the active JAX Mesh and a global_shard_guard; ep_size and + num_ep_groups are read from the mesh axes named by MeshResource.ep_resource + and MeshResource.dp_resource/fsdp_resource. + + max_token_dtype is the widest jnp dtype the group will dispatch; tensors + passed to ep_dispatch may use any narrower dtype. + max_num_sms caps the SMs allotted to EP kernels (0 = auto). + """ + if jnp.dtype(max_token_dtype) != jnp.bfloat16: + raise NotImplementedError( + "ep_bootstrap: only max_token_dtype=jnp.bfloat16 is supported today, got" + f" {jnp.dtype(max_token_dtype)}." + ) + if world_size < 2: + raise ValueError( + f"ep_bootstrap requires world_size >= 2 (got {world_size}); NCCL EP needs" + " at least 2 ranks to form a group." + ) + if jax.local_device_count() != 1: + raise ValueError( + "ep_bootstrap requires one local device per process (got" + f" jax.local_device_count() = {jax.local_device_count()}); NCCL EP does not" + " support single-process multi-device setups." + ) + + gsr = global_mesh_resource() + ep_resource = gsr.ep_resource + if ep_resource is None: + raise ValueError( + "ep_bootstrap requires MeshResource.ep_resource to be set; enter a" + " global_shard_guard(MeshResource(..., ep_resource=)) before bootstrap." + ) + ep_size = get_mesh_axis_size(ep_resource) + outer_axis = _ep_outer_axis() + if outer_axis is None: + if world_size != ep_size: + raise ValueError( + f"ep_bootstrap: world_size ({world_size}) > ep_size ({ep_size}) but neither" + " MeshResource.dp_resource nor fsdp_resource is set; name the outer axis so" + " EP-output tensors can shard across EP groups." + ) + num_ep_groups = 1 + else: + num_ep_groups = get_mesh_axis_size(outer_axis) + if num_ep_groups * ep_size != world_size: + raise ValueError( + f"ep_bootstrap: num_ep_groups*ep_size ({num_ep_groups}*{ep_size}=" + f"{num_ep_groups * ep_size}) must equal world_size ({world_size}); check that" + f" the '{outer_axis}' and '{ep_resource}' mesh axes cover all ranks." + ) + if num_experts % ep_size != 0: + raise ValueError(f"num_experts ({num_experts}) must be divisible by ep_size ({ep_size}).") + + UID_SIZE = 128 + dp_color = rank // ep_size + rank_within_group = rank % ep_size + is_color_root = rank_within_group == 0 + if is_color_root: + libnccl = ctypes.CDLL("libnccl.so.2", use_errno=True) + uid_arr = (ctypes.c_uint8 * UID_SIZE)() + ret = libnccl.ncclGetUniqueId(ctypes.cast(uid_arr, ctypes.c_void_p)) + assert ret == 0, f"ncclGetUniqueId failed with code {ret}" + uid_bytes = bytes(uid_arr) + else: + uid_bytes = bytes(UID_SIZE) + + uid_arr = jnp.frombuffer(uid_bytes, dtype=jnp.uint8) + all_uids = _allgather_uid(uid_arr, world_size, UID_SIZE) + uid_bytes = bytes(np.asarray(all_uids[dp_color * ep_size]).tolist()) + + # Eager NCCL init while ranks are barrier-synced by the UID broadcast above. + transformer_engine_jax.set_ep_bootstrap_params( + uid_bytes, + ep_size, + rank_within_group, + num_experts, + max_tokens_per_rank, + recv_capacity_per_rank, + hidden_dim, + max_num_sms=int(max_num_sms), + max_token_dtype=int(jax_dtype_to_te_dtype(max_token_dtype)), + ) + + # Release the C++ anchor at interpreter shutdown so RAII can tear down NCCL. + global _atexit_registered + if not _atexit_registered: + atexit.register(transformer_engine_jax.release_ep_resources) + _atexit_registered = True + + tex.ep.set_ep_config( + tex.ep.EpConfig( + world_size=world_size, + rank=rank, + ep_size=ep_size, + num_ep_groups=num_ep_groups, + num_experts=num_experts, + num_local_experts=num_experts // ep_size, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_capacity_per_rank, + hidden_dim=hidden_dim, + ) + ) + + +def _default_out_partition_spec(): + """Leading-axis default: ``(("dp","ep"),)`` if DP/FSDP is set, else ``("ep",)``.""" + gsr = global_mesh_resource() + if gsr.ep_resource is None: + raise ValueError( + "ep_resource is not set on the active MeshResource; pass out_sharding=... explicitly." + ) + outer = _ep_outer_axis() + leading = (outer, gsr.ep_resource) if outer is not None else gsr.ep_resource + return (leading,) + + +# ── ep_dispatch (custom_vjp) ───────────────────────────────────────────────── + + +@partial(jax.custom_vjp, nondiff_argnums=(0, 4)) +def ep_dispatch(cfg, topk_idx, tokens, topk_weights, recv_capacity_per_rank): + """Scatter tokens and weights to expert ranks. + + ``cfg`` is a per-layer ``EpLayerConfig``; distinct layers may share a + ``cfg`` (the pointer-keyed C++ cache keys on handle_mem, not on cfg). + Inputs are ``[..., H]`` with only the leading dim sharded as ``ep`` or + ``(dp, ep)``. Returns + ``(recv_tokens, recv_topk_weights, handle_mem, token_counts)``; pass + ``handle_mem`` and ``token_counts`` to the matching ``ep_combine``. + """ + return _dispatch_fwd(cfg, topk_idx, tokens, topk_weights, recv_capacity_per_rank)[0] + + +def _dispatch_fwd(cfg, topk_idx, tokens, topk_weights, recv_capacity_per_rank): + if not jnp.issubdtype(topk_weights.dtype, jnp.floating): + raise TypeError( + f"ep_dispatch: topk_weights must be a floating dtype; got {topk_weights.dtype}." + ) + token_counts, handle_mem = tex.ep_prepare(cfg, topk_idx) + recv_tokens, recv_topk_weights = tex.ep_dispatch_fwd( + cfg, handle_mem, topk_idx, tokens, topk_weights, recv_capacity_per_rank + ) + out_leading = tuple(tokens.shape[:-1]) + primal = (recv_tokens, recv_topk_weights, handle_mem, token_counts) + return primal, (handle_mem, out_leading) + + +def _dispatch_bwd(cfg, recv_capacity_per_rank, res, g_outputs): + del recv_capacity_per_rank + handle_mem, out_leading = res + # Re-pin cotangent: XLA transpose can drop the EP axis and feed the FFI a global tensor. + out_spec = _default_out_partition_spec() + spec = jax.sharding.PartitionSpec(*out_spec) + g_recv_tokens = with_sharding_constraint(g_outputs[0], spec) + g_recv_topk_weights = with_sharding_constraint(g_outputs[1], spec) + grad_tokens, grad_topk_weights = tex.ep_dispatch_bwd( + cfg, + handle_mem, + g_recv_tokens, + g_recv_topk_weights, + out_leading, + out_partition_spec=out_spec, + ) + return (None, grad_tokens, grad_topk_weights) + + +ep_dispatch.defvjp(_dispatch_fwd, _dispatch_bwd) + + +# ── ep_combine (custom_vjp) ────────────────────────────────────────────────── + + +@partial(jax.custom_vjp, nondiff_argnums=(0, 4, 5)) +def ep_combine( + cfg, + handle_mem, + token_counts, + expert_out, + num_local_tokens, + out_sharding=None, +): + """Scatter-sum expert outputs back to source ranks. **Unweighted.** + + Caller must pre-multiply ``expert_out`` by ``recv_topk_weights`` (and + zero padded slots); gradients w.r.t. weights flow through that hadamard, + not through this op. ``num_local_tokens`` is STATIC: int -> ``[T, H]``, + tuple -> ``[*tuple, H]``. ``out_sharding`` defaults via + ``_default_out_partition_spec``; only the leading dim may be sharded. + """ + return _combine_fwd( + cfg, + handle_mem, + token_counts, + expert_out, + num_local_tokens, + out_sharding, + )[0] + + +def _combine_fwd( + cfg, + handle_mem, + token_counts, + expert_out, + num_local_tokens, + out_sharding, +): + del token_counts + if out_sharding is None: + out_sharding = _default_out_partition_spec() + result = tex.ep_combine_fwd( + cfg, handle_mem, expert_out, num_local_tokens, out_partition_spec=out_sharding + ) + return result, (handle_mem, expert_out.shape[-2]) + + +def _combine_bwd(cfg, _num_local_tokens, _out_sharding, res, g_result): + handle_mem, recv_capacity_per_rank = res + # Re-pin cotangent (same XLA-transpose workaround as _dispatch_bwd). + if _out_sharding is None: + _out_sharding = _default_out_partition_spec() + spec = jax.sharding.PartitionSpec(*_out_sharding) + g_result = with_sharding_constraint(g_result, spec) + grad_expert_out = tex.ep_combine_bwd(cfg, handle_mem, g_result, recv_capacity_per_rank) + return (None, None, grad_expert_out) + + +ep_combine.defvjp(_combine_fwd, _combine_bwd) diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index 91346a7a48..640db29534 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -37,8 +37,7 @@ # import P`` without a second jax.sharding import. from jax.sharding import PartitionSpec as P # noqa: F401 # pylint: disable=unused-import -from ..moe import PermutationBackend, moe -from ..quantize import noop_quantizer_set +from ..moe import moe from ..router import ScoreFunction from ..sharding import get_active_resource_axis from .module import TransformerEngineBase @@ -50,7 +49,7 @@ Initializer = Callable[[PRNGKey, Shape, DType], Array] -__all__ = ["PermutationBackend", "_MoEBlock"] +__all__ = ["_MoEBlock"] class _MoEBlock(TransformerEngineBase): @@ -82,10 +81,11 @@ class _MoEBlock(TransformerEngineBase): Grouped top-k knobs (DeepSeek-style). ``None`` disables grouping. scaling_factor : float Multiplier on the routing weights. - use_expert_bias : bool - If ``True``, registers a per-expert routing bias (shape ``[E]``). - Only meaningful with ``score_function="sigmoid"``; the underlying - primitive validates the pairing. + use_expert_routing_bias : bool + If ``True``, registers a per-expert routing bias (shape ``[E]``) + used by the topk selection. Only meaningful with + ``score_function="sigmoid"``; the underlying primitive validates + the pairing. aux_loss_coeff : float If ``> 0``, return the MoE auxiliary load-balancing loss scalar in addition to the main output. @@ -100,23 +100,27 @@ class _MoEBlock(TransformerEngineBase): replicated across non-EP axes within an EP group; set e.g. ``("fsdp",)`` for true FSDP-of-batch where each device owns a unique slice of the batch. - permutation_backend : PermutationBackend - ``PURE_JAX`` (default) or ``TRITON``. - _align_size : int - Per-expert group-size alignment (``0`` disables; required > 0 - for quantized grouped GEMM). Internal knob; will be inferred - from the active quantization recipe in a follow-up PR. + apply_topk_weights_early : bool + If ``True``, multiply expert outputs by their top-k weights + *inside* each shard before ``ep_combine`` (saves one global + reduction at the cost of an extra broadcast). Default ``False``. + + The per-expert dispatch-slot alignment is fixed internally at 128 + tokens (see ``moe._ALIGN_SIZE``) -- the value required by NCCL EP + HT and satisfied by every current TE grouped-GEMM recipe -- and is + therefore not exposed as a per-instance knob. dtype : jnp.dtype Compute / parameter dtype. kernel_init, bias_init, expert_bias_init : Initializers. - use_bias : bool - Register per-expert FFN biases. + use_ffn_bias : bool + Register per-expert FFN biases (``wi_0_bias``, ``wi_1_bias``, + ``wo_bias``). Quantization is currently configured via the standard TE autocast - context (``fp8_autocast``/``with_quantizer_set``); per-call - quantizer sets can also be passed through ``__call__``'s - ``quantizer_sets`` keyword once we stabilise the recipe pipeline. + context (``fp8_autocast``/``with_quantizer_set``) and threaded + through ``moe()`` internally; this wrapper does not expose a + per-call ``quantizer_sets`` knob yet. """ # Architecture @@ -131,7 +135,7 @@ class _MoEBlock(TransformerEngineBase): num_groups: Optional[int] = None group_topk: Optional[int] = None scaling_factor: float = 1.0 - use_expert_bias: bool = False + use_expert_routing_bias: bool = False aux_loss_coeff: float = 0.0 # Sharding (logical axes) @@ -143,16 +147,15 @@ class _MoEBlock(TransformerEngineBase): # Parallelism data_parallelism_axes: Tuple[str, ...] = () - # Permutation - permutation_backend: PermutationBackend = PermutationBackend.PURE_JAX - _align_size: int = 0 + # MoE knobs forwarded to ``moe()`` + apply_topk_weights_early: bool = False # Dtypes / init / misc dtype: DType = jnp.float32 kernel_init: Optional[Initializer] = None bias_init: Initializer = nn.initializers.zeros expert_bias_init: Initializer = nn.initializers.zeros - use_bias: bool = False + use_ffn_bias: bool = False def __post_init__(self): if self.kernel_init is None: @@ -163,11 +166,6 @@ def __post_init__(self): 1.0, "fan_in", "truncated_normal", dtype=self.dtype ), ) - if not isinstance(self.permutation_backend, PermutationBackend): - raise TypeError( - "permutation_backend must be a PermutationBackend, got" - f" {self.permutation_backend!r}" - ) super().__post_init__() @nn.compact @@ -221,7 +219,7 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: self.dtype, ) wi_0_bias = wi_1_bias = wo_bias = None - if self.use_bias: + if self.use_ffn_bias: wi_0_bias = self.param( "wi_0_bias", nn.with_logical_partitioning(self.bias_init, ("exp", "mlp")), @@ -241,7 +239,7 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: self.dtype, ) expert_bias = None - if self.use_expert_bias: + if self.use_expert_routing_bias: expert_bias = self.param( "expert_bias", nn.with_logical_partitioning(self.expert_bias_init, ("exp",)), @@ -270,15 +268,12 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: group_topk=self.group_topk, scaling_factor=self.scaling_factor, aux_loss_coeff=self.aux_loss_coeff, - permutation_backend=self.permutation_backend, - align_size=self._align_size, - gate_inside_vjp=True, + apply_topk_weights_early=self.apply_topk_weights_early, ep_axis=ep_axis, data_parallelism_axes=self.data_parallelism_axes, input_axes=self.input_axes, gate_kernel_axes=self.gate_kernel_axes, wi_kernel_axes=self.wi_kernel_axes, wo_kernel_axes=self.wo_kernel_axes, - quantizer_sets=(noop_quantizer_set, noop_quantizer_set, noop_quantizer_set), dtype=self.dtype, ) diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py index 2a1c818cb3..ee61540801 100644 --- a/transformer_engine/jax/moe.py +++ b/transformer_engine/jax/moe.py @@ -1,76 +1,52 @@ # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -"""Functional Mixture-of-Experts (MoE) entry point with a single fused VJP. - -This module exposes :func:`moe`, the framework-agnostic flat function that -implements an entire MoE block (gate -> top-k routing -> token dispatch -> -per-expert FFN -> token combine, plus optional expert parallelism via a -shard_map / ragged_all_to_all collective) under a *single* -``jax.custom_vjp``. It is the moral analog of -:func:`transformer_engine.jax.layernorm_mlp.layernorm_mlp` for MoE: one -custom_vjp boundary covers the whole block so future fusions (FP8 over the -EP wire, fused ``ragged_all_to_all + grouped_gemm``, gate+route+dispatch -fusion) can land without re-architecting the call site. - -Design rationale ----------------- - -The earlier MoE block (:class:`transformer_engine.jax.flax.moe._MoEBlock`) -composed many narrower custom_vjps -- one per :func:`grouped_dense`, one -per :func:`token_dispatch`, etc. Every nested custom_vjp is a place where -a quantized :class:`ScaledTensor` cannot survive (JAX requires custom_vjp -inputs / outputs to be plain ``jnp.ndarray`` ish pytrees). To enable -end-to-end FP8 flow -- in particular FP8 carried over the EP -ragged_all_to_all -- the dispatch's quantize, the a2a, the per-expert -FFN, the inverse a2a, and the combine all have to live inside the same -VJP. This file collapses them into one. - -Implementation conventions --------------------------- - -* No nested ``custom_vjp``. Every primitive's ``_fwd`` and ``_bwd`` is - called directly (e.g. :func:`tex.fused_topk_with_score_function_fwd` / - ``_bwd``, :func:`unpermute_with_mask_map`, - :func:`unpermute_bwd_with_merging_probs`, - :func:`sort_chunks_by_map(is_forward=False)`, - forward + reverse :func:`jax.lax.ragged_all_to_all`) so the outer - ``_moe_bwd_rule`` controls the bwd graph end-to-end without invoking - ``jax.vjp`` for re-linearization. -* The fwd/bwd context (``ctx``) is a plain ``dict`` whose keys depend on - the static configuration (permutation backend, EP active or not, - presence of biases, aux loss enabled). The ``_moe_fwd_rule`` builds a - matching ``ctx_specs`` dict in lockstep when opening the EP shard_map - so ``out_specs`` structurally matches the body's return. -* :func:`_dispatch` is the helper that wraps - ``permute -> a2a -> local_permute`` (forward); :func:`_combine` is its - inverse. Their ``_bwd`` siblings drive the inverse collectives in the - bwd rule. None of these helpers form a custom_vjp boundary. +"""Mixture-of-Experts (MoE) layer for TransformerEngine JAX. + +This module exposes :func:`moe`, a single fused MoE forward pass + bwd +built on top of TE's NCCL-backed Expert Parallelism primitives +(``tex.ep_dispatch`` / ``tex.ep_combine``). The block runs:: + + gate -> topk -> ep_dispatch -> per-expert FFN (grouped GEMMs) + -> ep_combine -> output + +under a single ``jax.custom_vjp`` so the routing, dispatch, FFN and +combine steps fuse cleanly under XLA without leaking intermediate +residuals into the user-facing autograd graph. + +Sharding model +-------------- +* Inbound activations are 3D ``[B, S, H]`` sharded + ``((*data_parallelism_axes, ep_axis), None, None)``. The public + :func:`moe` soft-repins this on entry and warns when a reshard is + inserted. +* The EP primitives operate at global view (their custom_partitioning + rules handle per-shard execution). The FFN GEMMs run per-shard inside + a small ``shard_map`` whose ``in_specs`` and ``out_specs`` mirror the + same ``((dp, ep), ...)`` layout. + +Out-of-scope (for now) +---------------------- +FP8 / MXFP8 quantizer sets are not yet wired on this path; turning +them on requires recipe-aware residual specs and ``ScaledTensor`` +leaves across the ``shard_map`` boundary. ``aux_loss_coeff`` and +``expert_bias`` are supported (the former forces a per-step +all-gather over the routing-side logits, which lives off the critical +path and overlaps with the dispatch collective). """ -import math from dataclasses import dataclass -from enum import Enum from functools import partial -from typing import Any, NewType, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union +import warnings import jax import jax.numpy as jnp -from flax import struct as flax_struct -from jax.sharding import PartitionSpec as P +from jax.sharding import NamedSharding, PartitionSpec as P +from jax.tree_util import register_pytree_node_class from . import cpp_extensions as tex -from .permutation import ( - PureJaxPermState, - compute_ragged_all_to_all_params, - compute_reverse_ragged_all_to_all_params, - pure_jax_token_combine, - pure_jax_token_dispatch, - routing_map_to_selected_experts, -) from .quantize import ( - QuantizerSet, - ScaledTensor, TensorUsage, noop_quantizer_set, with_sharding_constraint_by_logical_axes, @@ -79,1070 +55,329 @@ from .router import ScoreFunction, _validate_score_function from .sharding import _get_mesh -# Triton-backed primitives are imported lazily: callers on the PURE_JAX -# permutation backend should not need ``triton`` installed. The TRITON -# branches in this module call ``_require_triton()`` first to raise a -# clear error if the import failed. -try: - from .triton_extensions.permutation import ( - make_chunk_sort_map, - make_row_id_map, - permute_with_mask_map, - permute_with_mask_map_and_pad, - sort_chunks_by_map, - unpermute_bwd_with_merging_probs, - unpermute_bwd_with_merging_probs_and_unpad, - unpermute_with_mask_map, - unpermute_with_mask_map_and_unpad, - ) - - _TRITON_AVAILABLE = True -except ImportError: - _TRITON_AVAILABLE = False - make_chunk_sort_map = None - make_row_id_map = None - permute_with_mask_map = None - permute_with_mask_map_and_pad = None - sort_chunks_by_map = None - unpermute_bwd_with_merging_probs = None - unpermute_bwd_with_merging_probs_and_unpad = None - unpermute_with_mask_map = None - unpermute_with_mask_map_and_unpad = None - - -def _require_triton(): - """Raise a clear error if Triton permutation kernels are unavailable.""" - if not _TRITON_AVAILABLE: - raise ImportError( - "PermutationBackend.TRITON requires" - " ``transformer_engine.jax.triton_extensions`` (and ``triton``)." - " Install Triton or pass PermutationBackend.PURE_JAX." - ) - - -PRNGKey = Any -Shape = Tuple[int, ...] -DType = NewType("DType", jnp.dtype) -Array = NewType("Array", jnp.ndarray) - +__all__ = ["moe"] -__all__ = ["moe", "PermutationBackend"] +# Per-expert dispatch-slot alignment fed to ``tex.ep_prepare`` as +# ``dispatch_output_per_expert_alignment``. NCCL EP HT requires the +# per-expert recv block to be at least 128-token aligned, and all current +# TE grouped-GEMM recipes (bf16/fp16/fp8/mxfp8) are satisfied by the +# same 128-token tile, so a single constant covers every supported path. +_ALIGN_SIZE = 128 -# ============================================================================= -# Enums -# ============================================================================= +def _with_sharding_constraint_cast_bwd(x: jnp.ndarray, sharding) -> jnp.ndarray: + """Sharding constraint that keeps bwd cotangents in the primal dtype. -class PermutationBackend(Enum): - """Token-dispatch / combine backend used by :func:`moe`. - - * ``TRITON``: TE's fused Triton kernels. Faster than ``PURE_JAX`` - on current hardware and the recommended default. - * ``PURE_JAX``: ``jnp.argsort`` + gather paths compiled as plain - XLA; useful as a numerical reference and on builds without - Triton available. - """ + Plain ``jax.lax.with_sharding_constraint`` is identity on the fwd + but does not constrain the dtype of the cotangent that flows back + through it. In this MoE bwd, ``d_x`` is built from two paths: - PURE_JAX = "pure_jax" - TRITON = "triton" + * ``d_x_from_dispatch`` from ``ep_dispatch_bwd`` -- primal dtype + (bf16 in mixed precision). + * ``d_x_from_gate = d_logits_2d @ gate_kernel.T`` where + ``d_logits_2d`` is produced by + ``fused_topk_with_score_function_bwd``. That primitive runs at + fp32 because the fwd promoted ``logits_2d`` to fp32 (the fused + topk/softmax/sigmoid kernels are only validated at fp32). - -# ============================================================================= -# Dispatch-state records (carried _dispatch -> _combine / *_bwd) -# ============================================================================= -# -# Two NamedTuples (one per permutation backend) so we get type -# discrimination at the consumer side via ``isinstance``. The backend- -# specific residuals are required fields; the EP-only residuals are -# Optional and are populated only when the run is EP-active. Each field -# is either an ``ndarray`` or ``None`` -- nothing static, since these -# values cross the shard_map pytree boundary and would otherwise be -# coerced into JitTracers. - - -@flax_struct.dataclass -class _PureJaxDispatchState: - """Residuals saved by :func:`_dispatch` on the PURE_JAX path. - - Registered as a JAX pytree via ``flax.struct.dataclass``: each - annotated field is a leaf, ``None`` is a non-leaf sentinel. The - matching spec built by :func:`_build_dispatch_specs` mirrors this - layout so shard_map's value and spec trees line up. + JAX's type promotion then makes ``d_x_from_gate + d_x_from_dispatch`` + fp32, so the user-visible ``d_x`` ends up wider than ``x``. That + doubles activation-grad bandwidth and breaks any downstream kernel + that pins a bf16 input layout. This wrapper inserts an explicit + cast back to the primal dtype on the bwd side and re-asserts the + same sharding there as well. """ - group_sizes: jnp.ndarray - sorted_indices: jnp.ndarray - routing_weights: jnp.ndarray - # EP-only: - all_shards_tokens_per_expert: Optional[jnp.ndarray] = None - local_perm_row_id_map: Optional[jnp.ndarray] = None - - -@flax_struct.dataclass -class _TritonDispatchState: - """Residuals saved by :func:`_dispatch` on the TRITON path.""" - - group_sizes: jnp.ndarray - row_id_map: jnp.ndarray - pad_offsets: Optional[jnp.ndarray] # populated only when align_size > 0 - merging_probs: jnp.ndarray - # EP-only: - all_shards_tokens_per_expert: Optional[jnp.ndarray] = None - local_perm_row_id_map: Optional[jnp.ndarray] = None - - -_DispatchState = Union[_PureJaxDispatchState, _TritonDispatchState] - + @jax.custom_vjp + def _constraint(y): + return jax.lax.with_sharding_constraint(y, sharding) -@flax_struct.dataclass -class _BodyCtx: - """Residuals carried fwd_rule -> bwd_rule by :func:`_body_fwd`. + def _constraint_fwd(y): + return jax.lax.with_sharding_constraint(y, sharding), jnp.zeros((), dtype=y.dtype) - Optional fields (``expert_bias``, ``aux_*``) are ``None`` when the - matching feature is disabled. :func:`_build_ctx_specs` mirrors that - layout so the shard_map spec and value trees match leaf-for-leaf. - """ + def _constraint_bwd(dtype_ref, grad): + return (jax.lax.with_sharding_constraint(grad.astype(dtype_ref.dtype), sharding),) - # Always present. - x: Any - gate_kernel: Any - logits_2d: Any - saved_scores: Any - routing_map: Any - dispatch: Any # _DispatchState - casted_sorted_x_lhs_trans: Any - casted_wi_rhs_trans: Any # combined [E, H, 2M] residual for fused wi_0|wi_1 bwd - gate_proj_out: Any - up_proj_out: Any - casted_intermediate_lhs_trans: Any - casted_wo_rhs_trans: Any - expert_outputs: Any - local_group_sizes: Any - # Feature-gated. - expert_bias: Any = None - aux_const_buf: Any = None - aux_tokens_per_expert: Any = None - aux_logits_for_score: Any = None - aux_saved_scores: Any = None + _constraint.defvjp(_constraint_fwd, _constraint_bwd) + return _constraint(x) # ============================================================================= -# ctx / dispatch-state key conventions +# Process-level NCCL EP bootstrap (must run eagerly, outside jax.jit) # ============================================================================= # -# Both ``ctx`` (carried fwd_rule -> bwd_rule) and the dispatch state -# (carried _dispatch -> _combine / _dispatch_bwd / _combine_bwd) are plain -# python dicts. Using a dict (rather than a flax_struct.dataclass) lets us -# vary the populated keys with the static config without breaking -# ``shard_map``'s ``out_specs`` structural match: the spec dict and the -# value dict are built with the SAME keys via :func:`_build_ctx_specs`. -# -# Below is the key glossary so the rest of the file reads cleanly. -# -# DispatchState (dict): values are jnp.ndarray unless noted -# Always present: -# "group_sizes" [n_groups] per-expert token counts -# (n_groups = E for no-EP, -# E_local for EP) -# "ep_active" bool (carried as a Python flag, -# not in the dict; passed -# alongside) -# PURE_JAX backend: -# "sorted_indices" [num_real + padding] argsort indices -# "routing_weights" [num_tokens, topk] per-token-per-expert weights -# TRITON backend: -# "row_id_map" [num_tokens, 2*E + 1] -# "pad_offsets" [E] or None -# "merging_probs" [num_tokens, E] -# EP-only: -# "all_shards_tokens_per_expert" [num_ep, E] -# "local_perm_row_id_map" [recv_buffer_rows] -# "local_perm_inv_row_id_map" [recv_buffer_rows] -# -# NOTE: per-shard compile-time-constant shapes (num_real_tokens, -# padding_size, pre/post_a2a_buffer_shape) are NOT stored in this -# dict; they are recomputed in _body_fwd/_body_bwd via -# _compute_static_shape_info and passed as Python ints / int tuples to -# the dispatch/combine helpers. Storing them in the dict would cause -# JAX's pytree-flatten across the shard_map boundary to coerce them -# into JitTracer 0-d arrays, which breaks Python-level control flow -# (e.g. ``if padding > 0``) and ``jnp.zeros(shape)`` in the bwd. -# -# See :class:`_BodyCtx` (NamedTuple) for the ctx layout and field -# documentation. :func:`_build_ctx_specs` returns a matching ``_BodyCtx`` -# of ``P(...)`` specs so shard_map's value/spec trees line up -# leaf-for-leaf. - - -# ============================================================================= -# Static shape helper -# ============================================================================= -# -# A set of per-shard shape/size values that the dispatch and combine -# helpers (both fwd and bwd) need. They're all derivable from existing -# static args, so we recompute them in both ``_body_fwd`` and -# ``_body_bwd`` and pass them as Python ints / int-tuples through -# explicit kwargs. We MUST NOT stash them inside the dynamic -# ``state`` / ``ctx`` dict: when the dict crosses the EP shard_map's -# out_specs/in_specs boundary, JAX's pytree-flatten coerces any Python -# int leaves into traced 0-d arrays, which then breaks dependent Python -# code in the bwd (e.g. ``if padding > 0`` and ``jnp.zeros(shape)``). - - -@dataclass(frozen=True) -class _StaticShapeInfo: - """Per-shard compile-time-constant shape info used by dispatch / - combine fwd and bwd. Fields are Python ints / int tuples (NOT jnp - arrays) so they can be passed as ordinary static keyword args. - - Attributes - ---------- - num_real_tokens : int - Per-shard count of real (non-padding) permuted tokens, - i.e. ``per_shard_num_tokens * num_experts_per_tok``. - padding_size : int - Per-shard number of alignment-padding tokens appended to the - sort buffer (``num_experts * (align_size - 1)`` when - ``align_size > 0``, else ``0``). - pre_a2a_buffer_shape : tuple[int, int] - ``(num_real_tokens + padding_size, hidden)`` -- the per-shard - shape of the sorted-inputs buffer sent over the EP - ragged_all_to_all in the fwd direction. - post_a2a_buffer_shape : Optional[tuple[int, int]] - ``(recv_buffer_rows, hidden)`` when EP is active, ``None`` - otherwise. - """ +# ``tex.ep_bootstrap`` does a NCCL UID allgather over the JAX runtime, which +# cannot run from inside a jit-traced function. The caller must bootstrap +# eagerly once per process before any jitted MoE call, then record the +# bootstrap signature via ``record_ep_bootstrap_signature_for_moe``. The +# per-call check below verifies the recorded signature is wide enough for +# the current MoE invocation (smaller per-call usage is fine since the C++ +# backend reserves worst-case buffers at bootstrap time). - num_real_tokens: int - padding_size: int - pre_a2a_buffer_shape: Tuple[int, int] - post_a2a_buffer_shape: Optional[Tuple[int, int]] +_te_ep_bootstrap_signature: Optional[Tuple[int, int, int, int, int]] = None -def _compute_static_shape_info( - *, - batch_size: int, - sequence_length: int, - hidden: int, +def record_ep_bootstrap_signature_for_moe( num_experts: int, - num_experts_per_tok: int, - align_size: int, - ep_active: bool, - num_ep: int = 1, - fsdp_sizes: Tuple[int, ...] = (), - recv_buffer_rows: int = 0, - batch_is_per_shard: bool = True, -) -> _StaticShapeInfo: - """Build a :class:`_StaticShapeInfo` for the current rank. - - ``batch_is_per_shard`` controls whether ``batch_size`` is already - sharded (True -- e.g. when this is called from inside a shard_map - body, where ``x.shape[0]`` reports the per-shard batch size) or - global (False -- e.g. when computing from x.shape outside the - shard_map body). + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + ep_size: int, +) -> None: + """Record the params passed to ``ep_bootstrap`` so the per-call check + in ``_moe_fwd_rule`` can verify compatibility. Call this once per + process immediately after ``ep_bootstrap``. """ - if ep_active and not batch_is_per_shard: - dp_size = math.prod(fsdp_sizes) if fsdp_sizes else 1 - per_shard_batch = batch_size // (num_ep * dp_size) - else: - per_shard_batch = batch_size - per_shard_num_tokens = per_shard_batch * sequence_length - num_real_tokens = per_shard_num_tokens * num_experts_per_tok - padding_size = num_experts * (align_size - 1) if align_size > 0 else 0 - pre_a2a_buffer_shape = (num_real_tokens + padding_size, hidden) - post_a2a_buffer_shape = (recv_buffer_rows, hidden) if ep_active else None - return _StaticShapeInfo( - num_real_tokens=num_real_tokens, - padding_size=padding_size, - pre_a2a_buffer_shape=pre_a2a_buffer_shape, - post_a2a_buffer_shape=post_a2a_buffer_shape, + global _te_ep_bootstrap_signature + _te_ep_bootstrap_signature = ( + num_experts, + max_tokens_per_rank, + recv_capacity_per_rank, + hidden_dim, + ep_size, ) -# ============================================================================= -# Dispatch / combine helpers (no VJP boundary -- pure Python) -# ============================================================================= - - -def _dispatch( - inputs_2d: jnp.ndarray, - sparse_probs: jnp.ndarray, - routing_map: jnp.ndarray, - *, - backend: PermutationBackend, +def _te_ep_assert_compatible_bootstrap( num_experts: int, - num_experts_per_tok: int, - align_size: int, - # EP-only: - ep_active: bool, - ep_axis: Optional[str], - num_ep: int, - recv_buffer_rows: int, - shard_id: Optional[jnp.ndarray] = None, -) -> Tuple[jnp.ndarray, dict]: - """``permute -> (a2a -> local_permute) iff ep_active``. - - Returns ``(sorted_x, state)`` where ``sorted_x`` has shape - ``[buffer_rows, hidden]`` -- ``E`` groups (no-EP) or ``E_local`` groups - (EP) -- and ``state`` is a dict carrying everything :func:`_combine` - and the bwd helpers need to reverse the operation. - - Bypasses the ``custom_vjp``-wrapped public ``token_dispatch`` / - ``pure_jax_token_dispatch`` wrappers (well, mostly: PURE_JAX still - composes through ``pure_jax_token_dispatch`` because that helper has - no ``custom_vjp`` itself -- only its inner ``_sort_activations`` does, - which is fine since we never auto-diff through it from this layer). - For TRITON we call the underlying ``permute_with_mask_map`` / - ``permute_with_mask_map_and_pad`` primitives directly. - """ - num_tokens, hidden = inputs_2d.shape - topk = num_experts_per_tok - - # Backend-specific residuals collected here, then packaged into the - # appropriate _*DispatchState below. - sorted_indices = None - routing_weights_kept = None - row_id_map = None - pad_offsets = None - merging_probs = None - - # ------------------------------------------------------------------ - # Step 1: global permute (every shard routes its own tokens over the - # full expert axis). Backend-specific. - # ------------------------------------------------------------------ - if backend is PermutationBackend.PURE_JAX: - selected_experts, routing_weights = routing_map_to_selected_experts( - sparse_probs, routing_map, topk - ) - sorted_inputs, perm_state, group_sizes = pure_jax_token_dispatch( - inputs_2d, - selected_experts, - num_experts=num_experts, - num_experts_per_tok=topk, - align_size=align_size, + max_tokens_per_rank: int, + recv_capacity_per_rank: int, + hidden_dim: int, + ep_size: int, +) -> None: + """Verify a prior eager ``ep_bootstrap`` is wide enough for this call.""" + if _te_ep_bootstrap_signature is None: + raise RuntimeError( + "TE EP was not bootstrapped. Call" + " transformer_engine.jax.ep.ep_bootstrap(...) eagerly (outside" + " any jax.jit) once per process, then" + " transformer_engine.jax.moe.record_ep_bootstrap_signature_for_moe(...)" + " with the same params, before invoking moe()." ) - # NOTE: ``perm_state.num_real_tokens`` and ``perm_state.padding_size`` - # are compile-time Python ints; intentionally NOT stored in the - # returned state (would be coerced to JitTracer 0-d arrays under - # the EP shard_map's pytree flatten). Recompute via - # ``_compute_static_shape_info`` in the bwd / EP-combine - # call sites that need them. - sorted_indices = perm_state.sorted_indices - routing_weights_kept = routing_weights - else: - # TRITON backend -- inline the underlying primitive sequence - # (mirrors ``_token_dispatch_fwd_rule`` but exposes the residuals - # to our ctx instead of saving them inside another custom_vjp). - num_out_tokens = num_tokens * topk - row_id_map = make_row_id_map(routing_map, num_tokens, num_experts) - tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32) - if align_size > 0: - target_tokens_per_expert = ( - jnp.ceil(tokens_per_expert / align_size) * align_size - ).astype(jnp.int32) - pad_lengths = target_tokens_per_expert - tokens_per_expert - cum_pad = jnp.cumsum(pad_lengths) - pad_offsets = jnp.concatenate([jnp.array([0], dtype=cum_pad.dtype), cum_pad[:-1]]) - worst_case_out_tokens = ( - (num_out_tokens + num_experts * (align_size - 1)) // align_size - ) * align_size - sorted_inputs, _ = permute_with_mask_map_and_pad( - inputs_2d, - row_id_map, - None, - pad_offsets, - num_tokens, - num_experts, - worst_case_out_tokens, - hidden, - align_size=align_size, - ) - group_sizes = target_tokens_per_expert - else: - sorted_inputs, _ = permute_with_mask_map( - inputs_2d, - row_id_map, - None, - num_tokens, - num_experts, - num_out_tokens, - hidden, - ) - pad_offsets = None - group_sizes = tokens_per_expert - merging_probs = sparse_probs - - def _build_state(group_sizes_val, ep_all=None, ep_local=None): - if backend is PermutationBackend.PURE_JAX: - return _PureJaxDispatchState( - group_sizes=group_sizes_val, - sorted_indices=sorted_indices, - routing_weights=routing_weights_kept, - all_shards_tokens_per_expert=ep_all, - local_perm_row_id_map=ep_local, - ) - return _TritonDispatchState( - group_sizes=group_sizes_val, - row_id_map=row_id_map, - pad_offsets=pad_offsets, - merging_probs=merging_probs, - all_shards_tokens_per_expert=ep_all, - local_perm_row_id_map=ep_local, + b_num_experts, b_max_tpr, b_recv_pr, b_hidden, b_ep_size = _te_ep_bootstrap_signature + if ( + num_experts != b_num_experts + or hidden_dim != b_hidden + or ep_size != b_ep_size + or max_tokens_per_rank > b_max_tpr + or recv_capacity_per_rank > b_recv_pr + ): + raise ValueError( + "TE EP was already bootstrapped with signature" + f" (num_experts={b_num_experts}, max_tokens_per_rank={b_max_tpr}," + f" recv_capacity_per_rank={b_recv_pr}, hidden_dim={b_hidden}," + f" ep_size={b_ep_size}); this moe() call needs" + f" (num_experts={num_experts}, max_tokens_per_rank={max_tokens_per_rank}," + f" recv_capacity_per_rank={recv_capacity_per_rank}, hidden_dim={hidden_dim}," + f" ep_size={ep_size}). Re-bootstrap with wider params (or matching exact" + " sizes) is required." ) - if not ep_active: - return sorted_inputs, _build_state(group_sizes) - - # ------------------------------------------------------------------ - # Step 2 (EP only): all_gather per-expert counts so every shard knows - # the [num_ep, num_experts] token-count matrix. - # ------------------------------------------------------------------ - all_shards_tokens_per_expert = jax.lax.all_gather( - group_sizes[None, :], - axis_name=ep_axis, - axis=0, - tiled=True, - ) - # ------------------------------------------------------------------ - # Step 3 (EP only): forward ragged_all_to_all over the EP axis. - # ------------------------------------------------------------------ - in_off, send_sz, out_off, recv_sz = compute_ragged_all_to_all_params( - all_shards_tokens_per_expert, shard_id, num_ep - ) - post_a2a_buffer_shape = (recv_buffer_rows, hidden) - recv_buf = jnp.zeros(post_a2a_buffer_shape, dtype=sorted_inputs.dtype) - x_recv = jax.lax.ragged_all_to_all( - sorted_inputs, recv_buf, in_off, send_sz, out_off, recv_sz, axis_name=ep_axis - ) - - # ------------------------------------------------------------------ - # Step 4 (EP only): local permute -- (source_shard, expert) -> - # (expert, shard). Inlined ``local_permute_after_a2a`` so we control - # both the row_id_map and its inverse for the bwd. - # ------------------------------------------------------------------ - num_experts_local = num_experts // num_ep - local_expert_start = shard_id * num_experts_local - local_expert_columns = jax.lax.dynamic_slice( - all_shards_tokens_per_expert, - start_indices=(0, local_expert_start), - slice_sizes=(num_ep, num_experts_local), - ) - split_sizes = local_expert_columns.reshape(-1) # source-major - indices_matrix = jnp.arange(num_ep * num_experts_local, dtype=jnp.int32).reshape( - num_ep, num_experts_local - ) - sorted_chunk_indices = indices_matrix.T.reshape(-1) # source-major -> expert-major - num_chunks = num_ep * num_experts_local - # Build a SINGLE row_id_map. ``is_forward=True`` permutes - # source-major -> expert-major; ``is_forward=False`` is the exact - # inverse (this is exactly what ``_sort_chunks_by_index_bwd_rule`` - # uses on the saved residual). _MoEBlock builds two row_id_maps - # only because it calls ``sort_chunks_by_index`` twice -- once in - # ``local_permute_after_a2a`` and again in ``local_unpermute_before_a2a``; - # each of those wrappers calls ``make_chunk_sort_map`` internally. - # Here we share one map across (fwd permute, fwd inverse-permute, - # bwd permute, bwd inverse-permute). - local_perm_row_id_map = make_chunk_sort_map( - split_sizes, sorted_chunk_indices, recv_buffer_rows, num_chunks - ) - sorted_x, _ = sort_chunks_by_map( - x_recv, local_perm_row_id_map, None, recv_buffer_rows, hidden, is_forward=True - ) - local_group_sizes = jnp.sum(local_expert_columns, axis=0) - - # NOTE: pre_a2a_buffer_shape and post_a2a_buffer_shape are compile- - # time int tuples; intentionally NOT stored in the returned state - # (would be coerced to JitTracer 0-d arrays under the EP shard_map's - # pytree flatten). Recompute via ``_compute_static_shape_info`` in - # the bwd call sites that need them. For EP, ``group_sizes`` here is - # the per-local-expert count (the FFN runs over E_local groups, not - # E). The global ``group_sizes`` lives inside - # ``all_shards_tokens_per_expert`` if anyone needs it for - # diagnostics. - return sorted_x, _build_state( - local_group_sizes, - ep_all=all_shards_tokens_per_expert, - ep_local=local_perm_row_id_map, - ) +# ============================================================================= +# Residual container threaded fwd -> bwd +# ============================================================================= -def _combine( - expert_outputs: jnp.ndarray, - state: _DispatchState, - *, - backend: PermutationBackend, - ep_active: bool, - batch_size: int, - sequence_length: int, - dtype: jnp.dtype, - num_experts_per_tok: int, - # Per-shard compile-time-constant shape info (Python ints / int tuples). - # Computed by _compute_static_shape_info in the caller, passed here - # rather than stored in ``state`` to survive shard_map crossings. - num_real_tokens: int, - padding_size: int, - pre_a2a_buffer_shape: Tuple[int, int], - # EP-only: - ep_axis: Optional[str], - shard_id: Optional[jnp.ndarray] = None, - num_ep: int = 1, -) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Inverse of :func:`_dispatch`. - - Returns ``(output, expert_outputs_post_ep)``. ``output`` is the - ``[B, S, H]`` combined activations. ``expert_outputs_post_ep`` is - the FFN-output tensor in the shape that Step 3 of the combine - actually consumed (i.e. after the reverse ragged_all_to_all on EP - runs, or the original input on non-EP). The caller stashes this as - the bwd residual so that ``_combine_bwd``'s Step-3 inverse sees - the same tensor the forward Step 3 used. - """ - if ep_active: - # Step 1 (EP): inverse local permute. Reuse the SAME row_id_map - # built in _dispatch by setting is_forward=False (this is the - # exact inverse, identical to what - # ``_sort_chunks_by_index_bwd_rule`` does with the saved residual). - recv_buffer_rows, hidden = expert_outputs.shape - x_send_back, _ = sort_chunks_by_map( - expert_outputs, - state.local_perm_row_id_map, - None, - recv_buffer_rows, - hidden, - is_forward=False, - ) - # Step 2 (EP): reverse ragged_all_to_all. - in_off_r, send_sz_r, out_off_r, recv_sz_r = compute_reverse_ragged_all_to_all_params( - state.all_shards_tokens_per_expert, shard_id, num_ep - ) - send_back_buf = jnp.zeros(pre_a2a_buffer_shape, dtype=expert_outputs.dtype) - expert_outputs = jax.lax.ragged_all_to_all( - x_send_back, - send_back_buf, - in_off_r, - send_sz_r, - out_off_r, - recv_sz_r, - axis_name=ep_axis, - ) +# Registered as a pytree so jax.custom_vjp can flatten/unflatten it across +# the fwd -> bwd boundary. ``cfg`` is the only static field (EpLayerConfig +# is a frozen dataclass of ints); the rest are jnp.ndarray, +# GroupedNoScaleTensor (already a pytree), or None when aux_loss_coeff == 0. +@register_pytree_node_class +@dataclass +class _Ctx: + """Residuals carried from the fwd rule into the bwd rule.""" + + x: jnp.ndarray + gate_kernel: jnp.ndarray + expert_bias: jnp.ndarray + logits_2d: jnp.ndarray + saved_scores: jnp.ndarray + routing_map: jnp.ndarray + cfg: Any + handle_mem: Any + token_counts: jnp.ndarray + recv_topk_weights: jnp.ndarray + casted_sorted_x_lhs_trans: Any + casted_wi_rhs_trans: Any + gate_proj_out: jnp.ndarray + up_proj_out: jnp.ndarray + casted_intermediate_lhs_trans: Any + casted_wo_rhs_trans: Any + expert_outputs: jnp.ndarray + local_group_sizes: jnp.ndarray + # Aux-loss residuals; None when aux_loss_coeff == 0. + aux_const_buf: Any = None + aux_tokens_per_expert: Any = None + aux_saved_scores: Any = None - # Step 3: global combine. ``expert_outputs`` here is the post-A2A - # tensor under EP, or the original input under non-EP -- whichever - # value Step 3 actually consumes. Returned as the second tuple - # element so the caller can stash it as the bwd residual. - if backend is PermutationBackend.PURE_JAX: - # Reuse the reference pure-jax implementation; it has no - # custom_vjp on its outer surface so we can call it freely. - perm_state = PureJaxPermState( - sorted_indices=state.sorted_indices, - num_real_tokens=num_real_tokens, - padding_size=padding_size, - ) - output = pure_jax_token_combine( - expert_outputs, - perm_state, - state.routing_weights, - num_experts_per_tok=num_experts_per_tok, - batch_size=batch_size, - sequence_length=sequence_length, - ) - return output, expert_outputs - # TRITON - num_tokens = state.row_id_map.shape[0] - num_experts = (state.row_id_map.shape[1] - 1) // 2 - hidden = expert_outputs.shape[-1] - if state.pad_offsets is not None: - out_2d, _ = unpermute_with_mask_map_and_unpad( - expert_outputs, - state.row_id_map, - state.merging_probs, - None, - state.pad_offsets, - num_tokens, - num_experts, - hidden, + def tree_flatten(self): + children = ( + self.x, + self.gate_kernel, + self.expert_bias, + self.logits_2d, + self.saved_scores, + self.routing_map, + self.handle_mem, + self.token_counts, + self.recv_topk_weights, + self.casted_sorted_x_lhs_trans, + self.casted_wi_rhs_trans, + self.gate_proj_out, + self.up_proj_out, + self.casted_intermediate_lhs_trans, + self.casted_wo_rhs_trans, + self.expert_outputs, + self.local_group_sizes, + self.aux_const_buf, + self.aux_tokens_per_expert, + self.aux_saved_scores, ) - else: - out_2d, _ = unpermute_with_mask_map( + aux_data = (self.cfg,) + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + (cfg,) = aux_data + ( + x, + gate_kernel, + expert_bias, + logits_2d, + saved_scores, + routing_map, + handle_mem, + token_counts, + recv_topk_weights, + casted_sorted_x_lhs_trans, + casted_wi_rhs_trans, + gate_proj_out, + up_proj_out, + casted_intermediate_lhs_trans, + casted_wo_rhs_trans, expert_outputs, - state.row_id_map, - state.merging_probs, - None, - num_tokens, - num_experts, - hidden, - ) - return out_2d.reshape(batch_size, sequence_length, hidden).astype(dtype), expert_outputs - - -def _combine_bwd( # pylint: disable=unused-argument - d_output: jnp.ndarray, - state: _DispatchState, - expert_outputs: jnp.ndarray, - *, - backend: PermutationBackend, - ep_active: bool, - batch_size: int, - sequence_length: int, - dtype: jnp.dtype, - num_experts: int, - num_experts_per_tok: int, - # Per-shard compile-time-constant shape info (Python ints / int tuples). - # See ``_compute_static_shape_info`` and the note in ``_dispatch`` - # for why these are kwargs rather than state-dict entries. - num_real_tokens: int, - padding_size: int, - post_a2a_buffer_shape: Optional[Tuple[int, int]], - # EP-only: - ep_axis: Optional[str], - shard_id: Optional[jnp.ndarray] = None, - num_ep: int = 1, -) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: - """Inverse of :func:`_combine` on the cotangent. - - Returns ``(d_expert_outputs, d_routing_weights_or_merging_probs)``. - - ``expert_outputs`` is the *forward* output of the FFN (same value the - fwd handed to :func:`_combine`). It's required by the TRITON - combine_bwd kernel; for PURE_JAX we don't need it but accept it for - a symmetric signature. - """ - # Step 3 inverse: global combine bwd. - d_output_2d = d_output.reshape(-1, d_output.shape[-1]) - if backend is PermutationBackend.PURE_JAX: - # The pure-jax combine is: - # unsort = _sort_activations(expert_outputs, argsort(sorted_indices)) - # if pad: unsort = unsort[:num_real] - # reshape -> einsum BKE,BK -> BE -> reshape to BSE - # Hand-derive the bwd in plain JAX (no custom_vjp involved): - unsort_indices = jnp.argsort(state.sorted_indices) - topk = num_experts_per_tok - num_real = num_real_tokens - padding = padding_size - # Recover the unsorted intermediate that the fwd produced (we - # need it for the d_routing_weights pullback). Apply the same - # gather the fwd did. - unsort_intermediate = expert_outputs[unsort_indices] - if padding > 0: - unsort_intermediate = unsort_intermediate[:num_real] - # Bwd of einsum/reshape: - # output[B, E] = sum_K intermediate[B, K, E] * weights[B, K] - # d_intermediate[B, K, E] = d_output[B, E] * weights[B, K] - # d_weights[B, K] = sum_E d_output[B, E] * intermediate[B, K, E] - rw = state.routing_weights.reshape(-1, topk) - intermediate_3d = unsort_intermediate.reshape(rw.shape[0], topk, -1) - rw_cast = rw.astype(intermediate_3d.dtype) - d_intermediate_3d = jnp.einsum("BE,BK -> BKE", d_output_2d, rw_cast) - d_routing_weights = jnp.einsum("BE,BKE -> BK", d_output_2d, intermediate_3d).astype( - state.routing_weights.dtype - ) - d_routing_weights = d_routing_weights.reshape(state.routing_weights.shape) - d_unsort_intermediate = d_intermediate_3d.reshape(num_real, -1) - # Pad back with zeros if the fwd stripped padding. - if padding > 0: - d_unsort_intermediate = jnp.concatenate( - [ - d_unsort_intermediate, - jnp.zeros( - (padding, d_unsort_intermediate.shape[-1]), - dtype=d_unsort_intermediate.dtype, - ), - ], - axis=0, - ) - # Bwd of the gather is gather-by-original-indices: - # sorted = unsort[argsort(sorted_indices)] - # d_sorted = scatter d_unsort via argsort(sorted_indices) - # = d_unsort[sorted_indices] (gather by original sorted_indices, - # which is the inverse of argsort(sorted_indices)). - d_expert_outputs_global = d_unsort_intermediate[state.sorted_indices] - else: - # TRITON combine bwd: requires fwd_input (expert_outputs). - num_tokens = state.row_id_map.shape[0] - n_experts = (state.row_id_map.shape[1] - 1) // 2 - hidden = d_output_2d.shape[-1] - num_out_tokens = expert_outputs.shape[0] - if state.pad_offsets is not None: - d_expert_outputs_global, d_merging_probs = unpermute_bwd_with_merging_probs_and_unpad( - d_output_2d, - state.row_id_map, - expert_outputs, - state.merging_probs, - state.pad_offsets, - num_tokens, - n_experts, - num_out_tokens, - hidden, - ) - # The kernel only writes positions tokens map to; padded - # positions may contain NaN. Replace with zeros (matches - # ``_token_combine_bwd_rule``). - d_expert_outputs_global = jnp.where( - jnp.isnan(d_expert_outputs_global), 0.0, d_expert_outputs_global - ) - else: - d_expert_outputs_global, d_merging_probs = unpermute_bwd_with_merging_probs( - d_output_2d, - state.row_id_map, - expert_outputs, - state.merging_probs, - num_tokens, - n_experts, - num_out_tokens, - hidden, - ) - d_routing_weights = d_merging_probs - - if not ep_active: - return d_expert_outputs_global, d_routing_weights - - # Step 2 (EP) inverse: bwd of reverse ragged_all_to_all is a forward - # ragged_all_to_all using the SAME forward parameters (sender / - # receiver roles swap from the reverse direction back to forward). - in_off_f, send_sz_f, out_off_f, recv_sz_f = compute_ragged_all_to_all_params( - state.all_shards_tokens_per_expert, shard_id, num_ep - ) - recv_buf_for_bwd = jnp.zeros(post_a2a_buffer_shape, dtype=d_expert_outputs_global.dtype) - d_x_send_back = jax.lax.ragged_all_to_all( - d_expert_outputs_global, - recv_buf_for_bwd, - in_off_f, - send_sz_f, - out_off_f, - recv_sz_f, - axis_name=ep_axis, - ) - # Step 1 (EP) inverse: combine fwd applied is_forward=False; the - # bwd is is_forward=True with the SAME row_id_map. - recv_buffer_rows, hidden = d_x_send_back.shape - d_expert_outputs, _ = sort_chunks_by_map( - d_x_send_back, - state.local_perm_row_id_map, - None, - recv_buffer_rows, - hidden, - is_forward=True, - ) - return d_expert_outputs, d_routing_weights - - -def _dispatch_bwd( - d_sorted_x: jnp.ndarray, - state: _DispatchState, - inputs_2d_shape: Tuple[int, ...], - *, - backend: PermutationBackend, - ep_active: bool, - num_experts: int, - num_experts_per_tok: int, - # Per-shard compile-time-constant shape info (Python ints / int tuples). - # See ``_compute_static_shape_info`` and the note in ``_dispatch`` - # for why these are kwargs rather than state-dict entries. - num_real_tokens: int, - padding_size: int, - pre_a2a_buffer_shape: Tuple[int, int], - # EP-only: - ep_axis: Optional[str], - shard_id: Optional[jnp.ndarray] = None, - num_ep: int = 1, -) -> jnp.ndarray: - """Inverse of :func:`_dispatch` on the cotangent. Returns ``d_inputs_2d``. - - The probs path through dispatch is always discarded (PURE_JAX never - threads probs through dispatch; TRITON technically does but the - caller drops ``permuted_probs``, so its cotangent is structurally - zero). The probs gradient instead flows back through - :func:`_combine_bwd`. - """ - if ep_active: - # Step 4 inverse: dispatch fwd applied is_forward=True; bwd is - # is_forward=False with the SAME row_id_map. - recv_buffer_rows, hidden = d_sorted_x.shape - d_x_recv, _ = sort_chunks_by_map( - d_sorted_x, - state.local_perm_row_id_map, - None, - recv_buffer_rows, - hidden, - is_forward=False, - ) - # Step 3 inverse: bwd of forward ragged_a2a is the reverse-direction - # ragged_a2a using the SAME params with sender/receiver swapped. - in_off_r, send_sz_r, out_off_r, recv_sz_r = compute_reverse_ragged_all_to_all_params( - state.all_shards_tokens_per_expert, shard_id, num_ep - ) - recv_buf_pre = jnp.zeros(pre_a2a_buffer_shape, dtype=d_x_recv.dtype) - d_sorted_x = jax.lax.ragged_all_to_all( - d_x_recv, - recv_buf_pre, - in_off_r, - send_sz_r, - out_off_r, - recv_sz_r, - axis_name=ep_axis, - ) - - # Step 1 inverse: global permute bwd. - if backend is PermutationBackend.PURE_JAX: - # Fwd was: replicated = repeat(inputs_2d, topk, axis=0) - # padded = pad(replicated, (0, padding_size)) - # sorted = padded[sorted_indices] - # Bwd: d_padded = scatter via sorted_indices - # = d_sorted[argsort(sorted_indices)] - # d_replicated = d_padded[:num_real] - # d_inputs_2d = d_replicated.reshape(T, topk, H).sum(axis=1) - sorted_indices = state.sorted_indices - num_real = num_real_tokens - padding = padding_size - topk = num_experts_per_tok - unsort_indices = jnp.argsort(sorted_indices) - d_padded = d_sorted_x[unsort_indices] - if padding > 0: - d_replicated = d_padded[:num_real] - else: - d_replicated = d_padded - num_tokens = inputs_2d_shape[0] - hidden = inputs_2d_shape[-1] - d_inputs_2d = d_replicated.reshape(num_tokens, topk, hidden).sum(axis=1) - return d_inputs_2d - - # TRITON: bwd is unpermute_with_mask_map[_and_unpad]. - num_tokens = inputs_2d_shape[0] - hidden = inputs_2d_shape[-1] - if state.pad_offsets is not None: - d_inputs_2d, _ = unpermute_with_mask_map_and_unpad( - d_sorted_x, - state.row_id_map, - None, - None, - state.pad_offsets, - num_tokens, - num_experts, - hidden, - ) - else: - d_inputs_2d, _ = unpermute_with_mask_map( - d_sorted_x, - state.row_id_map, - None, - None, - num_tokens, - num_experts, - hidden, + local_group_sizes, + aux_const_buf, + aux_tokens_per_expert, + aux_saved_scores, + ) = children + return cls( + x=x, + gate_kernel=gate_kernel, + expert_bias=expert_bias, + logits_2d=logits_2d, + saved_scores=saved_scores, + routing_map=routing_map, + cfg=cfg, + handle_mem=handle_mem, + token_counts=token_counts, + recv_topk_weights=recv_topk_weights, + casted_sorted_x_lhs_trans=casted_sorted_x_lhs_trans, + casted_wi_rhs_trans=casted_wi_rhs_trans, + gate_proj_out=gate_proj_out, + up_proj_out=up_proj_out, + casted_intermediate_lhs_trans=casted_intermediate_lhs_trans, + casted_wo_rhs_trans=casted_wo_rhs_trans, + expert_outputs=expert_outputs, + local_group_sizes=local_group_sizes, + aux_const_buf=aux_const_buf, + aux_tokens_per_expert=aux_tokens_per_expert, + aux_saved_scores=aux_saved_scores, ) - return d_inputs_2d # ============================================================================= -# Per-shard body +# Per-shard FFN body (runs inside shard_map) # ============================================================================= -def _body_fwd( # pylint: disable=unused-argument - captured: dict, +def _ffn_fwd_per_shard( + recv_tokens_local: jnp.ndarray, + recv_topk_weights_local: jnp.ndarray, + token_counts_local: jnp.ndarray, + wi_0: jnp.ndarray, + wi_1: jnp.ndarray, + wo: jnp.ndarray, + wi_0_bias: Optional[jnp.ndarray], + wi_1_bias: Optional[jnp.ndarray], + wo_bias: Optional[jnp.ndarray], *, - # Statics - num_experts: int, - num_experts_per_tok: int, + num_local_experts: int, + slots_per_expert: int, activation_type: str, - score_function: ScoreFunction, - use_pre_softmax: bool, - num_groups: Optional[int], - group_topk: Optional[int], - scaling_factor: float, - aux_loss_coeff: float, - permutation_backend: PermutationBackend, - align_size: int, - gate_inside_vjp: bool, - quantizer_sets: Tuple[QuantizerSet, QuantizerSet, QuantizerSet], - dtype: jnp.dtype, - # EP-only statics - ep_active: bool, - ep_axis: Optional[str], - data_parallelism_axes: Tuple[str, ...], - fsdp_sizes: Tuple[int, ...], - num_ep: int, - num_experts_local: int, - recv_buffer_rows: int, -) -> Tuple[jnp.ndarray, jnp.ndarray, dict]: - """Per-shard forward body. Returns ``(output, aux_loss, ctx_dict)``. - - ``aux_loss`` is always materialized (zeros scalar when disabled) so - the ``shard_map``'s ``out_specs`` has a static structure. + apply_topk_weights_early: bool, +): + """Per-shard FFN forward. + + Operates on the shard-local ``[1, recv_pr, H]`` slice that + ``tex.ep_dispatch`` produces. Returns the expert outputs (shaped + ``[1, recv_pr, H_out]`` so the surrounding ``shard_map`` reassembles + them as ``[num_procs, recv_pr, H_out]``) plus the residuals consumed + by the bwd. + + ``token_counts_local`` (``[1, num_local_experts]``, from + ``tex.ep_prepare``) is passed to ``grouped_gemm`` as ``group_sizes`` + so cuBLAS skips both 0-token-routed experts and the dispatch + overalloc tail. """ - if not gate_inside_vjp: - raise NotImplementedError( - "gate_inside_vjp=False is deferred to a follow-up PR; for now" - " the gate GEMM lives inside the MoE VJP." - ) - - x = captured["inputs"] - gate_kernel = captured["gate_kernel"] - wi_0 = captured["wi_0"] - wi_1 = captured["wi_1"] - wo = captured["wo"] - wi_0_bias = captured.get("wi_0_bias") - wi_1_bias = captured.get("wi_1_bias") - wo_bias = captured.get("wo_bias") - expert_bias = captured.get("expert_bias") - - batch_size, sequence_length, hidden = x.shape - - # ---------------- Stage 1: gate ---------------- - gate_kernel_cast = gate_kernel.astype(x.dtype) - gate_logits = jnp.einsum("bsh,he->bse", x, gate_kernel_cast) - logits_2d = gate_logits.reshape(-1, num_experts) - inputs_2d = x.reshape(-1, hidden) - - # ---------------- Stage 2: routing ---------------- - # Under EP, expert_bias is sharded P(ep_axis); the router needs the - # full E-dim view, so all_gather it. - if ep_active and expert_bias is not None: - full_expert_bias = jax.lax.all_gather(expert_bias, axis_name=ep_axis, tiled=True) - else: - full_expert_bias = expert_bias - # Pass an empty array sentinel when expert_bias is unused (the - # underlying primitive expects a real ndarray, not None). - eb_arg = ( - full_expert_bias if full_expert_bias is not None else jnp.zeros((0,), dtype=jnp.float32) - ) - sparse_probs, routing_map, saved_scores = tex.fused_topk_with_score_function_fwd( - logits_2d, - topk=num_experts_per_tok, - use_pre_softmax=use_pre_softmax, - num_groups=-1 if num_groups is None else num_groups, - group_topk=-1 if group_topk is None else group_topk, - scaling_factor=scaling_factor, - score_function=score_function, - expert_bias=eb_arg, - compute_aux_scores=False, - ) - sparse_probs = sparse_probs.astype(dtype) - - # ---------------- Stage 2b: aux loss ---------------- - if aux_loss_coeff > 0.0: - if ep_active: - collective_axes: Any = ( - ep_axis if not data_parallelism_axes else (ep_axis, *data_parallelism_axes) - ) - global_logits_2d = jax.lax.all_gather( - logits_2d, axis_name=collective_axes, axis=0, tiled=True - ) - _, global_routing_map, _ = tex.fused_topk_with_score_function_fwd( - global_logits_2d, - topk=num_experts_per_tok, - use_pre_softmax=use_pre_softmax, - num_groups=-1 if num_groups is None else num_groups, - group_topk=-1 if group_topk is None else group_topk, - scaling_factor=scaling_factor, - score_function=score_function, - expert_bias=eb_arg, - compute_aux_scores=False, - ) - aux_tokens_per_expert = jnp.sum(global_routing_map.astype(jnp.int32), axis=0) - aux_logits_for_score = global_logits_2d - else: - aux_tokens_per_expert = jnp.sum(routing_map.astype(jnp.int32), axis=0) - aux_logits_for_score = logits_2d - # Aux-side scores: clean per-expert scores (no grouped routing, - # no bias). compute_aux_scores=True takes a separate path that - # ignores the grouping knobs. - aux_probs, _aux_routing_map, aux_saved_scores = tex.fused_topk_with_score_function_fwd( - aux_logits_for_score.astype(jnp.float32), - topk=num_experts_per_tok, - use_pre_softmax=False, - num_groups=-1, - group_topk=-1, - scaling_factor=1.0, - score_function=score_function, - expert_bias=jnp.zeros((0,), dtype=jnp.float32), - compute_aux_scores=True, - ) - aux_loss, aux_const_buf = tex.fused_moe_aux_loss_fwd( - aux_probs.astype(jnp.float32), - aux_tokens_per_expert.astype(jnp.int32), - topk=num_experts_per_tok, - coeff=aux_loss_coeff, - ) - else: - aux_loss = jnp.zeros((), dtype=dtype) - aux_const_buf = None - aux_tokens_per_expert = None - aux_logits_for_score = None - aux_saved_scores = None - - # ---------------- Stage 3: dispatch ---------------- - shard_id = jax.lax.axis_index(ep_axis) if ep_active else None - sorted_x, dispatch_state = _dispatch( - inputs_2d, - sparse_probs, - routing_map, - backend=permutation_backend, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - align_size=align_size, - ep_active=ep_active, - ep_axis=ep_axis, - num_ep=num_ep, - recv_buffer_rows=recv_buffer_rows, - shard_id=shard_id, - ) - local_group_sizes = dispatch_state.group_sizes - - # ---------------- Stage 4: per-expert FFN (inlined) ---------------- - q_set_w0, q_set_w1, q_set_wo = quantizer_sets - if q_set_w0 == noop_quantizer_set: - wi_0 = wi_0.astype(sorted_x.dtype) - if q_set_w1 == noop_quantizer_set: - wi_1 = wi_1.astype(sorted_x.dtype) - if q_set_wo == noop_quantizer_set: - wo = wo.astype(sorted_x.dtype) - - # GEMM 1+2 (fused): up_proj_combined = sorted_x @ wi where - # wi := concat([wi_0, wi_1], axis=-1) -> shape [E, H, 2M] - # combined_out := sorted_x @ wi -> shape [T, 2M] - # Splitting the output back into ``gate_proj_out`` / ``up_proj_out`` - # is free (it's a slicing reshape). This collapses two grouped - # GEMMs and two grouped quantizes of ``sorted_x`` (one per kernel) - # into one of each. Bias is concatenated the same way. - # - # FP8/MXFP8 caveat: per-expert amax is now computed over [H, 2M] - # rather than [H, M] for each of wi_0 / wi_1 separately, so the - # representable range for one of the two halves may shift slightly - # vs. the pre-fusion code. Numerics tests cover this. - inter_M = wi_0.shape[-1] + hidden = recv_tokens_local.shape[-1] + sorted_x = recv_tokens_local.reshape(-1, hidden) + recv_w_flat = recv_topk_weights_local.reshape(-1) + local_group_sizes = token_counts_local.reshape(-1).astype(jnp.int32) + del slots_per_expert # not used since group_sizes is plumbed in dynamically + + wi_0 = wi_0.astype(sorted_x.dtype) + wi_1 = wi_1.astype(sorted_x.dtype) + wo = wo.astype(sorted_x.dtype) + + # Concat wi_0/wi_1 along the trailing axis (NOT stack on a new + # axis). grouped_gemm requires the 3D (G, K, N) weight layout with + # contracting_dims=((1,), (1,)); a 4D stack variant walks off the + # end of the RHS and returns NaN. wi_combined = jnp.concatenate([wi_0, wi_1], axis=-1) wi_combined_bias = ( jnp.concatenate([wi_0_bias, wi_1_bias], axis=-1) if wi_0_bias is not None else None ) - casted_sorted_x = tex.grouped_quantize(sorted_x, q_set_w0.x, local_group_sizes, flatten_axis=-1) - casted_wi = tex.grouped_quantize(wi_combined, q_set_w0.kernel, flatten_axis=-1) + + q_set = noop_quantizer_set + casted_sorted_x = tex.grouped_quantize(sorted_x, q_set.x, local_group_sizes, flatten_axis=-1) + casted_wi = tex.grouped_quantize(wi_combined, q_set.kernel, flatten_axis=-1) combined_out = tex.grouped_gemm( casted_sorted_x.get_tensor(usage=TensorUsage.LHS), casted_wi.get_tensor(usage=TensorUsage.RHS), contracting_dims=((1,), (1,)), bias=wi_combined_bias, ) - gate_proj_out = combined_out[..., :inter_M] - up_proj_out = combined_out[..., inter_M:] + gate_proj_out, up_proj_out = jnp.split(combined_out, 2, axis=-1) casted_sorted_x_lhs_trans = casted_sorted_x.get_tensor(usage=TensorUsage.LHS_TRANS) casted_wi_rhs_trans = casted_wi.get_tensor(usage=TensorUsage.RHS_TRANS) - if isinstance(casted_sorted_x_lhs_trans, ScaledTensor): - casted_sorted_x_lhs_trans = casted_sorted_x_lhs_trans.checkpoint(q_set_w0.x) - if isinstance(casted_wi_rhs_trans, ScaledTensor): - casted_wi_rhs_trans = casted_wi_rhs_trans.checkpoint(q_set_w0.kernel) - # Activation: intermediate = act(gate_proj_out) * up_proj_out + # Activation inputs (gate_proj_out, up_proj_out) stay in the wi GEMM + # output dtype; the activation output (`intermediate`) stays in the + # dtype the wo GEMM / wo's quantized input consumes. For bf16 compute + # that's all bf16; for FP8/FP4 the downstream grouped_quantize is what + # transitions to the target precision. act_fn = _convert_to_activation_function(activation_type) intermediate = act_fn(gate_proj_out) * up_proj_out - # GEMM 3: expert_outputs = intermediate @ wo + if apply_topk_weights_early: + # Fold the per-token combine weights into the FFN intermediate; + # the downstream wo GEMM is linear so this is equivalent to the + # late-weighting path, modulo elementwise op fusion gains. w_b is + # cast to intermediate.dtype so the multiply doesn't promote + # expert_outputs above the EP buffer's element width + # (ep_bootstrap rejects max_token_dtype != bf16, and the NCCL EP + # HT mega-buffer is sized for 2-byte slots accordingly). + w_b = recv_w_flat[:, None].astype(intermediate.dtype) + mask_b = (recv_w_flat != 0).astype(intermediate.dtype)[:, None] + intermediate = intermediate * w_b * mask_b + casted_intermediate = tex.grouped_quantize( - intermediate, q_set_wo.x, local_group_sizes, flatten_axis=-1 + intermediate, q_set.x, local_group_sizes, flatten_axis=-1 ) - casted_wo = tex.grouped_quantize(wo, q_set_wo.kernel, flatten_axis=-1) + casted_wo = tex.grouped_quantize(wo, q_set.kernel, flatten_axis=-1) expert_outputs = tex.grouped_gemm( casted_intermediate.get_tensor(usage=TensorUsage.LHS), casted_wo.get_tensor(usage=TensorUsage.RHS), @@ -1151,524 +386,143 @@ def _body_fwd( # pylint: disable=unused-argument ) casted_intermediate_lhs_trans = casted_intermediate.get_tensor(usage=TensorUsage.LHS_TRANS) casted_wo_rhs_trans = casted_wo.get_tensor(usage=TensorUsage.RHS_TRANS) - if isinstance(casted_intermediate_lhs_trans, ScaledTensor): - casted_intermediate_lhs_trans = casted_intermediate_lhs_trans.checkpoint(q_set_wo.x) - if isinstance(casted_wo_rhs_trans, ScaledTensor): - casted_wo_rhs_trans = casted_wo_rhs_trans.checkpoint(q_set_wo.kernel) - - # ---------------- Stage 5: combine ---------------- - # Compute per-shard static shape info once and pass through both - # _combine and (later) the bwd helpers via kwargs -- never via the - # state dict, which gets pytree-flattened across shard_map and would - # coerce Python ints into JitTracer 0-d arrays. - _static_shape = _compute_static_shape_info( - batch_size=batch_size, - sequence_length=sequence_length, - hidden=hidden, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - align_size=align_size, - ep_active=ep_active, - num_ep=num_ep, - fsdp_sizes=fsdp_sizes, - recv_buffer_rows=recv_buffer_rows, - ) - # ``expert_outputs_residual`` is the post-A2A FFN-output tensor that - # Step 3 of the combine actually consumed. Saving this (rather than - # the pre-A2A shard-local FFN output) is what makes - # ``_combine_bwd``'s Step-3 inverse see the same value the forward - # Step 3 saw -- otherwise EP + TRITON yields wrong d_expert_outputs. - output, expert_outputs_residual = _combine( - expert_outputs, - dispatch_state, - backend=permutation_backend, - ep_active=ep_active, - batch_size=batch_size, - sequence_length=sequence_length, - dtype=dtype, - num_experts_per_tok=num_experts_per_tok, - num_real_tokens=_static_shape.num_real_tokens, - padding_size=_static_shape.padding_size, - pre_a2a_buffer_shape=_static_shape.pre_a2a_buffer_shape, - ep_axis=ep_axis, - shard_id=shard_id, - num_ep=num_ep, - ) - # ---------------- Build ctx ---------------- - aux_enabled = aux_loss_coeff > 0.0 - ctx = _BodyCtx( - x=x, - gate_kernel=gate_kernel, - logits_2d=logits_2d, - saved_scores=saved_scores, - routing_map=routing_map, - dispatch=dispatch_state, - casted_sorted_x_lhs_trans=casted_sorted_x_lhs_trans, - casted_wi_rhs_trans=casted_wi_rhs_trans, - gate_proj_out=gate_proj_out, - up_proj_out=up_proj_out, - casted_intermediate_lhs_trans=casted_intermediate_lhs_trans, - casted_wo_rhs_trans=casted_wo_rhs_trans, - expert_outputs=expert_outputs_residual, - local_group_sizes=local_group_sizes, - expert_bias=expert_bias if expert_bias is not None else None, - aux_const_buf=aux_const_buf if aux_enabled else None, - aux_tokens_per_expert=aux_tokens_per_expert if aux_enabled else None, - aux_logits_for_score=aux_logits_for_score if aux_enabled else None, - aux_saved_scores=aux_saved_scores if aux_enabled else None, + expert_outputs_3d = expert_outputs.reshape(1, expert_outputs.shape[0], expert_outputs.shape[1]) + # Reshape local_group_sizes to (1, num_local_experts) so the + # surrounding shard_map can stitch per-shard counts back into the + # global (num_procs, num_local_experts) layout matching token_counts. + local_group_sizes_3d = local_group_sizes.reshape(1, num_local_experts) + residuals = ( + casted_sorted_x_lhs_trans, + casted_wi_rhs_trans, + gate_proj_out, + up_proj_out, + casted_intermediate_lhs_trans, + casted_wo_rhs_trans, + local_group_sizes_3d, ) - - return output, aux_loss, ctx - - -def _body_bwd( # pylint: disable=unused-argument - ctx: _BodyCtx, - dy_pair: Tuple[jnp.ndarray, jnp.ndarray], + return expert_outputs_3d, residuals + + +def _ffn_bwd_per_shard( + d_expert_outputs_local: jnp.ndarray, + casted_sorted_x_lhs_trans, + casted_wi_rhs_trans, + gate_proj_out: jnp.ndarray, + up_proj_out: jnp.ndarray, + casted_intermediate_lhs_trans, + casted_wo_rhs_trans, + local_group_sizes: jnp.ndarray, + recv_topk_weights_local: jnp.ndarray, *, - num_experts: int, - num_experts_per_tok: int, activation_type: str, - score_function: ScoreFunction, - use_pre_softmax: bool, - num_groups: Optional[int], - group_topk: Optional[int], - scaling_factor: float, - aux_loss_coeff: float, - permutation_backend: PermutationBackend, - align_size: int, - gate_inside_vjp: bool, - quantizer_sets: Tuple[QuantizerSet, QuantizerSet, QuantizerSet], - dtype: jnp.dtype, - ep_active: bool, - ep_axis: Optional[str], - data_parallelism_axes: Tuple[str, ...], - fsdp_sizes: Tuple[int, ...], - num_ep: int, - num_experts_local: int, - recv_buffer_rows: int, - # Static side info (kept here rather than inside ctx because they're - # python flags / shapes, not array leaves): - has_wi_bias: bool, - has_wo_bias: bool, - has_expert_bias: bool, - x_shape: Tuple[int, ...], -) -> dict: - """Per-shard backward body. Returns a dict of grads keyed identically - to the ``captured`` dict consumed by :func:`_body_fwd`.""" - if not gate_inside_vjp: - raise NotImplementedError("gate_inside_vjp=False is deferred to a follow-up PR.") - - d_output, d_aux_loss = dy_pair - # The fused FFN bwd quantizes via ``q_set_w0`` only (one quantize for - # the [E, H, 2M] fused wi tensor and one for the [T, 2M] fused dgrad), - # so ``q_set_w1`` is intentionally unused here. - q_set_w0, _q_set_w1, q_set_wo = quantizer_sets - batch_size, sequence_length, hidden = x_shape - shard_id = jax.lax.axis_index(ep_axis) if ep_active else None - - # Recompute per-shard static shape info from existing statics - # (Python ints / int tuples). Plumbed via kwargs to _combine_bwd - # and _dispatch_bwd -- NOT through the ctx dict, because the - # dict gets pytree-flattened across the bwd shard_map's in_specs - # and Python ints would be coerced into JitTracer 0-d arrays - # (breaking ``if padding > 0`` and ``jnp.zeros(shape)`` callsites). - # ``batch_size`` here is the GLOBAL batch size (captured in - # ``x_shape`` by the outer fwd rule), hence ``batch_is_per_shard=False``. - _static_shape = _compute_static_shape_info( - batch_size=batch_size, - sequence_length=sequence_length, - hidden=hidden, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - align_size=align_size, - ep_active=ep_active, - num_ep=num_ep, - fsdp_sizes=fsdp_sizes, - recv_buffer_rows=recv_buffer_rows, - batch_is_per_shard=False, - ) - - # Compute per-shard input shape: under the EP shard_map body, the - # gradient tensors live at per-shard shape, so the dispatch_bwd - # reshape target and ``d_x_from_dispatch.reshape(x_shape)`` below - # must use the per-shard shape rather than the captured global - # ``x_shape``. - if ep_active: - dp_size = math.prod(fsdp_sizes) if fsdp_sizes else 1 - per_shard_batch = batch_size // (num_ep * dp_size) - per_shard_x_shape: Tuple[int, ...] = (per_shard_batch, sequence_length, hidden) - else: - per_shard_x_shape = x_shape - - # ---------------- Combine bwd ---------------- - d_expert_outputs, d_routing_weights = _combine_bwd( - d_output, - ctx.dispatch, - ctx.expert_outputs, - backend=permutation_backend, - ep_active=ep_active, - batch_size=batch_size, - sequence_length=sequence_length, - dtype=dtype, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - num_real_tokens=_static_shape.num_real_tokens, - padding_size=_static_shape.padding_size, - post_a2a_buffer_shape=_static_shape.post_a2a_buffer_shape, - ep_axis=ep_axis, - shard_id=shard_id, - num_ep=num_ep, - ) + apply_topk_weights_early: bool, + has_bias: bool, +): + """Per-shard FFN backward. - # ---------------- FFN bwd: GEMM 3 (wo) ---------------- - casted_d_eo = tex.grouped_quantize( - d_expert_outputs, q_set_wo.dgrad, ctx.local_group_sizes, flatten_axis=-1 - ) + Mirrors :func:`_ffn_fwd_per_shard`. Returns + ``(d_sorted_x [1, recv_pr, H], d_recv_w [1, recv_pr], + d_wi_0, d_wi_1, d_wo, d_wi_0_bias, d_wi_1_bias, d_wo_bias)``. + """ + local_group_sizes = local_group_sizes.reshape(-1).astype(jnp.int32) + d_eo_2d = d_expert_outputs_local.reshape(-1, d_expert_outputs_local.shape[-1]) + recv_w_flat = recv_topk_weights_local.reshape(-1) + q_set = noop_quantizer_set + # cuBLAS grouped_gemm skips size_g == 0 groups without zero-filling + # the output slice; mask 0-token-expert wgrads to zero so the + # optimizer never sees uninit memory. + wgrad_group_active = (local_group_sizes > 0)[:, None, None] + + # wo bwd + casted_d_eo = tex.grouped_quantize(d_eo_2d, q_set.dgrad, local_group_sizes, flatten_axis=-1) + _casted_d_eo_lhs = casted_d_eo.get_tensor(usage=TensorUsage.LHS) + _casted_d_eo_rhs = casted_d_eo.get_tensor(usage=TensorUsage.RHS) d_intermediate = tex.grouped_gemm( - casted_d_eo.get_tensor(usage=TensorUsage.LHS), - ctx.casted_wo_rhs_trans, + _casted_d_eo_lhs, + casted_wo_rhs_trans, contracting_dims=((1,), (2,)), ) d_wo = tex.grouped_gemm( - ctx.casted_intermediate_lhs_trans, - casted_d_eo.get_tensor(usage=TensorUsage.RHS), + casted_intermediate_lhs_trans, + _casted_d_eo_rhs, contracting_dims=((0,), (0,)), ) - d_wo_bias = tex.grouped_dbias(d_expert_outputs, ctx.local_group_sizes) if has_wo_bias else None + d_wo = jnp.where(wgrad_group_active, d_wo, jnp.zeros_like(d_wo)) + d_wo_bias = tex.grouped_dbias(d_eo_2d, local_group_sizes) if has_bias else None - # ---------------- Activation bwd ---------------- - # intermediate = act(gate_proj_out) * up_proj_out - # d(gate_proj_out) = vjp(act, gate_proj_out)(d_intermediate * up_proj_out) - # d(up_proj_out) = d_intermediate * act(gate_proj_out) act_fn = _convert_to_activation_function(activation_type) - act_gate_proj_out, dact_gate_proj_pullback = jax.vjp(act_fn, ctx.gate_proj_out) - d_up_proj_out = d_intermediate * act_gate_proj_out - (d_gate_proj_out,) = dact_gate_proj_pullback(d_intermediate * ctx.up_proj_out) - - # ---------------- FFN bwd: GEMM 1+2 fused (wi_0 | wi_1) ---------------- - # Concat the two upstream grads along the output (M) axis, do one - # grouped quantize + one dgrad GEMM + one wgrad GEMM, then split. - # ``ctx.casted_wi_rhs_trans`` has shape [E, H, 2M] from the fwd - # fused quantize, so the dgrad math is: - # d_sorted_x = [d_gate | d_up] @ wi_rhs_trans - # = d_gate @ wi_0^T + d_up @ wi_1^T - inter_M = d_gate_proj_out.shape[-1] + if apply_topk_weights_early: + # intermediate' = intermediate * w * mask. Split the cotangent + # across both factors before the activation bwd consumes it. + # Cast w_b so the multiply stays in d_intermediate.dtype and + # d_sorted_x (downstream into ep_dispatch_bwd) stays bf16. + w_b = recv_w_flat[:, None].astype(d_intermediate.dtype) + mask_b = (recv_w_flat != 0).astype(d_intermediate.dtype)[:, None] + intermediate_unweighted = act_fn(gate_proj_out) * up_proj_out + d_recv_w_from_intermediate = jnp.sum( + d_intermediate * intermediate_unweighted * mask_b, axis=-1 + ).astype(recv_w_flat.dtype) + d_intermediate = d_intermediate * w_b * mask_b + else: + d_recv_w_from_intermediate = jnp.zeros_like(recv_w_flat) + + # Activation bwd, symmetric with the fwd: silu' and the two + # elementwise products run in the GEMM dtype (no fp32 island), so + # the chain rule composes through at the same precision the wi/wo + # GEMMs consume. + act_gp, dact_pullback = jax.vjp(act_fn, gate_proj_out) + d_up_proj_out = d_intermediate * act_gp + (d_gate_proj_out,) = dact_pullback(d_intermediate * up_proj_out) + + # wi bwd (fused gate/up via concat). Mirror the fused fwd: pack the + # gate/up cotangents along the trailing axis, run a single + # grouped_quantize + two grouped_gemm pair (one dgrad, one wgrad) + # against the fused casted_wi_rhs_trans residual, then split the + # wgrad result back into d_wi_0 / d_wi_1 halves with jnp.split. d_combined = jnp.concatenate([d_gate_proj_out, d_up_proj_out], axis=-1) casted_d_combined = tex.grouped_quantize( - d_combined, q_set_w0.dgrad, ctx.local_group_sizes, flatten_axis=-1 + d_combined, q_set.dgrad, local_group_sizes, flatten_axis=-1 ) d_sorted_x = tex.grouped_gemm( casted_d_combined.get_tensor(usage=TensorUsage.LHS), - ctx.casted_wi_rhs_trans, + casted_wi_rhs_trans, contracting_dims=((1,), (2,)), ) d_wi_combined = tex.grouped_gemm( - ctx.casted_sorted_x_lhs_trans, + casted_sorted_x_lhs_trans, casted_d_combined.get_tensor(usage=TensorUsage.RHS), contracting_dims=((0,), (0,)), ) - d_wi_0 = d_wi_combined[..., :inter_M] - d_wi_1 = d_wi_combined[..., inter_M:] - if has_wi_bias: - d_wi_combined_bias = tex.grouped_dbias(d_combined, ctx.local_group_sizes) - d_wi_0_bias = d_wi_combined_bias[..., :inter_M] - d_wi_1_bias = d_wi_combined_bias[..., inter_M:] + d_wi_combined = jnp.where(wgrad_group_active, d_wi_combined, jnp.zeros_like(d_wi_combined)) + d_wi_0, d_wi_1 = jnp.split(d_wi_combined, 2, axis=-1) + if has_bias: + d_wi_combined_bias = tex.grouped_dbias(d_combined, local_group_sizes) + d_wi_0_bias, d_wi_1_bias = jnp.split(d_wi_combined_bias, 2, axis=-1) else: d_wi_0_bias = None d_wi_1_bias = None - # ---------------- Dispatch bwd ---------------- - inputs_2d_shape = (per_shard_x_shape[0] * per_shard_x_shape[1], hidden) - d_inputs_2d = _dispatch_bwd( - d_sorted_x, - ctx.dispatch, - inputs_2d_shape=inputs_2d_shape, - backend=permutation_backend, - ep_active=ep_active, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - num_real_tokens=_static_shape.num_real_tokens, - padding_size=_static_shape.padding_size, - pre_a2a_buffer_shape=_static_shape.pre_a2a_buffer_shape, - ep_axis=ep_axis, - shard_id=shard_id, - num_ep=num_ep, - ) - d_x_from_dispatch = d_inputs_2d.reshape(per_shard_x_shape) - - # ---------------- Routing bwd ---------------- - # The probs cotangent comes from _combine_bwd. For PURE_JAX it's the - # cotangent of routing_weights (post-routing_map_to_selected_experts); - # we need to bridge back to sparse_probs. For TRITON it's already the - # cotangent of merging_probs == sparse_probs. - if d_routing_weights is not None: - if permutation_backend is PermutationBackend.PURE_JAX: - # routing_map_to_selected_experts: - # selected_experts = argsort(routing_map)[..., -topk:] - # weights = take_along_axis(sparse_probs, selected_experts, axis=-1) - # routing_map is bool (non-diff); the gradient of weights - # w.r.t. sparse_probs is a scatter-into-zero along the - # selected_experts indices. - selected_experts = jnp.argsort(ctx.routing_map, axis=-1)[..., -num_experts_per_tok:] - d_sparse_probs = jnp.zeros_like(ctx.saved_scores).astype(d_routing_weights.dtype) - d_sparse_probs = jnp.take_along_axis(d_sparse_probs, selected_experts, axis=-1) - # Actually scatter: build via jnp.zeros + .at[].set - d_sparse_probs = jnp.zeros(ctx.routing_map.shape, dtype=d_routing_weights.dtype) - d_sparse_probs = d_sparse_probs.at[ - jnp.arange(ctx.routing_map.shape[0])[:, None], selected_experts - ].set(d_routing_weights) - else: - d_sparse_probs = d_routing_weights.astype(jnp.float32) - else: - d_sparse_probs = jnp.zeros(ctx.routing_map.shape, dtype=jnp.float32) - - # Topk bwd primitive: returns d_logits (no d_expert_bias). - d_logits_2d_main = tex.fused_topk_with_score_function_bwd( - ctx.routing_map, - ctx.saved_scores, - d_sparse_probs.astype(ctx.saved_scores.dtype), - topk=num_experts_per_tok, - use_pre_softmax=use_pre_softmax, - scaling_factor=scaling_factor, - score_function=score_function, - compute_aux_scores=False, - ) - - # ---------------- Aux loss bwd ---------------- - if aux_loss_coeff > 0.0: - # Step 1: aux_loss bwd -> d_aux_probs - aux_num_tokens = ctx.aux_logits_for_score.shape[0] - d_aux_probs = tex.fused_moe_aux_loss_bwd( - ctx.aux_const_buf, - ctx.aux_tokens_per_expert.astype(jnp.int32), - d_aux_loss.reshape(()), - num_tokens=aux_num_tokens, - ) - # Step 2: aux-side topk bwd (compute_aux_scores=True path). - # The routing_map argument is ignored in this branch (the kernel - # uses saved_scores); pass any shape-correct integer tensor. - d_aux_logits = tex.fused_topk_with_score_function_bwd( - jnp.zeros(ctx.aux_logits_for_score.shape, dtype=jnp.bool_), - ctx.aux_saved_scores, - d_aux_probs.astype(ctx.aux_saved_scores.dtype), - topk=num_experts_per_tok, - use_pre_softmax=False, - scaling_factor=1.0, - score_function=score_function, - compute_aux_scores=True, - ) - # Step 3: under EP the aux logits were all_gathered along - # ``(ep_axis, *data_parallelism_axes)`` (the latter being FSDP - # axes that shard the batch). The bwd is the inverse of that - # multi-axis tiled all_gather: ``dynamic_slice`` to pick out - # this shard's local rows from the global cotangent. - # - # JAX's convention for tiled ``all_gather(axis_name=(a, b, ...))`` - # is row-major over the tuple: the shard at mesh position - # ``(i_a, i_b, ...)`` writes to rows - # ``[(i_a * size_b * ... + i_b * ... + ...) * local_T : - # + local_T)``. We invert that by computing the same flat - # index here and slicing. - if ep_active: - local_T_aux = ctx.logits_2d.shape[0] - flat_shard = shard_id # ep is the outermost axis in the gather tuple - for ax, sz in zip(data_parallelism_axes, fsdp_sizes): - flat_shard = flat_shard * sz + jax.lax.axis_index(ax) - d_aux_logits_local = jax.lax.dynamic_slice( - d_aux_logits.astype(ctx.logits_2d.dtype), - start_indices=(flat_shard * local_T_aux, 0), - slice_sizes=(local_T_aux, num_experts), - ) - else: - d_aux_logits_local = d_aux_logits.astype(d_logits_2d_main.dtype) - d_logits_2d = d_logits_2d_main + d_aux_logits_local.astype(d_logits_2d_main.dtype) - else: - d_logits_2d = d_logits_2d_main - - # ---------------- Gate bwd ---------------- - d_gate_logits = d_logits_2d.reshape(per_shard_x_shape[0], per_shard_x_shape[1], num_experts) - gate_kernel_cast = ctx.gate_kernel.astype(ctx.x.dtype) - d_x_from_gate = jnp.einsum("bse,he->bsh", d_gate_logits, gate_kernel_cast) - d_gate_kernel = jnp.einsum("bsh,bse->he", ctx.x, d_gate_logits).astype(ctx.gate_kernel.dtype) - d_x = d_x_from_gate + d_x_from_dispatch - - # Reduce per-rank partial contributions to match the out_specs - # declared by _build_grads_specs: - # gate_kernel : P() -> psum across (ep, *fsdp) - # wi_0/wi_1/wo : P(ep_axis, ...) -> psum across (*fsdp) only - # inputs : P((ep, fsdp), ...) -> already shard-local, no reduction - if ep_active: - replicate_all = (ep_axis,) + tuple(data_parallelism_axes) - d_gate_kernel = jax.lax.psum(d_gate_kernel, axis_name=replicate_all) - if data_parallelism_axes: - replicate_fsdp = tuple(data_parallelism_axes) - d_wi_0 = jax.lax.psum(d_wi_0, axis_name=replicate_fsdp) - d_wi_1 = jax.lax.psum(d_wi_1, axis_name=replicate_fsdp) - d_wo = jax.lax.psum(d_wo, axis_name=replicate_fsdp) - if has_wi_bias: - d_wi_0_bias = jax.lax.psum(d_wi_0_bias, axis_name=replicate_fsdp) - d_wi_1_bias = jax.lax.psum(d_wi_1_bias, axis_name=replicate_fsdp) - if has_wo_bias: - d_wo_bias = jax.lax.psum(d_wo_bias, axis_name=replicate_fsdp) - - grads: dict = { - "inputs": d_x, - "gate_kernel": d_gate_kernel, - "wi_0": d_wi_0, - "wi_1": d_wi_1, - "wo": d_wo, - } - if has_wi_bias: - grads["wi_0_bias"] = d_wi_0_bias - grads["wi_1_bias"] = d_wi_1_bias - if has_wo_bias: - grads["wo_bias"] = d_wo_bias - if has_expert_bias: - # expert_bias has no gradient through topk (the topk bwd returns - # None for it). Emit a structural zero so the outer rule has - # something to package. - grads["expert_bias"] = jnp.zeros_like(ctx.expert_bias) - return grads - - -# ============================================================================= -# Spec builders for shard_map (lockstep with ctx_dict / captured_dict) -# ============================================================================= - - -def _build_in_specs( - ep_axis: str, - batch_pspec_axis: Any, - *, - has_bias: bool, - has_expert_bias: bool, -) -> dict: - """Build the ``in_specs`` dict for the EP fwd shard_map.""" - specs: dict = { - "inputs": P(batch_pspec_axis, None, None), - "gate_kernel": P(), - "wi_0": P(ep_axis, None, None), - "wi_1": P(ep_axis, None, None), - "wo": P(ep_axis, None, None), - } - if has_bias: - for name in ("wi_0_bias", "wi_1_bias", "wo_bias"): - specs[name] = P(ep_axis, None) - if has_expert_bias: - specs["expert_bias"] = P(ep_axis) - return specs - - -def _build_dispatch_specs( # pylint: disable=unused-argument - ep_axis: str, - *, - backend: PermutationBackend, - ep_active: bool, - align_size: int, -) -> _DispatchState: - """Build the shard_map ``out_specs`` for the dispatch state. - - Returns a :data:`_DispatchState` (either :class:`_PureJaxDispatchState` - or :class:`_TritonDispatchState`) whose fields are - :class:`PartitionSpec` placeholders. Optional fields are set to - ``P()`` when populated by :func:`_dispatch` and to ``None`` when - intentionally omitted, so the spec's pytree structure mirrors the - value's structure leaf-for-leaf. - """ - ep_all = P() if ep_active else None - ep_local = P() if ep_active else None - if backend is PermutationBackend.PURE_JAX: - return _PureJaxDispatchState( - group_sizes=P(), - sorted_indices=P(), - routing_weights=P(), - all_shards_tokens_per_expert=ep_all, - local_perm_row_id_map=ep_local, - ) - return _TritonDispatchState( - group_sizes=P(), - row_id_map=P(), - pad_offsets=P() if align_size > 0 else None, - merging_probs=P(), - all_shards_tokens_per_expert=ep_all, - local_perm_row_id_map=ep_local, - ) - - -def _build_ctx_specs( # pylint: disable=unused-argument - ep_axis: str, - batch_pspec_axis: Any, - *, - backend: PermutationBackend, - ep_active: bool, - has_bias: bool, - has_expert_bias: bool, - aux_loss_enabled: bool, - align_size: int, -) -> _BodyCtx: - """Build the spec :class:`_BodyCtx` mirroring :func:`_body_fwd`'s ctx. - - Fields gated off by the static config (``expert_bias``, ``aux_*``) - are ``None`` here so the spec pytree matches the value pytree - leaf-for-leaf. - """ - return _BodyCtx( - # Per-shard local activations along the batch axis. - x=P(batch_pspec_axis, None, None), - gate_kernel=P(), - logits_2d=P(batch_pspec_axis, None), - saved_scores=P(batch_pspec_axis, None), - routing_map=P(batch_pspec_axis, None), - dispatch=_build_dispatch_specs( - ep_axis, backend=backend, ep_active=ep_active, align_size=align_size - ), - # FFN residuals: the LHS_TRANS / RHS_TRANS variants of - # grouped_quantize have leading "rows"/"experts" dims that are - # already shard-local (post-dispatch). Use P(ep_axis,...) on - # leading dim; that works whether the leaf is a plain ndarray - # or a ScaledTensor (shard_map applies the spec leaf-wise to - # the registered ScaledTensor pytree). - casted_sorted_x_lhs_trans=P(), - casted_wi_rhs_trans=P(ep_axis, None, None), - gate_proj_out=P(), - up_proj_out=P(), - casted_intermediate_lhs_trans=P(), - casted_wo_rhs_trans=P(ep_axis, None, None), - expert_outputs=P(), - local_group_sizes=P(), - expert_bias=P(ep_axis) if has_expert_bias else None, - aux_const_buf=P() if aux_loss_enabled else None, - aux_tokens_per_expert=P() if aux_loss_enabled else None, - aux_logits_for_score=P() if aux_loss_enabled else None, - aux_saved_scores=P() if aux_loss_enabled else None, - ) - - -def _build_grads_specs( - ep_axis: str, - batch_pspec_axis: Any, - *, - has_bias: bool, - has_expert_bias: bool, -) -> dict: - """Spec dict for the grads dict returned by :func:`_body_bwd`.""" - return _build_in_specs( - ep_axis, - batch_pspec_axis, - has_bias=has_bias, - has_expert_bias=has_expert_bias, + d_sorted_x_3d = d_sorted_x.reshape(1, d_sorted_x.shape[0], d_sorted_x.shape[1]) + d_recv_w_3d = d_recv_w_from_intermediate.reshape(1, -1) + return ( + d_sorted_x_3d, + d_recv_w_3d, + d_wi_0, + d_wi_1, + d_wo, + d_wi_0_bias, + d_wi_1_bias, + d_wo_bias, ) # ============================================================================= -# Top-level VJP rules +# Full fwd / bwd rules (custom_vjp halves) # ============================================================================= -def _moe_fwd_rule( # pylint: disable=unused-argument - # Args MUST match the positional order of ``_moe`` (diff first, - # then nondiff). See ``_moe_bwd_rule`` for the opposite convention. +def _moe_fwd_rule( x, gate_kernel, wi_0, @@ -1687,170 +541,337 @@ def _moe_fwd_rule( # pylint: disable=unused-argument group_topk, scaling_factor, aux_loss_coeff, - permutation_backend, - align_size, - gate_inside_vjp, ep_axis, data_parallelism_axes, input_axes, gate_kernel_axes, wi_kernel_axes, wo_kernel_axes, - quantizer_sets, dtype, + apply_topk_weights_early, ): - x = with_sharding_constraint_by_logical_axes(x, input_axes) - ep_active = ep_axis is not None - body_kwargs = { - "num_experts": num_experts, - "num_experts_per_tok": num_experts_per_tok, - "activation_type": activation_type, - "score_function": score_function, - "use_pre_softmax": use_pre_softmax, - "num_groups": num_groups, - "group_topk": group_topk, - "scaling_factor": scaling_factor, - "aux_loss_coeff": aux_loss_coeff, - "permutation_backend": permutation_backend, - "align_size": align_size, - "gate_inside_vjp": gate_inside_vjp, - "quantizer_sets": quantizer_sets, - "dtype": dtype, - "ep_axis": ep_axis, - "data_parallelism_axes": data_parallelism_axes, - } - captured: dict = { - "inputs": x, - "gate_kernel": gate_kernel, - "wi_0": wi_0, - "wi_1": wi_1, - "wo": wo, - } - has_bias = wi_0_bias is not None - has_expert_bias = expert_bias is not None - if has_bias: - captured["wi_0_bias"] = wi_0_bias - captured["wi_1_bias"] = wi_1_bias - captured["wo_bias"] = wo_bias - if has_expert_bias: - captured["expert_bias"] = expert_bias - - if not ep_active: - output, aux_loss, ctx = _body_fwd( - captured, - **body_kwargs, - ep_active=False, - fsdp_sizes=(), - num_ep=1, - num_experts_local=num_experts, - recv_buffer_rows=0, - ) - # Carry static side info to the bwd rule alongside ctx. These - # are Python ints/bools/tuples (NOT pytree leaves), so we - # bundle them as a plain dict rather than putting them on the - # ``_BodyCtx`` NamedTuple where shard_map would try to flatten - # them into JitTracers. - static = { - "has_wi_bias": has_bias, - "has_wo_bias": has_bias, - "has_expert_bias": has_expert_bias, - "x_shape": x.shape, - "num_experts_local": num_experts, - "recv_buffer_rows": 0, - } - return (output, aux_loss), (ctx, static) - - # ---------------- EP path ---------------- + """Forward: gate -> topk -> ep_dispatch -> shard_map(FFN) -> ep_combine. + + Returns ``(output, aux_loss)``. ``aux_loss`` is a zero scalar when + ``aux_loss_coeff == 0``. + """ + del gate_kernel_axes, wi_kernel_axes, wo_kernel_axes # used in bwd only from jax.experimental.shard_map import shard_map + x = with_sharding_constraint_by_logical_axes(x, input_axes) + mesh = _get_mesh() if mesh is None or mesh.empty: - raise ValueError("moe(...) requires an active jax.sharding.Mesh when ep_axis is set.") + raise ValueError("moe(...) requires an active jax.sharding.Mesh.") + if ep_axis is None: + raise ValueError("moe(...) requires ep_axis to be set (TE EP backend).") num_ep = mesh.shape[ep_axis] if num_experts % num_ep != 0: raise ValueError(f"num_experts={num_experts} must be divisible by EP size={num_ep}") - num_experts_local = num_experts // num_ep + num_local_experts = num_experts // num_ep - # Reject overlapping EP / FSDP axes. Listing ep_axis in - # data_parallelism_axes would produce a duplicate-axis PartitionSpec - # ((ep, ep, ...)) which JAX rejects, and would also double-count - # num_ep in dp_size (under-sizing recv_buffer_rows by a factor of - # num_ep). Catch it up front with a clear error. + dp_size = 1 for ax in data_parallelism_axes: - if ax not in mesh.shape: - raise ValueError( - f"data_parallelism_axes contains {ax!r} but mesh has" - f" axes {tuple(mesh.shape.keys())}" - ) - if ax == ep_axis: - raise ValueError( - f"data_parallelism_axes={data_parallelism_axes!r} contains the EP" - f" axis {ep_axis!r}; EP is implicit in the batch sharding and must" - " not also be listed as a data-parallel axis." - ) + dp_size *= mesh.shape[ax] + num_procs = num_ep * dp_size + + B, S, H = x.shape + K = num_experts_per_tok + if B % num_procs != 0: + raise ValueError(f"batch={B} not divisible by ep*dp={num_procs}") + + # Per-rank send capacity: B/num_procs rows x S tokens per rank. + max_tokens_per_rank = (B // num_procs) * S + # Per-rank receive capacity. NCCL EP HT lays out the per-rank receive + # buffer as ``[num_local_experts, num_ep * max_tokens_per_rank, hidden]`` + # (see nccl_ep.cc::init kernel buffer sizing + the LL combine assertion + # at nccl_ep.cc:2185 which spells out the same layout). The natural + # dropless K-expanded count + # ``ceil((B/dp)*S*K / num_local_experts)`` does NOT match: it ignores + # the worst-case where all of one EP group's tokens land on a single + # local expert. We must size to that worst case or NCCL EP's HT kernel + # rejects the dispatch buffer with ``invalid argument``. + natural_spe = num_ep * max_tokens_per_rank # = (B // dp_size) * S + # NCCL EP requires each expert-major output block to be at least + # ``_ALIGN_SIZE`` (=128) tokens; see the constant's docstring. + slots_per_expert = ((natural_spe + _ALIGN_SIZE - 1) // _ALIGN_SIZE) * _ALIGN_SIZE + recv_pr = num_local_experts * slots_per_expert + + _te_ep_assert_compatible_bootstrap( + num_experts=num_experts, + max_tokens_per_rank=max_tokens_per_rank, + recv_capacity_per_rank=recv_pr, + hidden_dim=H, + ep_size=num_ep, + ) if not data_parallelism_axes: batch_pspec_axis: Any = ep_axis else: - batch_pspec_axis = (ep_axis, *data_parallelism_axes) - dp_size = 1 - for ax in data_parallelism_axes: - dp_size *= mesh.shape[ax] + # ep must be innermost: ep_bootstrap forms NCCL EP comms from + # consecutive global ranks (dp_color = rank // ep_size), so the + # comm only stays within one model replica under (outer_dp, ep). + batch_pspec_axis = (*data_parallelism_axes, ep_axis) + ep3_spec = P(batch_pspec_axis, None, None) + ep2_spec = P(batch_pspec_axis, None) + x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, ep3_spec)) + + # ---------------- Gate (global view) ---------------- + # tex.fused_topk_with_score_function is only validated against its + # pytorch reference at fp32 (see tests/pytorch/test_fused_router.py: + # parametrize gates dtype on torch.float32 only; the tolerance helper + # raises NotImplementedError for any other dtype). Keeping logits in + # the activation dtype (e.g. bf16) lets sigmoid / softmax / topk + # accumulate at low precision and silently produce NaNs on tokens + # whose normalised weights underflow. Cast to fp32 here to stay in + # the validated regime. + gate_kernel_cast = gate_kernel.astype(x.dtype) + gate_logits = jnp.einsum("bsh,he->bse", x, gate_kernel_cast) + logits_2d = gate_logits.reshape(-1, num_experts).astype(jnp.float32) - global_batch_size, sequence_length, _hidden = x.shape - topk = num_experts_per_tok - if global_batch_size % (num_ep * dp_size) != 0: - raise ValueError(f"batch={global_batch_size} not divisible by ep*dp={num_ep * dp_size}") - recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk - if align_size > 0: - recv_buffer_rows += num_experts * (align_size - 1) + # ---------------- Routing (global view) ---------------- + # expert_bias is an empty (shape-(0,)) sentinel when the caller did + # not enable it; the primitive treats that as "no bias". + eb_arg = expert_bias if expert_bias.shape != (0,) else jnp.zeros((0,), dtype=jnp.float32) + sparse_probs, routing_map, saved_scores = tex.fused_topk_with_score_function_fwd( + logits_2d, + topk=K, + use_pre_softmax=use_pre_softmax, + num_groups=-1 if num_groups is None else num_groups, + group_topk=-1 if group_topk is None else group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=eb_arg, + compute_aux_scores=False, + ) + sparse_probs = sparse_probs.astype(dtype) - in_specs = _build_in_specs( - ep_axis, - batch_pspec_axis, - has_bias=has_bias, - has_expert_bias=has_expert_bias, + # ---------------- Aux loss (global view, replicated) ---------------- + # ``fused_moe_aux_loss_fwd`` sums probs and tokens_per_expert across + # all tokens, which is wrong when T is sharded. Force-replicate the + # gate logits and recompute the routing map at global view so the + # kernel sees a complete [T_global, E] tensor. The replication is a + # single all-gather over (*dp, ep) and lives off the dispatch + # critical path. + if aux_loss_coeff > 0.0: + global_logits_2d = jax.lax.with_sharding_constraint(logits_2d, NamedSharding(mesh, P())) + _, global_routing_map, _ = tex.fused_topk_with_score_function_fwd( + global_logits_2d, + topk=K, + use_pre_softmax=use_pre_softmax, + num_groups=-1 if num_groups is None else num_groups, + group_topk=-1 if group_topk is None else group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=eb_arg, + compute_aux_scores=False, + ) + aux_tokens_per_expert = jnp.sum(global_routing_map.astype(jnp.int32), axis=0) + # compute_aux_scores=True takes a separate kernel path: clean + # per-expert softmax, no grouping / bias / scaling. + aux_probs, _aux_rm, aux_saved_scores = tex.fused_topk_with_score_function_fwd( + global_logits_2d.astype(jnp.float32), + topk=K, + use_pre_softmax=False, + num_groups=-1, + group_topk=-1, + scaling_factor=1.0, + score_function=score_function, + expert_bias=jnp.zeros((0,), dtype=jnp.float32), + compute_aux_scores=True, + ) + aux_loss, aux_const_buf = tex.fused_moe_aux_loss_fwd( + aux_probs.astype(jnp.float32), + aux_tokens_per_expert.astype(jnp.int32), + topk=K, + coeff=aux_loss_coeff, + ) + aux_loss = aux_loss.astype(dtype) + else: + aux_loss = jnp.zeros((), dtype=dtype) + aux_const_buf = None + aux_tokens_per_expert = None + aux_saved_scores = None + + # ---------------- Routing -> (topk_idx, topk_w) at 3D ---------------- + # argsort on a bool tensor places True last (False=0 < True=1), so the + # last K indices are the selected expert IDs. + selected_experts = jnp.argsort(routing_map, axis=-1)[..., -K:] + routing_weights = jnp.take_along_axis(sparse_probs, selected_experts, axis=-1) + topk_idx_3d = selected_experts.reshape(B, S, K).astype(jnp.int32) + topk_w_3d = routing_weights.reshape(B, S, K).astype(jnp.float32) + # tex.ep_prepare/dispatch's partition only folds ep_axis into a replicated + # leading dim, not the outer dp/fsdp axes, so a replicated topk_idx makes + # each rank see B/ep rows (not B/num_procs) and overrun the bootstrap-sized + # send buffer. Pin both routing tensors to the (outer, ep) leading sharding + # so per-rank token counts match max_tokens_per_rank. + topk_idx_3d = jax.lax.with_sharding_constraint(topk_idx_3d, NamedSharding(mesh, ep3_spec)) + topk_w_3d = jax.lax.with_sharding_constraint(topk_w_3d, NamedSharding(mesh, ep3_spec)) + + # ---------------- TE EP dispatch (global view) ---------------- + cfg = tex.EpLayerConfig( + top_k=K, + dispatch_output_per_expert_alignment=slots_per_expert, ) - output_spec = P(batch_pspec_axis, None, None) - aux_spec = P() - ctx_spec = _build_ctx_specs( - ep_axis, - batch_pspec_axis, - backend=permutation_backend, - ep_active=True, - has_bias=has_bias, - has_expert_bias=has_expert_bias, - aux_loss_enabled=(aux_loss_coeff > 0.0), - align_size=align_size, + token_counts, handle_mem = tex.ep_prepare(cfg, topk_idx_3d) + recv_tokens, recv_topk_weights = tex.ep_dispatch_fwd( + cfg, handle_mem, topk_idx_3d, x, topk_w_3d, recv_pr + ) + recv_tokens = jax.lax.with_sharding_constraint(recv_tokens, NamedSharding(mesh, ep3_spec)) + recv_topk_weights = jax.lax.with_sharding_constraint( + recv_topk_weights, NamedSharding(mesh, ep2_spec) + ) + + # ---------------- FFN (per-shard via shard_map) ---------------- + has_bias = wi_0_bias is not None + kernel_spec = P(ep_axis, None, None) + bias_spec = P(ep_axis, None) if has_bias else None + # token_counts is the per-shard (1, num_local_experts) padded + # per-expert count from ep_prepare; piped into _ffn_fwd_per_shard + # as the grouped_gemm group_sizes so cuBLAS skips both 0-token + # experts and the trailing overalloc tail. + ffn_in_specs = (ep3_spec, ep2_spec, ep2_spec, kernel_spec, kernel_spec, kernel_spec) + ffn_in_args = [recv_tokens, recv_topk_weights, token_counts, wi_0, wi_1, wo] + if has_bias: + ffn_in_specs = ffn_in_specs + (bias_spec, bias_spec, bias_spec) + ffn_in_args.extend([wi_0_bias, wi_1_bias, wo_bias]) + + # FFN residuals live entirely on the local ep rank, so the leading + # "experts" / "rows" dims map to P() (already shard-local). wi is + # fused via jnp.concatenate along the trailing (output) axis + # (see _ffn_fwd_per_shard for rationale), so the residual is a + # single 3D casted_wi_rhs_trans of shape + # (num_local_experts, hidden, 2*H_inter). local_group_sizes is + # now per-shard dynamic (= per-shard token_counts), so its + # residual spec mirrors ep2_spec (one row per ep rank). + residuals_spec = ( + P(), # casted_sorted_x_lhs_trans + P(ep_axis, None, None), # casted_wi_rhs_trans + P(), # gate_proj_out + P(), # up_proj_out + P(), # casted_intermediate_lhs_trans + P(ep_axis, None, None), # casted_wo_rhs_trans + ep2_spec, # local_group_sizes (1, num_local_experts) per shard ) + out_specs = (ep3_spec, residuals_spec) - _fsdp_sizes: Tuple[int, ...] = tuple(mesh.shape[ax] for ax in data_parallelism_axes) - - def _shardmap_body(captured_local): - return _body_fwd( - captured_local, - **body_kwargs, - ep_active=True, - fsdp_sizes=_fsdp_sizes, - num_ep=num_ep, - num_experts_local=num_experts_local, - recv_buffer_rows=recv_buffer_rows, + def _body(*args): + if has_bias: + (r_tok, r_w, tc, w0, w1, w_o, w0b, w1b, wob) = args + else: + (r_tok, r_w, tc, w0, w1, w_o) = args + w0b = w1b = wob = None + # NOTE: tex.ep_dispatch_fwd's NCCL EP HT path leaves the recv + # buffer uninitialised on fully-empty-receiver ranks (and at + # padded slots on partially-loaded ranks). We don't need a + # zero-init guard here anymore because: + # 1. ``tc`` (per-expert padded counts) is plumbed into + # grouped_gemm as group_sizes, so cuBLAS skips both + # 0-token experts and the trailing overalloc tail. + # 2. The per-group wgrad masks in _ffn_bwd_per_shard zero + # ``d_wo`` / ``d_wi_combined`` slices for 0-token-globally + # experts (cuBLAS skips size_g==0 groups without + # zero-filling, which would otherwise leak NaN into the + # user's optimizer). + # 3. All other downstream consumers (ep_combine, + # ep_dispatch_bwd) are handle_mem-aware and read only + # valid positions. + # If a future caller adds a non-group-aware reader of r_tok + # (e.g. an inspect probe over the full recv tile), re-add the + # ``jax.lax.cond(jnp.any(r_w != 0), identity, zeros_like)`` + # guard here. + return _ffn_fwd_per_shard( + r_tok, + r_w, + tc, + w0, + w1, + w_o, + w0b, + w1b, + wob, + num_local_experts=num_local_experts, + slots_per_expert=slots_per_expert, + activation_type=activation_type, + apply_topk_weights_early=apply_topk_weights_early, ) - output, aux_loss, ctx = shard_map( - _shardmap_body, + expert_outputs, ffn_residuals = shard_map( + _body, mesh=mesh, - in_specs=(in_specs,), - out_specs=(output_spec, aux_spec, ctx_spec), + in_specs=ffn_in_specs, + out_specs=out_specs, check_rep=False, - )(captured) + )(*ffn_in_args) + expert_outputs = jax.lax.with_sharding_constraint(expert_outputs, NamedSharding(mesh, ep3_spec)) + + # ---------------- TE EP combine (global view) ---------------- + out_partition_spec = (batch_pspec_axis, None, None) + if apply_topk_weights_early: + # expert_outputs is already weighted upstream. + output = tex.ep_combine_fwd( + cfg, + handle_mem, + expert_outputs, + num_local_tokens=(B, S), + out_partition_spec=out_partition_spec, + ) + else: + # IEEE 754: NaN * 0 = NaN, so a multiplicative mask cannot kill + # the NaNs ep_dispatch_fwd leaves at padded slots of recv_tokens + # (they ride through the FFN into expert_outputs at the same + # padded positions): mean=NaN on expert_outputs[padded] then + # propagates into the combine output when the kernel's read + # pattern overlaps the padded region. Use jnp.where to overwrite + # padded positions with a literal 0 before combine. + w = recv_topk_weights[..., None].astype(expert_outputs.dtype) + mask_bool = (recv_topk_weights != 0)[..., None] + weighted = jnp.where(mask_bool, expert_outputs * w, jnp.zeros_like(expert_outputs)) + output = tex.ep_combine_fwd( + cfg, + handle_mem, + weighted, + num_local_tokens=(B, S), + out_partition_spec=out_partition_spec, + ) + + ( + casted_sorted_x_lhs_trans, + casted_wi_rhs_trans, + gate_proj_out, + up_proj_out, + casted_intermediate_lhs_trans, + casted_wo_rhs_trans, + local_group_sizes, + ) = ffn_residuals + + ctx = _Ctx( + x=x, + gate_kernel=gate_kernel, + expert_bias=expert_bias, + logits_2d=logits_2d, + saved_scores=saved_scores, + routing_map=routing_map, + cfg=cfg, + handle_mem=handle_mem, + token_counts=token_counts, + recv_topk_weights=recv_topk_weights, + casted_sorted_x_lhs_trans=casted_sorted_x_lhs_trans, + casted_wi_rhs_trans=casted_wi_rhs_trans, + gate_proj_out=gate_proj_out, + up_proj_out=up_proj_out, + casted_intermediate_lhs_trans=casted_intermediate_lhs_trans, + casted_wo_rhs_trans=casted_wo_rhs_trans, + expert_outputs=expert_outputs, + local_group_sizes=local_group_sizes, + aux_const_buf=aux_const_buf, + aux_tokens_per_expert=aux_tokens_per_expert, + aux_saved_scores=aux_saved_scores, + ) static = { - "has_wi_bias": has_bias, - "has_wo_bias": has_bias, - "has_expert_bias": has_expert_bias, + "has_bias": has_bias, "x_shape": x.shape, - "num_experts_local": num_experts_local, - "recv_buffer_rows": recv_buffer_rows, + "recv_pr": recv_pr, } return (output, aux_loss), (ctx, static) @@ -1865,128 +886,272 @@ def _moe_bwd_rule( group_topk, scaling_factor, aux_loss_coeff, - permutation_backend, - align_size, - gate_inside_vjp, ep_axis, data_parallelism_axes, input_axes, gate_kernel_axes, wi_kernel_axes, wo_kernel_axes, - quantizer_sets, dtype, - ctx, - dy_pair, + apply_topk_weights_early, + residuals, + cotangents, ): - ctx, static = ctx # split tensor residuals from static side info - has_wi_bias = static["has_wi_bias"] - has_wo_bias = static["has_wo_bias"] - has_expert_bias = static["has_expert_bias"] - x_shape = static["x_shape"] - num_experts_local = static["num_experts_local"] - recv_buffer_rows = static["recv_buffer_rows"] + """Backward mirror of :func:`_moe_fwd_rule`.""" + del num_groups, group_topk, dtype # captured in residuals / unused in bwd + from jax.experimental.shard_map import shard_map - ep_active = ep_axis is not None - mesh = _get_mesh() if ep_active else None - fsdp_sizes: Tuple[int, ...] = ( - tuple(mesh.shape[ax] for ax in data_parallelism_axes) if ep_active else () - ) - body_kwargs = { - "num_experts": num_experts, - "num_experts_per_tok": num_experts_per_tok, - "activation_type": activation_type, - "score_function": score_function, - "use_pre_softmax": use_pre_softmax, - "num_groups": num_groups, - "group_topk": group_topk, - "scaling_factor": scaling_factor, - "aux_loss_coeff": aux_loss_coeff, - "permutation_backend": permutation_backend, - "align_size": align_size, - "gate_inside_vjp": gate_inside_vjp, - "quantizer_sets": quantizer_sets, - "dtype": dtype, - "ep_axis": ep_axis, - "data_parallelism_axes": data_parallelism_axes, - "fsdp_sizes": fsdp_sizes, - "num_ep": 1 if not ep_active else mesh.shape[ep_axis], - "num_experts_local": num_experts_local, - "recv_buffer_rows": recv_buffer_rows, - "has_wi_bias": has_wi_bias, - "has_wo_bias": has_wo_bias, - "has_expert_bias": has_expert_bias, - "x_shape": x_shape, - } + d_output, d_aux_loss = cotangents - if not ep_active: - grads = _body_bwd(ctx, dy_pair, ep_active=False, **body_kwargs) - # Apply sharding constraints on grads. - grads["gate_kernel"] = with_sharding_constraint_by_logical_axes( - grads["gate_kernel"], gate_kernel_axes - ) - grads["wi_0"] = with_sharding_constraint_by_logical_axes(grads["wi_0"], wi_kernel_axes) - grads["wi_1"] = with_sharding_constraint_by_logical_axes(grads["wi_1"], wi_kernel_axes) - grads["wo"] = with_sharding_constraint_by_logical_axes(grads["wo"], wo_kernel_axes) - grads["inputs"] = with_sharding_constraint_by_logical_axes(grads["inputs"], input_axes) - return _grads_dict_to_tuple(grads, has_wi_bias, has_wo_bias, has_expert_bias) + ctx, static = residuals + has_bias = static["has_bias"] + x_shape = static["x_shape"] + recv_pr = static["recv_pr"] - from jax.experimental.shard_map import shard_map + mesh = _get_mesh() + if mesh is None or mesh.empty: + raise ValueError("moe(...) requires an active jax.sharding.Mesh.") + num_ep = mesh.shape[ep_axis] + dp_size = 1 + for ax in data_parallelism_axes: + dp_size *= mesh.shape[ax] + B, S, _ = x_shape + K = num_experts_per_tok if not data_parallelism_axes: batch_pspec_axis: Any = ep_axis else: - batch_pspec_axis = (ep_axis, *data_parallelism_axes) - ctx_spec = _build_ctx_specs( - ep_axis, - batch_pspec_axis, - backend=permutation_backend, - ep_active=True, - has_bias=has_wi_bias, - has_expert_bias=has_expert_bias, - aux_loss_enabled=(aux_loss_coeff > 0.0), - align_size=align_size, + batch_pspec_axis = (*data_parallelism_axes, ep_axis) + ep3_spec = P(batch_pspec_axis, None, None) + ep2_spec = P(batch_pspec_axis, None) + out_partition_spec = (batch_pspec_axis, None, None) + + # ---------------- Combine bwd (global view) ---------------- + d_output = jax.lax.with_sharding_constraint(d_output, NamedSharding(mesh, ep3_spec)) + grad_pre_combine = tex.ep_combine_bwd(ctx.cfg, ctx.handle_mem, d_output, recv_pr) + grad_pre_combine = jax.lax.with_sharding_constraint( + grad_pre_combine, NamedSharding(mesh, ep3_spec) ) - dy_specs = (P(batch_pspec_axis, None, None), P()) - grads_spec = _build_grads_specs( - ep_axis, batch_pspec_axis, has_bias=has_wi_bias, has_expert_bias=has_expert_bias + + if apply_topk_weights_early: + # combine_fwd consumed already-weighted expert_outputs; the recv_w + # cotangent flows through the early-weighting step inside the FFN bwd. + d_expert_outputs = grad_pre_combine + d_recv_w_from_combine = jnp.zeros_like(ctx.recv_topk_weights) + else: + # ep_dispatch_fwd can land NaN into recv_topk_weights on padded + # slots. Untreated, `(NaN != 0) == True` in IEEE, + # so the multiplicative mask cannot suppress the NaN and it + # propagates through grad_pre_combine * w * mask into d_expert_outputs + # and then into every downstream gradient (gate_kernel ends up + # all-NaN). Sanitize once here. + recv_w_clean = jnp.where(jnp.isnan(ctx.recv_topk_weights), 0, ctx.recv_topk_weights) + w = recv_w_clean[..., None].astype(grad_pre_combine.dtype) + mask_bool = (recv_w_clean != 0)[..., None] + d_expert_outputs = jnp.where( + mask_bool, grad_pre_combine * w, jnp.zeros_like(grad_pre_combine) + ) + # Same masking strategy for the cotangent on recv_topk_weights: + # grad_pre_combine has NaN at padded slots and ctx.expert_outputs + # may too, so the per-element product must be jnp.where'd before + # the sum reduction. + d_recv_w_from_combine = jnp.where( + mask_bool, + grad_pre_combine * ctx.expert_outputs, + jnp.zeros_like(grad_pre_combine), + ).sum(axis=-1) + d_recv_w_from_combine = d_recv_w_from_combine.astype(ctx.recv_topk_weights.dtype) + + # ---------------- FFN bwd (per-shard via shard_map) ---------------- + kernel_spec = P(ep_axis, None, None) + bias_spec = P(ep_axis, None) if has_bias else None + + bwd_in_specs = ( + ep3_spec, # d_expert_outputs + P(), # casted_sorted_x_lhs_trans + P(ep_axis, None, None), # casted_wi_rhs_trans + P(), # gate_proj_out + P(), # up_proj_out + P(), # casted_intermediate_lhs_trans + P(ep_axis, None, None), # casted_wo_rhs_trans + ep2_spec, # local_group_sizes (1, num_local_experts) per shard + ep2_spec, # recv_topk_weights + ) + bwd_in_args = [ + d_expert_outputs, + ctx.casted_sorted_x_lhs_trans, + ctx.casted_wi_rhs_trans, + ctx.gate_proj_out, + ctx.up_proj_out, + ctx.casted_intermediate_lhs_trans, + ctx.casted_wo_rhs_trans, + ctx.local_group_sizes, + ctx.recv_topk_weights, + ] + bwd_out_specs = ( + ep3_spec, # d_sorted_x + ep2_spec, # d_recv_w_from_intermediate + kernel_spec, # d_wi_0 + kernel_spec, # d_wi_1 + kernel_spec, # d_wo + bias_spec if has_bias else None, # d_wi_0_bias + bias_spec if has_bias else None, # d_wi_1_bias + bias_spec if has_bias else None, # d_wo_bias ) - def _bwd_body(ctx_local, dy_local): - return _body_bwd(ctx_local, dy_local, ep_active=True, **body_kwargs) + def _bwd_body(*args): + ( + d_sorted_x_3d, + d_recv_w_3d, + d_wi_0, + d_wi_1, + d_wo, + d_wi_0_bias, + d_wi_1_bias, + d_wo_bias, + ) = _ffn_bwd_per_shard( + *args, + activation_type=activation_type, + apply_topk_weights_early=apply_topk_weights_early, + has_bias=has_bias, + ) + # Weight grads accumulate per-DP-shard inside the body; psum across + # DP axes so each replica sees the full sum (matches out_specs + # P(ep_axis, ...) which is DP-replicated). + if data_parallelism_axes: + dp = tuple(data_parallelism_axes) + d_wi_0 = jax.lax.psum(d_wi_0, axis_name=dp) + d_wi_1 = jax.lax.psum(d_wi_1, axis_name=dp) + d_wo = jax.lax.psum(d_wo, axis_name=dp) + if has_bias: + d_wi_0_bias = jax.lax.psum(d_wi_0_bias, axis_name=dp) + d_wi_1_bias = jax.lax.psum(d_wi_1_bias, axis_name=dp) + d_wo_bias = jax.lax.psum(d_wo_bias, axis_name=dp) + return ( + d_sorted_x_3d, + d_recv_w_3d, + d_wi_0, + d_wi_1, + d_wo, + d_wi_0_bias, + d_wi_1_bias, + d_wo_bias, + ) - grads = shard_map( + ( + d_sorted_x, + d_recv_w_from_intermediate, + d_wi_0, + d_wi_1, + d_wo, + d_wi_0_bias, + d_wi_1_bias, + d_wo_bias, + ) = shard_map( _bwd_body, mesh=mesh, - in_specs=(ctx_spec, dy_specs), - out_specs=grads_spec, + in_specs=bwd_in_specs, + out_specs=bwd_out_specs, check_rep=False, - )(ctx, dy_pair) + )( + *bwd_in_args + ) - grads["gate_kernel"] = with_sharding_constraint_by_logical_axes( - grads["gate_kernel"], gate_kernel_axes + d_recv_w_total = d_recv_w_from_combine + d_recv_w_from_intermediate + + # ---------------- Dispatch bwd (global view) ---------------- + d_sorted_x = jax.lax.with_sharding_constraint(d_sorted_x, NamedSharding(mesh, ep3_spec)) + d_recv_w_total = jax.lax.with_sharding_constraint(d_recv_w_total, NamedSharding(mesh, ep2_spec)) + d_x_from_dispatch, d_topk_w = tex.ep_dispatch_bwd( + ctx.cfg, + ctx.handle_mem, + d_sorted_x, + d_recv_w_total, + num_local_tokens=(B, S), + out_partition_spec=out_partition_spec, + ) + + # ---------------- Routing bwd (global view) ---------------- + # The cotangent on routing_weights is a sparse scatter into sparse_probs + # at the selected_experts indices. + selected_experts = jnp.argsort(ctx.routing_map, axis=-1)[..., -K:] + d_topk_w_flat = d_topk_w.reshape(-1, K) + d_sparse_probs = jnp.zeros(ctx.routing_map.shape, dtype=d_topk_w_flat.dtype) + d_sparse_probs = d_sparse_probs.at[ + jnp.arange(ctx.routing_map.shape[0])[:, None], selected_experts + ].set(d_topk_w_flat) + + d_logits_2d = tex.fused_topk_with_score_function_bwd( + ctx.routing_map, + ctx.saved_scores, + d_sparse_probs.astype(ctx.saved_scores.dtype), + topk=K, + use_pre_softmax=use_pre_softmax, + scaling_factor=scaling_factor, + score_function=score_function, + compute_aux_scores=False, ) - grads["wi_0"] = with_sharding_constraint_by_logical_axes(grads["wi_0"], wi_kernel_axes) - grads["wi_1"] = with_sharding_constraint_by_logical_axes(grads["wi_1"], wi_kernel_axes) - grads["wo"] = with_sharding_constraint_by_logical_axes(grads["wo"], wo_kernel_axes) - grads["inputs"] = with_sharding_constraint_by_logical_axes(grads["inputs"], input_axes) - return _grads_dict_to_tuple(grads, has_wi_bias, has_wo_bias, has_expert_bias) + # ---------------- Aux loss bwd (global view, replicated) ---------------- + # Reverse the fwd's all-gather/aux pipeline: aux_loss_bwd produces + # d_aux_probs, then topk_bwd(compute_aux_scores=True) produces the + # extra d_logits contribution. The replicated tensor adds into the + # T-sharded routing-side d_logits via JAX's normal broadcast. + if aux_loss_coeff > 0.0: + T_global = ctx.logits_2d.shape[0] + d_aux_loss_scalar = d_aux_loss.reshape(()).astype(jnp.float32) + d_aux_probs = tex.fused_moe_aux_loss_bwd( + ctx.aux_const_buf, + ctx.aux_tokens_per_expert.astype(jnp.int32), + d_aux_loss_scalar, + num_tokens=int(T_global), + ) + # routing_map is ignored by the kernel when compute_aux_scores=True, + # so pass a zero placeholder of the right shape/dtype. + zero_routing_map = jnp.zeros(ctx.aux_saved_scores.shape, dtype=ctx.routing_map.dtype) + d_logits_aux = tex.fused_topk_with_score_function_bwd( + zero_routing_map, + ctx.aux_saved_scores, + d_aux_probs.astype(ctx.aux_saved_scores.dtype), + topk=K, + use_pre_softmax=False, + scaling_factor=1.0, + score_function=score_function, + compute_aux_scores=True, + ) + d_logits_2d = d_logits_2d + d_logits_aux.astype(d_logits_2d.dtype) + + # ---------------- Gate bwd (global view) ---------------- + d_gate_logits = d_logits_2d.reshape(B, S, num_experts) + gate_kernel_cast = ctx.gate_kernel.astype(ctx.x.dtype) + d_x_from_gate = jnp.einsum("bse,he->bsh", d_gate_logits, gate_kernel_cast) + d_gate_kernel = jnp.einsum("bsh,bse->he", ctx.x, d_gate_logits).astype(ctx.gate_kernel.dtype) + d_x = d_x_from_gate + d_x_from_dispatch + + # Pin output grads to the declared logical axes so downstream + # optimizers see consistent shardings. + d_x = with_sharding_constraint_by_logical_axes(d_x, input_axes) + d_gate_kernel = with_sharding_constraint_by_logical_axes(d_gate_kernel, gate_kernel_axes) + d_wi_0 = with_sharding_constraint_by_logical_axes(d_wi_0, wi_kernel_axes) + d_wi_1 = with_sharding_constraint_by_logical_axes(d_wi_1, wi_kernel_axes) + d_wo = with_sharding_constraint_by_logical_axes(d_wo, wo_kernel_axes) + + # expert_bias has no learnable bwd path through fused_topk: the + # primitive's bwd returns None for the bias slot. Match that with a + # zero cotangent of the right shape so custom_vjp's arity check + # passes. + d_expert_bias = jnp.zeros_like(ctx.expert_bias) -def _grads_dict_to_tuple( - grads: dict, has_wi_bias: bool, has_wo_bias: bool, has_expert_bias: bool -) -> Tuple: - """Pack the body_bwd's grads dict into the positional tuple JAX expects.""" return ( - grads["inputs"], - grads["gate_kernel"], - grads["wi_0"], - grads["wi_1"], - grads["wo"], - grads.get("wi_0_bias") if has_wi_bias else None, - grads.get("wi_1_bias") if has_wi_bias else None, - grads.get("wo_bias") if has_wo_bias else None, - grads.get("expert_bias") if has_expert_bias else None, + d_x, + d_gate_kernel, + d_wi_0, + d_wi_1, + d_wo, + d_wi_0_bias if has_bias else None, + d_wi_1_bias if has_bias else None, + d_wo_bias if has_bias else None, + d_expert_bias, ) @@ -1995,7 +1160,7 @@ def _grads_dict_to_tuple( # ============================================================================= -@partial(jax.custom_vjp, nondiff_argnums=tuple(range(9, 29))) +@partial(jax.custom_vjp, nondiff_argnums=tuple(range(9, 26))) def _moe( x, gate_kernel, @@ -2015,23 +1180,16 @@ def _moe( group_topk, scaling_factor, aux_loss_coeff, - permutation_backend, - align_size, - gate_inside_vjp, ep_axis, data_parallelism_axes, input_axes, gate_kernel_axes, wi_kernel_axes, wo_kernel_axes, - quantizer_sets, dtype, + apply_topk_weights_early, ): - # Call in `_moe`'s own signature order to match what JAX will pass - # the fwd rule via ``_argnums_partial``. See the comment block at - # the top of ``_moe_fwd_rule`` for why this differs from - # ``_moe_bwd_rule``'s convention. - output_pair, _ = _moe_fwd_rule( + primal, _ = _moe_fwd_rule( x, gate_kernel, wi_0, @@ -2050,19 +1208,16 @@ def _moe( group_topk, scaling_factor, aux_loss_coeff, - permutation_backend, - align_size, - gate_inside_vjp, ep_axis, data_parallelism_axes, input_axes, gate_kernel_axes, wi_kernel_axes, wo_kernel_axes, - quantizer_sets, dtype, + apply_topk_weights_early, ) - return output_pair + return primal _moe.defvjp(_moe_fwd_rule, _moe_bwd_rule) @@ -2079,56 +1234,106 @@ def moe( wo_bias: Optional[jnp.ndarray] = None, expert_bias: Optional[jnp.ndarray] = None, *, - # Architecture num_experts: int, num_experts_per_tok: int, activation_type: str = "silu", - # Routing score_function: Union[str, ScoreFunction] = "softmax", use_pre_softmax: bool = False, num_groups: Optional[int] = None, group_topk: Optional[int] = None, scaling_factor: float = 1.0, aux_loss_coeff: float = 0.0, - # Permutation - permutation_backend: PermutationBackend = PermutationBackend.PURE_JAX, - align_size: int = 0, - # Gate placement (Phuong: "perhaps as an option") - gate_inside_vjp: bool = True, - # Parallelism (resolved by caller from MeshResource) - ep_axis: Optional[str] = None, + apply_topk_weights_early: bool = False, + ep_axis: str, data_parallelism_axes: Tuple[str, ...] = (), - # Logical axes for sharding constraints input_axes: Tuple[Optional[str], ...] = (), gate_kernel_axes: Tuple[Optional[str], ...] = (), wi_kernel_axes: Tuple[Optional[str], ...] = ("exp", "embed", "mlp"), wo_kernel_axes: Tuple[Optional[str], ...] = ("exp", "mlp", "embed"), - # Quantization - quantizer_sets: Tuple[QuantizerSet, QuantizerSet, QuantizerSet] = ( - noop_quantizer_set, - noop_quantizer_set, - noop_quantizer_set, - ), dtype: jnp.dtype = jnp.float32, ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: - """Run a full MoE block under a single fused custom_vjp. + """Run a full MoE block under a single fused custom_vjp on the TE EP path. + + Returns ``(output, aux_loss)``. ``aux_loss`` is ``None`` when + ``aux_loss_coeff == 0`` and a 0-d scalar otherwise. - Parameters and return are documented at the call site of - ``_MoEBlock.__call__``. See module docstring for design rationale. + Parameters + ---------- + expert_bias : Optional[jnp.ndarray] + ``[num_experts]`` learnable router bias added before the top-k + when ``score_function='sigmoid'``. Pass ``None`` to disable. + The bias has no gradient through the top-k primitive itself (it + only steers expert selection); a zero cotangent is returned for + it. + aux_loss_coeff : float + Per-step expert-load-balance loss coefficient. ``0.0`` (default) + disables the aux loss entirely. When non-zero, an extra + all-gather over the routing-side logits is inserted so the + ``fused_moe_aux_loss`` kernel sees a global ``[T_global, E]`` + view; this lives off the dispatch critical path. + + Note that the per-expert dispatch-slot alignment is fixed internally + at 128 tokens (``_ALIGN_SIZE``); see that constant's docstring for + rationale and how to extend if a future recipe needs >128. + + Axis-name parameters: + + * ``ep_axis`` and ``data_parallelism_axes`` are *physical mesh + axis names* -- they index ``jax.sharding.Mesh.shape`` directly + (to compute ``num_ep`` / ``dp_size`` and to construct + ``P((dp..., ep), None, None)`` for the per-shard + ``jax.lax.with_sharding_constraint`` calls that JAX requires + to refer to real mesh axes). + * ``input_axes``, ``gate_kernel_axes``, ``wi_kernel_axes``, + ``wo_kernel_axes`` are *logical axis names* (e.g. + ``"batch"``, ``"embed"``, ``"mlp"``, ``"exp"``) -- they get + resolved via the active Flax logical-axis rules and consumed + by ``with_sharding_constraint_by_logical_axes``. They are + ``Optional[str]`` tuples so a rule of ``None`` means + "replicated on this axis". + + Logical-axis support for ``ep_axis`` / ``data_parallelism_axes`` + is intentionally out of scope: the EP comm-group construction + (``dp_color = rank // ep_size``) and the bootstrap signature + check both require concrete integer sizes, so a logical name + would have to be resolved to a physical one anyway before any + EP primitive is called. If a downstream pipeline needs to plumb + logical names all the way to ``moe()``, do the rule lookup at + the call site. + + See module docstring for the rest of the parameter semantics and the + surrounding design rationale. """ - if not isinstance(permutation_backend, PermutationBackend): - raise TypeError( - f"permutation_backend must be a PermutationBackend, got {permutation_backend!r}" - ) - if permutation_backend is PermutationBackend.TRITON: - _require_triton() - # Normalize string score_function ("softmax" / "sigmoid") to the - # ScoreFunction enum once here. The underlying primitive - # ``tex.fused_topk_with_score_function_fwd`` expects an int-coercible - # value (the enum has integer .value), and the public router wrapper - # we bypass also normalizes here. score_function = _validate_score_function(score_function) + # Enforce ((outer_dp..., ep), None, None) on inbound activations. The + # EP comm groups consecutive global ranks (dp_color = rank // ep_size), + # so ep MUST be innermost in the partition spec. Soft re-pin: free if + # upstream already matches, single reshard otherwise. + mesh = _get_mesh() + if mesh is None or mesh.empty: + raise ValueError("moe(...) requires an active jax.sharding.Mesh.") + expected_leading: Any = (*data_parallelism_axes, ep_axis) if data_parallelism_axes else ep_axis + expected_spec = P(expected_leading, None, None) + actual_spec = getattr(getattr(x, "sharding", None), "spec", None) + if actual_spec is not None and tuple(actual_spec) != tuple(expected_spec): + warnings.warn( + f"moe(...): inbound x sharding {actual_spec} does not match expected " + f"{expected_spec}; inserting a reshard. Apply " + "jax.lax.with_sharding_constraint upstream to avoid this overhead.", + UserWarning, + stacklevel=2, + ) + x = _with_sharding_constraint_cast_bwd(x, NamedSharding(mesh, expected_spec)) + + # custom_vjp can't trace through None args; lower expert_bias to an + # empty shape-(0,) tensor that fused_topk_with_score_function treats + # as "no bias". + if expert_bias is None: + expert_bias_arg = jnp.zeros((0,), dtype=jnp.float32) + else: + expert_bias_arg = expert_bias + output, aux_loss = _moe( x, gate_kernel, @@ -2138,27 +1343,24 @@ def moe( wi_0_bias, wi_1_bias, wo_bias, - expert_bias, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - activation_type=activation_type, - score_function=score_function, - use_pre_softmax=use_pre_softmax, - num_groups=num_groups, - group_topk=group_topk, - scaling_factor=scaling_factor, - aux_loss_coeff=aux_loss_coeff, - permutation_backend=permutation_backend, - align_size=align_size, - gate_inside_vjp=gate_inside_vjp, - ep_axis=ep_axis, - data_parallelism_axes=data_parallelism_axes, - input_axes=input_axes, - gate_kernel_axes=gate_kernel_axes, - wi_kernel_axes=wi_kernel_axes, - wo_kernel_axes=wo_kernel_axes, - quantizer_sets=quantizer_sets, - dtype=dtype, + expert_bias_arg, + num_experts, + num_experts_per_tok, + activation_type, + score_function, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + float(aux_loss_coeff), + ep_axis, + data_parallelism_axes, + input_axes, + gate_kernel_axes, + wi_kernel_axes, + wo_kernel_axes, + dtype, + apply_topk_weights_early, ) if aux_loss_coeff <= 0.0: aux_loss = None diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index da527fdf18..2e8e611fa3 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -331,7 +331,12 @@ class MeshResource: fsdp_resource: Axis name for full-sharded data parallelism, default is None pp_resource: Axis name for pipeline parallelism (layer sharding), default is None cp_resource: Axis name for context parallelism (sequence sharding), default is None - ep_resource: Axis name for expert parallelism (MoE expert sharding), default is None + ep_resource: Axis name for expert parallelism. Dispatch input tokens + must be sharded on their leading dim by ``ep_resource`` (alone or + compound with ``dp_resource`` / ``fsdp_resource`` as outer, e.g. + ``PartitionSpec(("dp", "ep"), None, None)``). Dispatch output + ``[ep_size, recv_capacity, H]`` is always sharded by ``ep_resource`` + on the leading ``ep_size`` dim. """ dp_resource: str = None @@ -474,3 +479,8 @@ def dp_or_fsdp_axis_size(): dp_size = get_mesh_axis_size(global_mesh_resource().dp_resource) fsdp_size = get_mesh_axis_size(global_mesh_resource().fsdp_resource) return dp_size if dp_size > 1 else fsdp_size + + +def ep_axis_size(): + """Get the size of the dispatch/EP axis (ep_resource). Returns 1 if unset.""" + return get_mesh_axis_size(global_mesh_resource().ep_resource)