Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
ee7dfff
Expert Parallelism: common C API + NCCL EP v0.1 backend
phu0ngng May 22, 2026
de64b7c
Expert Parallelism: persistent ncclEpHandle cache with allow_handle_m…
phu0ngng May 23, 2026
a234333
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2026
ef387fe
Build: NCCL_HOME discovery supports Debian/Ubuntu multiarch lib paths
phu0ngng May 27, 2026
9535f87
bump NCCL
phu0ngng May 27, 2026
f79914f
Expert Parallelism: require token_dtype in NVTEEpGroupConfig and enfo…
phu0ngng May 28, 2026
865536f
Expert Parallelism: document ep_comm lifetime, v0.1 single-GPU scope,…
phu0ngng May 28, 2026
e659dd8
Expert Parallelism: drop version label from initialize scope note
phu0ngng May 28, 2026
d95f610
Expert Parallelism: pointer-keyed LRU handle cache; drop register_lay…
phu0ngng Jun 8, 2026
52b7e94
bump nccl to latest v0.1
phu0ngng Jun 8, 2026
c35db03
tests/cpp_distributed: drop unused NCCL EP header include path
phu0ngng Jun 8, 2026
5656578
common/ep: fold nvte_ep_* stubs into ep_api.cpp under #if NVTE_WITH_N…
phu0ngng Jun 8, 2026
97cb6ef
common/ep: dlopen libnccl_ep.so so libtransformer_engine.so loads wit…
phu0ngng Jun 9, 2026
32444dc
common/ep: add BUILD_RPATH=NCCL_EP_LIB_DIR for in-tree dev builds
phu0ngng Jun 9, 2026
adab125
common/ep: polish ep.h docstrings; drop unused NVTE_CHECK_NCCL from l…
phu0ngng Jun 9, 2026
a8aa363
common/ep: expose zero_copy in NVTEEpGroupConfig; map to NCCL_EP_ZERO…
phu0ngng Jun 9, 2026
19b7209
tests/cpp_distributed: exercise zero_copy=ON in EPZeroCopyTest.Identi…
phu0ngng Jun 9, 2026
9bee544
tests/cpp_distributed: tighten EPZeroCopyTest comments
phu0ngng Jun 9, 2026
859c2ec
common/CMakeLists: correct NCCL resolution comment (not bundled with …
phu0ngng Jun 9, 2026
3120085
common/CMakeLists: shorten NCCL/GIN headers comments
phu0ngng Jun 9, 2026
6f9620f
setup,common: bundle libnccl_ep.so.0 next to libtransformer_engine.so…
phu0ngng Jun 9, 2026
f9e9382
.gitmodules: drop nccl branch pin and align indentation with other su…
phu0ngng Jun 9, 2026
e02028f
setup: gate NCCL EP build on arch >= 90 or native; drop sm_90 fallback
phu0ngng Jun 9, 2026
049a986
common,setup,tests: discover nccl.h via find_path/NCCL_INCLUDE_DIR; d…
phu0ngng Jun 9, 2026
54b815b
common/ep: simplify make_nccl_ep_tensor to take NVTETensor and option…
phu0ngng Jun 9, 2026
1fd404c
common/ep: move te_dtype_to_nccl_dtype out of EPBackend into anon nam…
phu0ngng Jun 9, 2026
0ac254c
common/ep: reword multicast check; drop NVLS framing
phu0ngng Jun 9, 2026
2575b78
common,tests: replace unicode em-dash and box-drawing chars with ASCI…
phu0ngng Jun 9, 2026
132d75e
bump nccl to latest v0.1
phu0ngng Jun 9, 2026
1253d98
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2026
3c6fdb4
nccl commit to 2.31.0a4-1
phu0ngng Jun 9, 2026
84c8425
common/CMakeLists: point NCCL_EP_INCLUDE_DIR at build/include staged …
phu0ngng Jun 9, 2026
28197bf
common/CMakeLists: clarify NCCL EP missing-header instructions
phu0ngng Jun 10, 2026
3f12667
common/ep: use int64_t instead of long for handle-cache size env (cpp…
phu0ngng Jun 10, 2026
7187a26
common/ep: fix dangling sizes pointer in make_nccl_ep_tensor (NVTESha…
phu0ngng Jun 10, 2026
b6f10e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2026
b9a3c70
jax: add EP bindings on pointer-keyed cache with EpLayerConfig and bf…
phu0ngng Jun 8, 2026
48b0557
jax/ep: drop topk_weights from ep_combine; caller must pre-multiply
phu0ngng Jun 8, 2026
86855e2
tests/jax/ep: mask uninitialized recv_tokens tail in dispatch_vjp
phu0ngng Jun 8, 2026
9b470a9
examples/jax/ep: add ep_bench.py + run_ep_bench.sh
phu0ngng Jun 8, 2026
6251f80
examples/jax/ep: ep_moe.py runs --iters fwd+bwd steps (default 3)
phu0ngng Jun 8, 2026
f04fb84
jax/ep: tighten sharding contract, drop helpers, route bwd through TE…
phu0ngng Jun 8, 2026
5b5e059
jax/ep: derive ep_size and num_ep_groups from active mesh in ep_boots…
phu0ngng Jun 9, 2026
6a7ecf0
examples/jax/ep: rename ep_handle to layer_cfg in ep_moe.py (matches …
phu0ngng Jun 10, 2026
8460faa
jax/ep: add primitive docstrings and silence missing-kwoa false posit…
phu0ngng Jun 10, 2026
50fa69d
jax/ep: apply black formatting (pre-commit hook output)
phu0ngng Jun 10, 2026
7c7d6f7
build_tools/jax: gate NCCL EP on NVTE_BUILD_WITH_NCCL_EP (default on)…
phu0ngng Jun 10, 2026
40f1f5e
jax/ep: collapse 5 FFI attr structs into single EpConfig
phu0ngng Jun 10, 2026
238b21a
jax/ep: dedup _ep_outer_axis, normalize _ep_spec_ok, unify outer_abst…
phu0ngng Jun 10, 2026
9176eed
jax/ep: apply clang-format and silence pylint unused-arg in lowering
phu0ngng Jun 10, 2026
e5a2072
[JAX] Resync onto upstream PR #3036, restore TE-EP-only MoE block
tdophung Jun 10, 2026
a14f20d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2026
fe4069b
tests/jax: trim TE-EP MoE suite (drop bootstrap, flax-wrapper, bias-z…
tdophung Jun 11, 2026
ff6cf3d
jax/router: fix two bwd custom_partitioning bugs (aux-loss rank, topk…
tdophung Jun 11, 2026
d3695db
jax/ep: skip size-1 dp/fsdp axis in _ep_outer_axis
tdophung Jun 11, 2026
a5d3f15
jax/flax: realign _MoEBlock with post-resync moe() signature
tdophung Jun 11, 2026
57c1615
jax/moe: plumb token_counts to grouped_gemm and zero 0-token wgrad sl…
tdophung Jun 11, 2026
6f1d3e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 11, 2026
5e524d0
jax/flax,tests: rename use_bias/use_expert_bias for symmetry (PR #3116)
tdophung Jun 11, 2026
3c24517
jax/moe: address PR #3116 review feedback (hardcode align + expand in…
tdophung Jun 11, 2026
fe44697
jax/moe: strip PR-response framing from comments; drop sparse_probs N…
tdophung Jun 12, 2026
dc55faa
jax/moe: drop fp32 island around silu+multiply (fwd, bwd, reference)
tdophung Jun 12, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,6 @@ tensor_dumps/
artifacts/
.DS_Store
.claude/

# NCCL EP shared library staged by setup.py for wheel packaging.
/transformer_engine/libnccl_ep.so*
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "3rdparty/nccl"]
path = 3rdparty/nccl
url = https://github.com/NVIDIA/nccl.git
1 change: 1 addition & 0 deletions 3rdparty/nccl
Submodule nccl added at 808d24
3 changes: 3 additions & 0 deletions build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ def setup_jax_extension(
if bool(int(os.getenv("NVTE_WITH_CUBLASMP", 0))):
cxx_flags.append("-DNVTE_WITH_CUBLASMP")

if bool(int(os.getenv("NVTE_BUILD_WITH_NCCL_EP", "1"))):
cxx_flags.append("-DNVTE_WITH_NCCL_EP")

# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension

Expand Down
308 changes: 308 additions & 0 deletions examples/jax/ep/bench/ep_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX EP perf bench — dispatch/combine (raw fwd + custom_vjp wrapper) on a 1DP x EP mesh.

One process per GPU; launch via run_ep_bench.sh. Each stage is jitted and
timed separately with NVTX ranges (prepare runs once outside the loop).
Rank-0 prints mean wall in us; nsys / --xplane attribute kernels per stage.
"""

import argparse
import os
import sys
import time

import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from transformer_engine.jax.cpp_extensions import ep as tex_ep
from transformer_engine.jax.ep import EpLayerConfig, ep_bootstrap, ep_dispatch, ep_combine
from transformer_engine.jax.sharding import MeshResource, global_shard_guard


def _parse_args():
p = argparse.ArgumentParser(description="TE-JAX EP perf bench (dispatch_fwd + combine_fwd)")
p.add_argument("--coordinator-address", required=True)
p.add_argument("--process-id", type=int, required=True)
p.add_argument("--num-processes", type=int, required=True)
p.add_argument("--tokens-per-rank", type=int, default=8192)
p.add_argument("--hidden", type=int, default=7168)
p.add_argument("--top-k", type=int, default=8)
p.add_argument("--num-experts", type=int, default=256)
p.add_argument("--dp-size", type=int, default=1)
p.add_argument("--warmup", type=int, default=2)
p.add_argument("--iters", type=int, default=10)
p.add_argument(
"--max-num-sms",
type=int,
default=0,
help="Max SMs for dispatch / combine / preprocess kernels (0 = auto).",
)
p.add_argument(
"--mode-label",
default=None,
help="Optional label suffix for NVTX range names so nsys can partition kernels.",
)
p.add_argument(
"--second-step",
action="store_true",
help=(
"Time only the 2nd step (1 warmup iter, 1 timed iter). Use to isolate "
"JIT-cache-warm-but-no-steady-state-batching overhead from steady-state perf."
),
)
p.add_argument(
"--xplane",
default=None,
help="If set, jax.profiler dumps an XPlane trace into this dir (rank 0 only).",
)
return p.parse_args()


def _distributed_init(args):
jax.distributed.initialize(
coordinator_address=args.coordinator_address,
num_processes=args.num_processes,
process_id=args.process_id,
local_device_ids=[args.process_id],
)


def _build_mesh(args):
n = args.num_processes
assert n % args.dp_size == 0
ep = n // args.dp_size
devs = np.asarray(jax.devices()).reshape(args.dp_size, ep)
return Mesh(devs, ("dp", "ep")), ep


def _make_inputs(args, ep_size):
"""Round-robin routing, uniform top-k weights; each rank sees ``args.tokens_per_rank`` tokens."""
n = args.num_processes
T = args.tokens_per_rank
H = args.hidden
K = args.top_k
E = args.num_experts
del ep_size

topk_idx = np.empty((n * T, K), dtype=np.int32)
for t in range(n * T):
for k in range(K):
topk_idx[t, k] = (t * K + k) % E
topk_idx = jnp.asarray(topk_idx)
topk_w = jnp.full((n * T, K), 1.0 / K, dtype=jnp.float32)
tokens = jnp.asarray(
np.random.default_rng(0).standard_normal((n * T, H), dtype=np.float32) * 0.5,
dtype=jnp.bfloat16,
)
return tokens, topk_idx, topk_w


def main():
args = _parse_args()
_distributed_init(args)
mesh, ep_size = _build_mesh(args)
mr = MeshResource(dp_resource="dp", ep_resource="ep")
rank = args.process_id

local_experts = args.num_experts // ep_size
recv_capacity_per_rank = args.num_processes * args.tokens_per_rank * args.top_k // 2

if rank == 0:
print(
f"[ep_bench] world={args.num_processes} dp={args.dp_size} ep={ep_size}"
f" T={args.tokens_per_rank} H={args.hidden} K={args.top_k}"
f" E={args.num_experts} (local={local_experts}) recv_pr={recv_capacity_per_rank}"
+ (f" mode={args.mode_label}" if args.mode_label else ""),
flush=True,
)

nvtx_suffix = f"[{args.mode_label}]" if args.mode_label else ""

in_spec = PartitionSpec(("dp", "ep"), None)
ep_spec_3d = PartitionSpec(("dp", "ep"), None, None)
ep_spec_2d = PartitionSpec(("dp", "ep"), None)
out_spec = (("dp", "ep"), None)
T_global = args.num_processes * args.tokens_per_rank

with mesh, global_shard_guard(mr):
ep_bootstrap(
world_size=args.num_processes,
rank=rank,
num_experts=args.num_experts,
max_tokens_per_rank=args.tokens_per_rank,
recv_capacity_per_rank=recv_capacity_per_rank,
hidden_dim=args.hidden,
max_num_sms=args.max_num_sms,
)

tokens, topk_idx, topk_w = _make_inputs(args, ep_size)
idx_s = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(mesh, in_spec))
tok_s = jax.lax.with_sharding_constraint(tokens, NamedSharding(mesh, in_spec))
w_s = jax.lax.with_sharding_constraint(topk_w, NamedSharding(mesh, in_spec))

cfg = EpLayerConfig(top_k=args.top_k, dispatch_output_per_expert_alignment=16)

@jax.jit
def run_prepare(idx):
tc, hm = tex_ep.ep_prepare(cfg, idx)
return tc, hm

@jax.jit
def run_dispatch(hm, idx, toks, w):
recv_t, recv_w = tex_ep.ep_dispatch_fwd(cfg, hm, idx, toks, w, recv_capacity_per_rank)
recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(mesh, ep_spec_3d))
recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(mesh, ep_spec_2d))
return recv_t, recv_w

@jax.jit
def run_dispatch_vjp(idx, toks, w):
recv_t, recv_w, _hm, _tc = ep_dispatch(cfg, idx, toks, w, recv_capacity_per_rank)
recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(mesh, ep_spec_3d))
recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(mesh, ep_spec_2d))
return recv_t, recv_w

@jax.jit
def run_combine(hm, recv_t):
out = tex_ep.ep_combine_fwd(
cfg,
hm,
recv_t,
T_global,
out_partition_spec=out_spec,
)
return out

@jax.jit
def run_combine_vjp(hm, tc, recv_t):
# ep_combine is unweighted; bench feeds expert_out directly (caller
# would otherwise pre-multiply by recv_topk_weights + mask).
out = ep_combine(cfg, hm, tc, recv_t, T_global, out_sharding=out_spec)
return out

tc, handle_mem = run_prepare(idx_s)
tc.block_until_ready()
handle_mem.block_until_ready()

recv_t0, recv_w0 = run_dispatch(handle_mem, idx_s, tok_s, w_s)
recv_t0.block_until_ready()
recv_w0.block_until_ready()

warmup_n = 1 if args.second_step else args.warmup
iters_n = 1 if args.second_step else args.iters

for _ in range(warmup_n):
r, _rw = run_dispatch(handle_mem, idx_s, tok_s, w_s)
r.block_until_ready()
o = run_combine(handle_mem, r)
o.block_until_ready()
run_dispatch_vjp(idx_s, tok_s, w_s)[0].block_until_ready()
run_combine_vjp(handle_mem, tc, recv_t0).block_until_ready()

if args.xplane and rank == 0:
os.makedirs(args.xplane, exist_ok=True)
jax.profiler.start_trace(args.xplane)

try:
import nvtx as _nvtx

def _push(name):
_nvtx.push_range(message=name)

def _pop():
_nvtx.pop_range()

except ImportError:

def _push(name):
pass

def _pop():
pass

def _time_stage_wall_us(name, fn):
# First timed iter still carries an autotune outlier even after JIT
# warmup; run iters_n + 1, drop iter 0 from the average, and push
# the NVTX range AFTER iter 0 so nsys' nvtx_kern_sum excludes the
# outlier too.
total_ns = 0
counted = 0
for i in range(iters_n + 1):
if i == 1:
_push(f"{name}{nvtx_suffix}")
t0 = time.perf_counter_ns()
fn()
dt = time.perf_counter_ns() - t0
if i == 0:
continue
total_ns += dt
counted += 1
_pop()
return total_ns / 1e3 / counted

def _do_dispatch():
r, _ = run_dispatch(handle_mem, idx_s, tok_s, w_s)
r.block_until_ready()

def _do_dispatch_vjp():
r, _ = run_dispatch_vjp(idx_s, tok_s, w_s)
r.block_until_ready()

def _do_combine():
o = run_combine(handle_mem, recv_t0)
o.block_until_ready()

def _do_combine_vjp():
o = run_combine_vjp(handle_mem, tc, recv_t0)
o.block_until_ready()

d_wall_us = _time_stage_wall_us("dispatch_fwd", _do_dispatch)
dv_wall_us = _time_stage_wall_us("ep_dispatch_vjp", _do_dispatch_vjp)
c_wall_us = _time_stage_wall_us("combine_fwd", _do_combine)
cv_wall_us = _time_stage_wall_us("ep_combine_vjp", _do_combine_vjp)

if args.xplane and rank == 0:
jax.profiler.stop_trace()

if rank == 0:
label = f" [{args.mode_label}]" if args.mode_label else ""
print("", flush=True)
print(f"| stage | mean wall (us){label} |", flush=True)
print("|-------------------|---------------:|", flush=True)
print(f"| dispatch_fwd | {d_wall_us:14.1f} |", flush=True)
print(f"| ep_dispatch_vjp | {dv_wall_us:14.1f} |", flush=True)
print(f"| combine_fwd | {c_wall_us:14.1f} |", flush=True)
print(f"| ep_combine_vjp | {cv_wall_us:14.1f} |", flush=True)
print(f"| (dispatch vjp-fwd)| {dv_wall_us - d_wall_us:14.1f} |", flush=True)
print(f"| (combine vjp-fwd)| {cv_wall_us - c_wall_us:14.1f} |", flush=True)
print("", flush=True)
print(
"[ep_bench] kernel breakout: see nsys nvtx_kern_sum output below "
"(produced by run_ep_bench.sh --nsys).",
flush=True,
)

# Under nsys: force cudaDeviceReset() to drain CUPTI's in-process kernel
# records into the .nsys-rep, then os._exit to skip JAX's coord-service
# watchdog. The reset crashes during NCCL EP context teardown, so we only
# take this path when the launcher opts in via EP_BENCH_FLUSH_CUPTI=1.
if os.environ.get("EP_BENCH_FLUSH_CUPTI", "0") == "1":
try:
import ctypes

cudart = ctypes.CDLL("libcudart.so")
cudart.cudaDeviceSynchronize()
cudart.cudaDeviceReset()
except Exception:
pass
time.sleep(0.5)
sys.stdout.flush()
sys.stderr.flush()
os._exit(0)


if __name__ == "__main__":
main()
Loading
Loading