Skip to content

feat(amd): EP intra-node normal and low-latency kernels with mori shmem#164

Merged
XG-zheng merged 3 commits into
ByteDance-Seed:mainfrom
jhchouuu:mori_ep_intranode
Apr 22, 2026
Merged

feat(amd): EP intra-node normal and low-latency kernels with mori shmem#164
XG-zheng merged 3 commits into
ByteDance-Seed:mainfrom
jhchouuu:mori_ep_intranode

Conversation

@jhchouuu

@jhchouuu jhchouuu commented Mar 27, 2026

Copy link
Copy Markdown
Collaborator

Summary

  1. Implement EP intra-node dispatch/combine kernels using mori shmem P2P (putmem_signal_warp) on AMD MI325X
  2. Add Low Latency EP v1 (raw all-to-all) and v2 (online FP8 quant + combine with topk weighted reduce)
  3. Fix shfl_up/shfl_down_sync implementation and golden reference calculation in test_language_extra.py
  4. Fix mixed-bitwidth ld/st implementation and add kernel test coverage
  5. Update mori submodule to main with JIT bitcode compilation, replacing manual hipcc/llvm-link build
  6. Simplify build_mori_shmem.sh to use mori JIT (mori.ir.bitcode.find_bitcode())
  7. Add AlgoBW and BusBW metrics to EP A2A benchmark output
  8. Add CI tests for EP A2A (correctness + perf), LL v2 (correctness + perf M=64/128)

Co-authored-by: Wu, Yutong yutong.wu@amd.com

- Implement EP intra-node dispatch/combine kernels using mori shmem P2P (putmem_signal_warp) on AMD MI325X
- Add Low Latency EP v1 (raw all-to-all) and v2 (online FP8 quant + combine with topk weighted reduce)
- Fix shfl_up/shfl_down_sync implementation and golden reference calculation in test_language_extra.py
- Fix mixed-bitwidth ld/st implementation and add kernel test coverage
- Update mori submodule to main with JIT bitcode compilation, replacing manual hipcc/llvm-link build
- Simplify `build_mori_shmem.sh` to use mori JIT (`mori.ir.bitcode.find_bitcode()`)
- Add AlgoBW and BusBW metrics to EP A2A benchmark output
- Add CI tests for EP A2A (correctness + perf), LL v2 (correctness + perf M=64/128)

---------

Co-authored-by: Wu, Yutong <yutong.wu@amd.com>
Copilot AI review requested due to automatic review settings March 27, 2026 05:53
@jhchouuu

Copy link
Copy Markdown
Collaborator Author

Intranode Performance (MI325X, 8 GPU, N=7168, G=256, topk=8, bench_iters=10, dispatch_grid=512, combine_grid=304):

PT = PyTorch (all_to_all baseline), TD = Triton-dist, speedup = PT / TD

Default mode

M PT disp PT comb TD disp TD comb disp comb total
4096 2.19ms 3.31ms 1.80ms 1.92ms 1.22x 1.72x 1.48x
3780 1.99ms 3.05ms 1.62ms 1.52ms 1.23x 2.01x 1.60x
3206 2.06ms 3.12ms 1.66ms 1.59ms 1.25x 1.96x 1.59x
2638 2.06ms 3.13ms 1.65ms 1.70ms 1.24x 1.84x 1.55x
2395 1.93ms 3.04ms 1.55ms 1.52ms 1.25x 2.00x 1.62x
2264 1.80ms 2.83ms 1.47ms 1.37ms 1.22x 2.06x 1.63x

enable-local-combine mode

M PT disp PT comb TD disp TD comb disp comb total
4096 2.18ms 3.29ms 1.80ms 2.37ms 1.21x 1.38x 1.31x
3780 1.98ms 3.07ms 1.62ms 1.92ms 1.22x 1.59x 1.43x
3206 2.07ms 3.16ms 1.66ms 2.00ms 1.25x 1.58x 1.43x
2638 2.05ms 3.14ms 1.65ms 2.02ms 1.24x 1.55x 1.41x
2395 1.92ms 2.94ms 1.55ms 1.87ms 1.24x 1.58x 1.42x
2264 1.79ms 2.79ms 1.48ms 1.74ms 1.21x 1.60x 1.41x

with-scatter-indices mode

