feat(amd): EP intra-node normal and low-latency kernels with mori shmem#164
Conversation
- Implement EP intra-node dispatch/combine kernels using mori shmem P2P (putmem_signal_warp) on AMD MI325X - Add Low Latency EP v1 (raw all-to-all) and v2 (online FP8 quant + combine with topk weighted reduce) - Fix shfl_up/shfl_down_sync implementation and golden reference calculation in test_language_extra.py - Fix mixed-bitwidth ld/st implementation and add kernel test coverage - Update mori submodule to main with JIT bitcode compilation, replacing manual hipcc/llvm-link build - Simplify `build_mori_shmem.sh` to use mori JIT (`mori.ir.bitcode.find_bitcode()`) - Add AlgoBW and BusBW metrics to EP A2A benchmark output - Add CI tests for EP A2A (correctness + perf), LL v2 (correctness + perf M=64/128) --------- Co-authored-by: Wu, Yutong <yutong.wu@amd.com>
|
Intranode Performance (MI325X, 8 GPU, N=7168, G=256, topk=8, bench_iters=10, dispatch_grid=512, combine_grid=304): PT = PyTorch (all_to_all baseline), TD = Triton-dist, speedup = PT / TD Default mode
enable-local-combine mode
with-scatter-indices mode
|
|
EP Low Latency v2 Performance (MI325X, 8 GPU, N=7168, G=256, topk=8, 100 iters) M=64
M=128
|
There was a problem hiding this comment.
Pull request overview
This PR adds AMD ROCm intra-node Expert Parallel (EP) all-to-all kernels (including low-latency variants) built on MoRI SHMEM, updates MoRI integration to rely on JIT bitcode discovery/compilation, and expands AMD test + CI coverage for correctness and performance.
Changes:
- Add AMD EP intra-node all-to-all layer + kernels, plus low-latency EP v2 dispatch/combine implementation.
- Switch MoRI SHMEM device bitcode handling to JIT-based discovery and update runtime lookup utilities.
- Add/extend AMD tests (language extras, MoRI SHMEM API/BW, EP A2A + LL v2) and wire them into AMD CI.
Reviewed changes
Copilot reviewed 30 out of 30 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| scripts/launch_amd.sh | Make Triton cache directory configurable via TRITON_CACHE_DIR. |
| scripts/build_mori_shmem.sh | Build/install MoRI and obtain device bitcode via MoRI JIT find_bitcode(). |
| python/triton_dist/utils.py | Add MoRI SHMEM tensor/barrier helpers and JIT fallback for libdevice bitcode path. |
| python/triton_dist/test/common/test_language_extra.py | Add mixed-bitwidth ld/st test coverage. |
| python/triton_dist/test/amd/test_mori_shmem_bw.py | Adjust docstring to raw string form. |
| python/triton_dist/test/amd/test_mori_shmem_api.py | Add put+signal+wait test path for MoRI SHMEM device API. |
| python/triton_dist/test/amd/test_language_extra.py | Fix shuffle golden refs and improve failure diagnostics. |
| python/triton_dist/test/amd/test_ep_ll_a2a.py | Add LL v2 correctness + perf driver. |
| python/triton_dist/test/amd/test_ep_a2a.py | Add EP A2A correctness + perf driver and BW metrics. |
| python/triton_dist/test/amd/test_all_to_all.py | Add all-to-all benchmark/verification driver for AMD. |
| python/triton_dist/test/amd/ep_a2a_utils.py | Add FP8 quant/dequant + torch reference dispatch/combine for LL v2. |
| python/triton_dist/layers/amd/ep_ll_a2a_layer.py | Add LL v2 layer wrapper around new dispatch/combine kernels. |
| python/triton_dist/layers/amd/ep_a2a_layer.py | Add AMD intra-node EP A2A layer built on MoRI SHMEM symmetric buffers. |
| python/triton_dist/layers/amd/init.py | Export AMD EP layers (best-effort imports for optional components). |
| python/triton_dist/language/extra/libshmem_device.py | Extend MoRI SHMEM dispatch surface (putmem_signal qp_id, uint64 wait, cmp enums). |
| python/triton_dist/language/extra/hip/libmori_shmem_device.py | Add MoRI SHMEM device externs (warp/block put, put+signal, barrier, wait shims). |
| python/triton_dist/language/extra/hip/language_extra.py | Fix/load/store/atom symbol naming for mixed bitwidth; add AMD shuffle + warp atomic helper. |
| python/triton_dist/language/core.py | Adjust extern elementwise arg checking/broadcasting behavior. |
| python/triton_dist/kernels/amd/low_latency_all_to_all_v2.py | Implement LL v2 dispatch/combine kernels + context creation using MoRI SHMEM. |
| python/triton_dist/kernels/amd/low_latency_all_to_all.py | Add LL all-to-all (v1-style) kernel and context using MoRI SHMEM. |
| python/triton_dist/kernels/amd/ep_a2a_intra_node.py | Implement intra-node EP dispatch/combine kernels and split all-gather for AMD. |
| python/triton_dist/kernels/amd/ep_a2a.py | Add AMD bincount helper and re-export intra-node EP kernels. |
| python/triton_dist/kernels/amd/common_ops.py | Route barrier to mori_shmem vs rocshmem based on backend selection. |
| python/triton_dist/kernels/amd/init.py | Export EP/LL kernels and make gemm helpers optional with warnings. |
| python/triton_dist/jit.py | Only run MoRI SHMEM module init hook when the kernel module contains mori symbols. |
| python/triton_dist/amd_utils.py | Add safer fallback if UUID→device mapping misses. |
| lib/Conversion/TritonDistributedToLLVM/AMD/BuiltinFuncToLLVMExt.cpp | Update parsing to handle bitwidth-prefixed hip builtin names. |
| .gitmodules | Update MoRI submodule tracking (remove custom branch pin). |
| .github/workflows/amd-ci.yml | Add AMD CI coverage for EP A2A + LL v2 and expand MoRI SHMEM tests. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| try: | ||
| from .allgather_gemm import ag_gemm_intra_node, create_ag_gemm_intra_node_context | ||
| from .gemm_reduce_scatter import gemm_rs_intra_node, create_gemm_rs_intra_node_context | ||
| except ImportError as e: | ||
| import warnings | ||
| warnings.warn(f"allgather_gemm/gemm_reduce_scatter unavailable (pyrocshmem not installed): {e}") | ||
|
|
||
| __all__ = [ | ||
| "ag_gemm_intra_node", "create_ag_gemm_intra_node_context", "gemm_rs_intra_node", "create_gemm_rs_intra_node_context" | ||
| "ag_gemm_intra_node", | ||
| "create_ag_gemm_intra_node_context", | ||
| "gemm_rs_intra_node", | ||
| "create_gemm_rs_intra_node_context", | ||
| "kernel_dispatch_token_intra_node", | ||
| "kernel_skipped_token_local_dispatch_intra_node", | ||
| "kernel_skipped_token_inplace_local_combine_intra_node", | ||
| "kernel_combine_token_intra_node", | ||
| "get_ag_splits_and_recv_offset_for_dispatch_intra_node", | ||
| "create_all_to_all_context", | ||
| "fast_all_to_all", | ||
| "all_to_all_post_process", | ||
| ] |
There was a problem hiding this comment.
The module wraps allgather_gemm/gemm_reduce_scatter imports in a try/except, but __all__ always includes ag_gemm_intra_node, gemm_rs_intra_node, etc. If those optional imports fail, from triton_dist.kernels.amd import * will raise because the names won't exist. Consider only adding these symbols to __all__ when the import succeeds (or define stubs).
| # export AMD_LOG_LEVEL=5 # for debug | ||
|
|
||
| mkdir -p triton_cache | ||
| mkdir -p ${TRITON_CACHE_DIR} |
There was a problem hiding this comment.
mkdir -p ${TRITON_CACHE_DIR} is unquoted, so paths containing spaces or glob characters can break. Quote the variable (and consider using -- to guard against leading -).
| mkdir -p ${TRITON_CACHE_DIR} | |
| mkdir -p -- "${TRITON_CACHE_DIR}" |
| "sync_grid", | ||
| "smid", |
There was a problem hiding this comment.
__all__ contains duplicate entries ("sync_grid" and "smid" are listed twice). This is harmless at runtime but makes exports harder to audit; deduplicate the list.
| "sync_grid", | |
| "smid", |
| # Cast scatter indices from uint64 to int32 | ||
| scatter_idx_i32 = ep_a2a_layout_desc.token_dst_scatter_idx.to(torch.int32) | ||
| combine_grid = (min(ep_a2a_layout_desc.num_dispatch_token_cur_rank, self.combine_grid_size), ) | ||
| kernel_combine_token_intra_node[combine_grid]( | ||
| ep_a2a_layout_desc.num_dispatch_token_cur_rank, | ||
| input, | ||
| combine_intra_node_out_buf, | ||
| ep_a2a_layout_desc.topk_indices_tensor, | ||
| scatter_idx_i32, |
There was a problem hiding this comment.
This cast creates a new tensor when token_dst_scatter_idx was allocated as uint64 (the --with-scatter-indices off path). Since this runs on every combine(), it can add noticeable overhead for large max_tokens. Prefer allocating token_dst_scatter_idx as int32 in the dispatch path (or only casting when dtype != int32).
| # Cast scatter indices from uint64 to int32 | |
| scatter_idx_i32 = ep_a2a_layout_desc.token_dst_scatter_idx.to(torch.int32) | |
| combine_grid = (min(ep_a2a_layout_desc.num_dispatch_token_cur_rank, self.combine_grid_size), ) | |
| kernel_combine_token_intra_node[combine_grid]( | |
| ep_a2a_layout_desc.num_dispatch_token_cur_rank, | |
| input, | |
| combine_intra_node_out_buf, | |
| ep_a2a_layout_desc.topk_indices_tensor, | |
| scatter_idx_i32, | |
| # Cast scatter indices from uint64 to int32 only if needed to avoid extra allocations | |
| scatter_idx = ep_a2a_layout_desc.token_dst_scatter_idx | |
| if scatter_idx.dtype != torch.int32: | |
| scatter_idx = scatter_idx.to(torch.int32) | |
| combine_grid = (min(ep_a2a_layout_desc.num_dispatch_token_cur_rank, self.combine_grid_size), ) | |
| kernel_combine_token_intra_node[combine_grid]( | |
| ep_a2a_layout_desc.num_dispatch_token_cur_rank, | |
| input, | |
| combine_intra_node_out_buf, | |
| ep_a2a_layout_desc.topk_indices_tensor, | |
| scatter_idx, |
|
|
||
| if ! cp -f "$TEMP_DIR/libmori_shmem_device.bc" "$DST_PATH/"; then | ||
| echo "Error: Mori bitcode copy failed." >&2 | ||
| BC_PATH=$(python3 -c "from mori.ir.bitcode import find_bitcode; print(find_bitcode())" 2>&1 | tail -1) |
There was a problem hiding this comment.
BC_PATH is captured via python3 ... 2>&1 | tail -1. This can mask the real Python error/traceback (and the pipeline exit status comes from tail, not python3), making failures harder to debug. Prefer capturing stdout directly (no tail) and failing fast on a non-zero Python exit code while printing the full error output.
| BC_PATH=$(python3 -c "from mori.ir.bitcode import find_bitcode; print(find_bitcode())" 2>&1 | tail -1) | |
| BC_PATH=$(python3 -c "from mori.ir.bitcode import find_bitcode; print(find_bitcode())") || { | |
| echo "Error: JIT bitcode compilation failed (python3 returned non-zero exit status)." >&2 | |
| exit 1 | |
| } |
| torch.distributed.barrier() | ||
| self.dispatch_output_buf = mori_shmem_create_tensor([dispatch_recv_tokens, self.ep_config.hidden], | ||
| self.ep_config.token_dtype) | ||
| self.weight_recv_buf = mori_shmem_create_tensor([dispatch_recv_tokens, self.ep_config.topk], |
There was a problem hiding this comment.
weight_recv_buf is allocated as 1-D in create() (one weight per dispatched token), but reallocated here as a 2-D [dispatch_recv_tokens, topk] tensor. This will break pointer arithmetic in kernel_dispatch_token_intra_node (which does weight_recv_buf + store_idx) and downstream copies. Reallocate weight_recv_buf with the same 1-D shape as the initial allocation.
| self.weight_recv_buf = mori_shmem_create_tensor([dispatch_recv_tokens, self.ep_config.topk], | |
| # Keep weight_recv_buf 1-D (one weight per dispatched token) to match initial allocation | |
| # and kernel_dispatch_token_intra_node pointer arithmetic. | |
| self.weight_recv_buf = mori_shmem_create_tensor([dispatch_recv_tokens], |
| self.send_buf = mori_shmem_create_tensor((max_m, hidden), dtype) | ||
| self.recv_buf = mori_shmem_create_tensor((WORLD_SIZE * max_m * 2, hidden), dtype) | ||
| self.scale_send_buf = mori_shmem_create_tensor((max_m, ), scale_dtype) | ||
| self.scale_recv_buf = mori_shmem_create_tensor((WORLD_SIZE * max_m * 2, ), scale_dtype) | ||
| self.split_send_buf = mori_shmem_create_tensor((num_tot_experts, ), torch.int32) | ||
| self.split_recv_buf = mori_shmem_create_tensor((num_tot_experts * 2, ), torch.int32) | ||
| self.signal_buf = mori_shmem_create_tensor((WORLD_SIZE * 2, ), MORI_SHMEM_SIGNAL_DTYPE) | ||
|
|
||
| self.max_m = max_m | ||
| self.hidden = hidden | ||
| self.dtype = dtype | ||
| self.scale_dtype = scale_dtype | ||
| self.ele_size = dtype_size_in_bytes(self.dtype) | ||
| self.scale_ele_size = dtype_size_in_bytes(self.scale_dtype) | ||
|
|
||
| self.num_tot_experts = num_tot_experts | ||
| self.experts_per_rank = experts_per_rank | ||
|
|
||
| self.WORLD_SIZE = WORLD_SIZE | ||
| self.rank = rank | ||
|
|
||
| # start from 1, becase the initial values of signal buffer is 0 | ||
| self.call_count = 1 | ||
| self.MOD_VALUE = 1000000 |
There was a problem hiding this comment.
signal_buf is used for signal_wait_until(..., call_count) and call_count starts from 1 assuming the signal buffer is initialized to 0, but the buffer is never zero-initialized after allocation. If mori_shmem_create_tensor returns uninitialized memory, this can lead to races or deadlocks. Initialize self.signal_buf to 0 (and consider a barrier after init) before the first kernel launch.
# Conflicts: # 3rdparty/mori
… nested submodules
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 30 out of 30 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def __shfl_up_sync_i32(value, delta): | ||
| """Shuffle up: each lane reads from (laneid - delta), clamped to 0.""" | ||
| _lid = laneid() | ||
| src_lane = _lid - tl.cast(delta, tl.int32) | ||
| if src_lane < 0: | ||
| src_lane = _lid | ||
| byte_offset = src_lane * 4 | ||
| return _ds_bpermute_b32(tl.cast(value, tl.int32), byte_offset) |
| def __shfl_down_sync_i32(value, delta): | ||
| """Shuffle down: each lane reads from (laneid + delta), clamped to 63.""" | ||
| WARP_SIZE: tl.constexpr = 64 | ||
| _lid = laneid() | ||
| src_lane = _lid + tl.cast(delta, tl.int32) | ||
| if src_lane >= WARP_SIZE: | ||
| src_lane = _lid | ||
| byte_offset = src_lane * 4 | ||
| return _ds_bpermute_b32(tl.cast(value, tl.int32), byte_offset) |
| _laneid = laneid() | ||
| x = tl.cast(0, barrier_ptr.dtype.element_ty) | ||
| if _laneid == 0: | ||
| x = atomic_add(barrier_ptr, value, scope, semantic) |
| try: | ||
| from .allgather_gemm import ag_gemm_intra_node, create_ag_gemm_intra_node_context | ||
| from .gemm_reduce_scatter import gemm_rs_intra_node, create_gemm_rs_intra_node_context | ||
| except ImportError as e: | ||
| import warnings | ||
| warnings.warn(f"allgather_gemm/gemm_reduce_scatter unavailable (pyrocshmem not installed): {e}") | ||
|
|
||
| __all__ = [ | ||
| "ag_gemm_intra_node", "create_ag_gemm_intra_node_context", "gemm_rs_intra_node", "create_gemm_rs_intra_node_context" | ||
| "ag_gemm_intra_node", | ||
| "create_ag_gemm_intra_node_context", | ||
| "gemm_rs_intra_node", | ||
| "create_gemm_rs_intra_node_context", | ||
| "kernel_dispatch_token_intra_node", |
| @@ -0,0 +1,233 @@ | |||
| ################################################################################ | |||
There was a problem hiding this comment.
this looks similar to ep_utils on the nvidia side. can we directly reuse the torch ref for AMD?
There was a problem hiding this comment.
We will attempt and complete this task in the subsequent steps.
|
|
||
| if ! cp -f "$TEMP_DIR/libmori_shmem_device.bc" "$DST_PATH/"; then | ||
| echo "Error: Mori bitcode copy failed." >&2 | ||
| BC_PATH=$(python3 -c "from mori.ir.bitcode import find_bitcode; print(find_bitcode())" 2>&1 | tail -1) |
There was a problem hiding this comment.
mori already support the jit bitcode feature?
There was a problem hiding this comment.
Yes, mori now support jit bitcode compile
Summary
build_mori_shmem.shto use mori JIT (mori.ir.bitcode.find_bitcode())Co-authored-by: Wu, Yutong yutong.wu@amd.com