M PT disp PT comb TD disp TD comb disp comb total
4096 2.20ms 3.30ms 1.80ms 1.93ms 1.22x 1.71x 1.47x
3780 1.97ms 3.08ms 1.62ms 1.52ms 1.22x 2.03x 1.60x
3206 2.08ms 3.17ms 1.65ms 1.59ms 1.26x 2.00x 1.62x
2638 2.07ms 3.33ms 1.65ms 1.72ms 1.25x 1.94x 1.60x
2395 1.94ms 2.98ms 1.55ms 1.52ms 1.25x 1.96x 1.59x
2264 1.81ms 2.78ms 1.47ms 1.36ms 1.23x 2.04x 1.62x

@jhchouuu

Copy link
Copy Markdown
Collaborator Author

EP Low Latency v2 Performance (MI325X, 8 GPU, N=7168, G=256, topk=8, 100 iters)

M=64

Tokens PT dispatch PT combine PT total TD dispatch TD combine TD total disp comb total
64 0.692ms 0.526ms 1.218ms 0.103ms 0.421ms 0.525ms 6.69x 1.25x 2.32x
60 0.671ms 0.547ms 1.218ms 0.100ms 0.386ms 0.486ms 6.72x 1.42x 2.51x
58 0.674ms 0.564ms 1.238ms 0.097ms 0.365ms 0.462ms 6.95x 1.55x 2.68x
38 0.687ms 0.556ms 1.243ms 0.098ms 0.365ms 0.463ms 7.00x 1.52x 2.68x
52 0.678ms 0.529ms 1.208ms 0.098ms 0.375ms 0.473ms 6.91x 1.41x 2.55x
55 0.667ms 0.575ms 1.243ms 0.089ms 0.345ms 0.435ms 7.47x 1.67x 2.86x

M=128

Tokens PT dispatch PT combine PT total TD dispatch TD combine TD total disp comb total
128 0.677ms 0.555ms 1.232ms 0.153ms 0.662ms 0.816ms 4.41x 0.84x 1.51x
64 0.672ms 0.552ms 1.224ms 0.134ms 0.540ms 0.674ms 5.02x 1.02x 1.82x
120 0.671ms 0.585ms 1.256ms 0.149ms 0.604ms 0.752ms 4.52x 0.97x 1.67x
118 0.681ms 0.591ms 1.272ms 0.147ms 0.619ms 0.766ms 4.64x 0.95x 1.66x
89 0.683ms 0.602ms 1.285ms 0.145ms 0.605ms 0.750ms 4.70x 0.99x 1.71x
78 0.675ms 0.584ms 1.259ms 0.136ms 0.574ms 0.710ms 4.97x 1.02x 1.77x

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds AMD ROCm intra-node Expert Parallel (EP) all-to-all kernels (including low-latency variants) built on MoRI SHMEM, updates MoRI integration to rely on JIT bitcode discovery/compilation, and expands AMD test + CI coverage for correctness and performance.

Changes:

  • Add AMD EP intra-node all-to-all layer + kernels, plus low-latency EP v2 dispatch/combine implementation.
  • Switch MoRI SHMEM device bitcode handling to JIT-based discovery and update runtime lookup utilities.
  • Add/extend AMD tests (language extras, MoRI SHMEM API/BW, EP A2A + LL v2) and wire them into AMD CI.

Reviewed changes

Copilot reviewed 30 out of 30 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
scripts/launch_amd.sh Make Triton cache directory configurable via TRITON_CACHE_DIR.
scripts/build_mori_shmem.sh Build/install MoRI and obtain device bitcode via MoRI JIT find_bitcode().
python/triton_dist/utils.py Add MoRI SHMEM tensor/barrier helpers and JIT fallback for libdevice bitcode path.
python/triton_dist/test/common/test_language_extra.py Add mixed-bitwidth ld/st test coverage.
python/triton_dist/test/amd/test_mori_shmem_bw.py Adjust docstring to raw string form.
python/triton_dist/test/amd/test_mori_shmem_api.py Add put+signal+wait test path for MoRI SHMEM device API.
python/triton_dist/test/amd/test_language_extra.py Fix shuffle golden refs and improve failure diagnostics.
python/triton_dist/test/amd/test_ep_ll_a2a.py Add LL v2 correctness + perf driver.
python/triton_dist/test/amd/test_ep_a2a.py Add EP A2A correctness + perf driver and BW metrics.
python/triton_dist/test/amd/test_all_to_all.py Add all-to-all benchmark/verification driver for AMD.
python/triton_dist/test/amd/ep_a2a_utils.py Add FP8 quant/dequant + torch reference dispatch/combine for LL v2.
python/triton_dist/layers/amd/ep_ll_a2a_layer.py Add LL v2 layer wrapper around new dispatch/combine kernels.
python/triton_dist/layers/amd/ep_a2a_layer.py Add AMD intra-node EP A2A layer built on MoRI SHMEM symmetric buffers.
python/triton_dist/layers/amd/init.py Export AMD EP layers (best-effort imports for optional components).
python/triton_dist/language/extra/libshmem_device.py Extend MoRI SHMEM dispatch surface (putmem_signal qp_id, uint64 wait, cmp enums).
python/triton_dist/language/extra/hip/libmori_shmem_device.py Add MoRI SHMEM device externs (warp/block put, put+signal, barrier, wait shims).
python/triton_dist/language/extra/hip/language_extra.py Fix/load/store/atom symbol naming for mixed bitwidth; add AMD shuffle + warp atomic helper.
python/triton_dist/language/core.py Adjust extern elementwise arg checking/broadcasting behavior.
python/triton_dist/kernels/amd/low_latency_all_to_all_v2.py Implement LL v2 dispatch/combine kernels + context creation using MoRI SHMEM.
python/triton_dist/kernels/amd/low_latency_all_to_all.py Add LL all-to-all (v1-style) kernel and context using MoRI SHMEM.
python/triton_dist/kernels/amd/ep_a2a_intra_node.py Implement intra-node EP dispatch/combine kernels and split all-gather for AMD.
python/triton_dist/kernels/amd/ep_a2a.py Add AMD bincount helper and re-export intra-node EP kernels.
python/triton_dist/kernels/amd/common_ops.py Route barrier to mori_shmem vs rocshmem based on backend selection.
python/triton_dist/kernels/amd/init.py Export EP/LL kernels and make gemm helpers optional with warnings.
python/triton_dist/jit.py Only run MoRI SHMEM module init hook when the kernel module contains mori symbols.
python/triton_dist/amd_utils.py Add safer fallback if UUID→device mapping misses.
lib/Conversion/TritonDistributedToLLVM/AMD/BuiltinFuncToLLVMExt.cpp Update parsing to handle bitwidth-prefixed hip builtin names.
.gitmodules Update MoRI submodule tracking (remove custom branch pin).
.github/workflows/amd-ci.yml Add AMD CI coverage for EP A2A + LL v2 and expand MoRI SHMEM tests.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +34 to 54
try:
from .allgather_gemm import ag_gemm_intra_node, create_ag_gemm_intra_node_context
from .gemm_reduce_scatter import gemm_rs_intra_node, create_gemm_rs_intra_node_context
except ImportError as e:
import warnings
warnings.warn(f"allgather_gemm/gemm_reduce_scatter unavailable (pyrocshmem not installed): {e}")

__all__ = [
"ag_gemm_intra_node", "create_ag_gemm_intra_node_context", "gemm_rs_intra_node", "create_gemm_rs_intra_node_context"
"ag_gemm_intra_node",
"create_ag_gemm_intra_node_context",
"gemm_rs_intra_node",
"create_gemm_rs_intra_node_context",
"kernel_dispatch_token_intra_node",
"kernel_skipped_token_local_dispatch_intra_node",
"kernel_skipped_token_inplace_local_combine_intra_node",
"kernel_combine_token_intra_node",
"get_ag_splits_and_recv_offset_for_dispatch_intra_node",
"create_all_to_all_context",
"fast_all_to_all",
"all_to_all_post_process",
]

Copilot AI Mar 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module wraps allgather_gemm/gemm_reduce_scatter imports in a try/except, but __all__ always includes ag_gemm_intra_node, gemm_rs_intra_node, etc. If those optional imports fail, from triton_dist.kernels.amd import * will raise because the names won't exist. Consider only adding these symbols to __all__ when the import succeeds (or define stubs).

Copilot uses AI. Check for mistakes.
Comment thread scripts/launch_amd.sh
# export AMD_LOG_LEVEL=5 # for debug

mkdir -p triton_cache
mkdir -p ${TRITON_CACHE_DIR}

Copilot AI Mar 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mkdir -p ${TRITON_CACHE_DIR} is unquoted, so paths containing spaces or glob characters can break. Quote the variable (and consider using -- to guard against leading -).

Suggested change
mkdir -p ${TRITON_CACHE_DIR}
mkdir -p -- "${TRITON_CACHE_DIR}"

Copilot uses AI. Check for mistakes.
Comment on lines +582 to +583
"sync_grid",
"smid",

Copilot AI Mar 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__all__ contains duplicate entries ("sync_grid" and "smid" are listed twice). This is harmless at runtime but makes exports harder to audit; deduplicate the list.

Suggested change
"sync_grid",
"smid",

Copilot uses AI. Check for mistakes.
Comment on lines +517 to +525
# Cast scatter indices from uint64 to int32
scatter_idx_i32 = ep_a2a_layout_desc.token_dst_scatter_idx.to(torch.int32)
combine_grid = (min(ep_a2a_layout_desc.num_dispatch_token_cur_rank, self.combine_grid_size), )
kernel_combine_token_intra_node[combine_grid](
ep_a2a_layout_desc.num_dispatch_token_cur_rank,
input,
combine_intra_node_out_buf,
ep_a2a_layout_desc.topk_indices_tensor,
scatter_idx_i32,

Copilot AI Mar 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cast creates a new tensor when token_dst_scatter_idx was allocated as uint64 (the --with-scatter-indices off path). Since this runs on every combine(), it can add noticeable overhead for large max_tokens. Prefer allocating token_dst_scatter_idx as int32 in the dispatch path (or only casting when dtype != int32).

Suggested change
# Cast scatter indices from uint64 to int32
scatter_idx_i32 = ep_a2a_layout_desc.token_dst_scatter_idx.to(torch.int32)
combine_grid = (min(ep_a2a_layout_desc.num_dispatch_token_cur_rank, self.combine_grid_size), )
kernel_combine_token_intra_node[combine_grid](
ep_a2a_layout_desc.num_dispatch_token_cur_rank,
input,
combine_intra_node_out_buf,
ep_a2a_layout_desc.topk_indices_tensor,
scatter_idx_i32,
# Cast scatter indices from uint64 to int32 only if needed to avoid extra allocations
scatter_idx = ep_a2a_layout_desc.token_dst_scatter_idx
if scatter_idx.dtype != torch.int32:
scatter_idx = scatter_idx.to(torch.int32)
combine_grid = (min(ep_a2a_layout_desc.num_dispatch_token_cur_rank, self.combine_grid_size), )
kernel_combine_token_intra_node[combine_grid](
ep_a2a_layout_desc.num_dispatch_token_cur_rank,
input,
combine_intra_node_out_buf,
ep_a2a_layout_desc.topk_indices_tensor,
scatter_idx,

Copilot uses AI. Check for mistakes.

if ! cp -f "$TEMP_DIR/libmori_shmem_device.bc" "$DST_PATH/"; then
echo "Error: Mori bitcode copy failed." >&2
BC_PATH=$(python3 -c "from mori.ir.bitcode import find_bitcode; print(find_bitcode())" 2>&1 | tail -1)

Copilot AI Mar 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BC_PATH is captured via python3 ... 2>&1 | tail -1. This can mask the real Python error/traceback (and the pipeline exit status comes from tail, not python3), making failures harder to debug. Prefer capturing stdout directly (no tail) and failing fast on a non-zero Python exit code while printing the full error output.

Suggested change
BC_PATH=$(python3 -c "from mori.ir.bitcode import find_bitcode; print(find_bitcode())" 2>&1 | tail -1)
BC_PATH=$(python3 -c "from mori.ir.bitcode import find_bitcode; print(find_bitcode())") || {
echo "Error: JIT bitcode compilation failed (python3 returned non-zero exit status)." >&2
exit 1
}

Copilot uses AI. Check for mistakes.
torch.distributed.barrier()
self.dispatch_output_buf = mori_shmem_create_tensor([dispatch_recv_tokens, self.ep_config.hidden],
self.ep_config.token_dtype)
self.weight_recv_buf = mori_shmem_create_tensor([dispatch_recv_tokens, self.ep_config.topk],

Copilot AI Mar 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight_recv_buf is allocated as 1-D in create() (one weight per dispatched token), but reallocated here as a 2-D [dispatch_recv_tokens, topk] tensor. This will break pointer arithmetic in kernel_dispatch_token_intra_node (which does weight_recv_buf + store_idx) and downstream copies. Reallocate weight_recv_buf with the same 1-D shape as the initial allocation.

Suggested change
self.weight_recv_buf = mori_shmem_create_tensor([dispatch_recv_tokens, self.ep_config.topk],
# Keep weight_recv_buf 1-D (one weight per dispatched token) to match initial allocation
# and kernel_dispatch_token_intra_node pointer arithmetic.
self.weight_recv_buf = mori_shmem_create_tensor([dispatch_recv_tokens],

Copilot uses AI. Check for mistakes.
Comment on lines +158 to +181
self.send_buf = mori_shmem_create_tensor((max_m, hidden), dtype)
self.recv_buf = mori_shmem_create_tensor((WORLD_SIZE * max_m * 2, hidden), dtype)
self.scale_send_buf = mori_shmem_create_tensor((max_m, ), scale_dtype)
self.scale_recv_buf = mori_shmem_create_tensor((WORLD_SIZE * max_m * 2, ), scale_dtype)
self.split_send_buf = mori_shmem_create_tensor((num_tot_experts, ), torch.int32)
self.split_recv_buf = mori_shmem_create_tensor((num_tot_experts * 2, ), torch.int32)
self.signal_buf = mori_shmem_create_tensor((WORLD_SIZE * 2, ), MORI_SHMEM_SIGNAL_DTYPE)

self.max_m = max_m
self.hidden = hidden
self.dtype = dtype
self.scale_dtype = scale_dtype
self.ele_size = dtype_size_in_bytes(self.dtype)
self.scale_ele_size = dtype_size_in_bytes(self.scale_dtype)

self.num_tot_experts = num_tot_experts
self.experts_per_rank = experts_per_rank

self.WORLD_SIZE = WORLD_SIZE
self.rank = rank

# start from 1, becase the initial values of signal buffer is 0
self.call_count = 1
self.MOD_VALUE = 1000000

Copilot AI Mar 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

signal_buf is used for signal_wait_until(..., call_count) and call_count starts from 1 assuming the signal buffer is initialized to 0, but the buffer is never zero-initialized after allocation. If mori_shmem_create_tensor returns uninitialized memory, this can lead to races or deadlocks. Initialize self.signal_buf to 0 (and consider a barrier after init) before the first kernel launch.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings April 13, 2026 02:48

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 30 out of 30 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +512 to +519
def __shfl_up_sync_i32(value, delta):
"""Shuffle up: each lane reads from (laneid - delta), clamped to 0."""
_lid = laneid()
src_lane = _lid - tl.cast(delta, tl.int32)
if src_lane < 0:
src_lane = _lid
byte_offset = src_lane * 4
return _ds_bpermute_b32(tl.cast(value, tl.int32), byte_offset)
Comment on lines +523 to +531
def __shfl_down_sync_i32(value, delta):
"""Shuffle down: each lane reads from (laneid + delta), clamped to 63."""
WARP_SIZE: tl.constexpr = 64
_lid = laneid()
src_lane = _lid + tl.cast(delta, tl.int32)
if src_lane >= WARP_SIZE:
src_lane = _lid
byte_offset = src_lane * 4
return _ds_bpermute_b32(tl.cast(value, tl.int32), byte_offset)
Comment on lines +558 to +561
_laneid = laneid()
x = tl.cast(0, barrier_ptr.dtype.element_ty)
if _laneid == 0:
x = atomic_add(barrier_ptr, value, scope, semantic)
Comment on lines +34 to +46
try:
from .allgather_gemm import ag_gemm_intra_node, create_ag_gemm_intra_node_context
from .gemm_reduce_scatter import gemm_rs_intra_node, create_gemm_rs_intra_node_context
except ImportError as e:
import warnings
warnings.warn(f"allgather_gemm/gemm_reduce_scatter unavailable (pyrocshmem not installed): {e}")

__all__ = [
"ag_gemm_intra_node", "create_ag_gemm_intra_node_context", "gemm_rs_intra_node", "create_gemm_rs_intra_node_context"
"ag_gemm_intra_node",
"create_ag_gemm_intra_node_context",
"gemm_rs_intra_node",
"create_gemm_rs_intra_node_context",
"kernel_dispatch_token_intra_node",
@@ -0,0 +1,233 @@
################################################################################

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks similar to ep_utils on the nvidia side. can we directly reuse the torch ref for AMD?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will attempt and complete this task in the subsequent steps.


if ! cp -f "$TEMP_DIR/libmori_shmem_device.bc" "$DST_PATH/"; then
echo "Error: Mori bitcode copy failed." >&2
BC_PATH=$(python3 -c "from mori.ir.bitcode import find_bitcode; print(find_bitcode())" 2>&1 | tail -1)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mori already support the jit bitcode feature?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, mori now support jit bitcode compile

@XG-zheng XG-zheng merged commit 2b4c24b into ByteDance-Seed:main Apr 22, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants