Skip to content

Commit 2b4c24b

Browse files
jhchouuuisytwu
andauthored
feat(amd): EP intra-node normal and low-latency kernels with mori shmem (#164)
* feat(amd): EP intra-node normal and low-latency kernels with mori shmem - Implement EP intra-node dispatch/combine kernels using mori shmem P2P (putmem_signal_warp) on AMD MI325X - Add Low Latency EP v1 (raw all-to-all) and v2 (online FP8 quant + combine with topk weighted reduce) - Fix shfl_up/shfl_down_sync implementation and golden reference calculation in test_language_extra.py - Fix mixed-bitwidth ld/st implementation and add kernel test coverage - Update mori submodule to main with JIT bitcode compilation, replacing manual hipcc/llvm-link build - Simplify `build_mori_shmem.sh` to use mori JIT (`mori.ir.bitcode.find_bitcode()`) - Add AlgoBW and BusBW metrics to EP A2A benchmark output - Add CI tests for EP A2A (correctness + perf), LL v2 (correctness + perf M=64/128) --------- Co-authored-by: Wu, Yutong <yutong.wu@amd.com> * fix(ci): use non-recursive submodule checkout to avoid pulling mori's nested submodules --------- Co-authored-by: Wu, Yutong <yutong.wu@amd.com>
1 parent bec05d7 commit 2b4c24b

30 files changed

Lines changed: 4607 additions & 209 deletions

.github/workflows/amd-ci.yml

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
- name: Checkout
3333
uses: actions/checkout@v4
3434
with:
35-
submodules: 'recursive'
35+
submodules: 'true'
3636
- name: Build rocmshmem bind
3737
run: |
3838
bash ./shmem/rocshmem_bind/build.sh
@@ -78,7 +78,7 @@ jobs:
7878
- name: Checkout
7979
uses: actions/checkout@v4
8080
with:
81-
submodules: 'recursive'
81+
submodules: 'true'
8282
- name: Build rocmshmem bind
8383
run: |
8484
bash ./shmem/rocshmem_bind/build.sh
@@ -124,14 +124,26 @@ jobs:
124124
- name: Checkout
125125
uses: actions/checkout@v4
126126
with:
127-
submodules: 'recursive'
127+
submodules: 'true'
128128
- name: Build triton-distributed
129129
run: |
130130
pip3 install -e python --verbose --no-build-isolation --use-pep517
131131
- name: Build mori shmem
132132
run: |
133-
bash ./scripts/build_mori_shmem.sh
134-
- name: Mori SHMEM API tests
133+
bash ./scripts/build_mori_shmem.sh
134+
- name: MoRI SHMEM API tests
135135
run: |
136136
bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_mori_shmem_api.py
137-
bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_mori_shmem_bw.py
137+
bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_mori_shmem_bw.py
138+
- name: EP A2A intra-node tests
139+
run: |
140+
bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_ep_a2a.py --check
141+
bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_ep_a2a.py --check --with-scatter-indices
142+
bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_ep_a2a.py --check --enable-local-combine
143+
bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_ep_a2a.py --rounds 3 --bench_iters 10
144+
- name: EP Low Latency v2 tests
145+
run: |
146+
bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_ep_ll_a2a.py --check -M 64 -N 7168 -G 256 --topk 8
147+
bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_ep_ll_a2a.py -M 64 -N 7168 -G 256 --topk 8 --rounds 3
148+
bash ./scripts/launch_amd.sh ./python/triton_dist/test/amd/test_ep_ll_a2a.py -M 128 -N 7168 -G 256 --topk 8 --rounds 3
149+

.gitmodules

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
[submodule "3rdparty/mori"]
55
path = 3rdparty/mori
66
url = https://github.com/ROCm/mori.git
7-
branch = jiahzhou/triton_dis_support
87
[submodule "3rdparty/cutlass"]
98
path = 3rdparty/cutlass
109
url = https://github.com/NVIDIA/cutlass.git

3rdparty/mori

Submodule mori updated 380 files

lib/Conversion/TritonDistributedToLLVM/AMD/BuiltinFuncToLLVMExt.cpp

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,14 @@ class CallOpConversion : public OpRewritePattern<LLVM::CallOp> {
118118
llvm_unreachable("unknown scope string");
119119
};
120120

121+
auto skipBitwidthPrefix =
122+
[](const SmallVector<StringRef> &parts) -> size_t {
123+
if (!parts.empty() &&
124+
llvm::all_of(parts[0], [](char c) { return std::isdigit(c); }))
125+
return 1;
126+
return 0;
127+
};
128+
121129
auto operands = callOp.getOperands();
122130
auto result = callOp.getResult();
123131

@@ -193,10 +201,12 @@ class CallOpConversion : public OpRewritePattern<LLVM::CallOp> {
193201
if (auto maybeParts =
194202
matchPrefixAndSplitRemainder(calleeName, "__triton_hip_load_")) {
195203
auto parts = maybeParts.value();
196-
assert(parts.size() == 2 &&
197-
"expected load function to have 2 parts after prefix");
198-
LLVM::AtomicOrdering memOrder = strToMemoryOrder(parts[0]);
199-
auto scopeStr = strToScope(parts[1]);
204+
size_t idx = skipBitwidthPrefix(parts);
205+
assert(parts.size() - idx == 2 &&
206+
"expected load function to have 2 parts (memOrder, scope) after "
207+
"optional bitwidth prefix");
208+
LLVM::AtomicOrdering memOrder = strToMemoryOrder(parts[idx]);
209+
auto scopeStr = strToScope(parts[idx + 1]);
200210
assert(operands.size() == 1 && "expected load to have 1 operand");
201211

202212
replacementOp = buildAtomicLoad(operands[0], memOrder, scopeStr);
@@ -206,10 +216,12 @@ class CallOpConversion : public OpRewritePattern<LLVM::CallOp> {
206216
else if (auto maybeParts = matchPrefixAndSplitRemainder(
207217
calleeName, "__triton_hip_store_")) {
208218
auto parts = maybeParts.value();
209-
assert(parts.size() == 2 &&
210-
"expected store function to have 2 parts after prefix");
211-
LLVM::AtomicOrdering memOrder = strToMemoryOrder(parts[0]);
212-
auto scopeStr = strToScope(parts[1]);
219+
size_t idx = skipBitwidthPrefix(parts);
220+
assert(parts.size() - idx == 2 &&
221+
"expected store function to have 2 parts (memOrder, scope) after "
222+
"optional bitwidth prefix");
223+
LLVM::AtomicOrdering memOrder = strToMemoryOrder(parts[idx]);
224+
auto scopeStr = strToScope(parts[idx + 1]);
213225
assert(operands.size() == 2 && "expected store to have 2 operands");
214226
buildAtomicStore(operands[1], operands[0], memOrder, scopeStr);
215227
rewriter.eraseOp(callOp);
@@ -220,11 +232,13 @@ class CallOpConversion : public OpRewritePattern<LLVM::CallOp> {
220232
else if (auto maybeParts = matchPrefixAndSplitRemainder(
221233
calleeName, "__triton_hip_atom_add_")) {
222234
auto parts = maybeParts.value();
223-
assert(parts.size() == 2 &&
224-
"expected atomic add function to have 2 parts after prefix");
235+
size_t idx = skipBitwidthPrefix(parts);
236+
assert(parts.size() - idx == 2 &&
237+
"expected atomic add function to have 2 parts (memOrder, scope) "
238+
"after optional bitwidth prefix");
225239
assert(operands.size() == 2 && "expected atomic add to have 2 operands");
226-
LLVM::AtomicOrdering memOrder = strToMemoryOrder(parts[0]);
227-
auto scopeStr = strToScope(parts[1]);
240+
LLVM::AtomicOrdering memOrder = strToMemoryOrder(parts[idx]);
241+
auto scopeStr = strToScope(parts[idx + 1]);
228242
replacementOp =
229243
buildAtomicFetchAdd(operands[0], operands[1], memOrder, scopeStr);
230244
}
@@ -233,12 +247,15 @@ class CallOpConversion : public OpRewritePattern<LLVM::CallOp> {
233247
else if (auto maybeParts = matchPrefixAndSplitRemainder(
234248
calleeName, "__triton_hip_atom_cas_")) {
235249
auto parts = maybeParts.value();
236-
assert(parts.size() == 3 &&
237-
"expected atomic cas function to have 3 parts after prefix");
250+
size_t idx = skipBitwidthPrefix(parts);
251+
assert(parts.size() - idx == 3 &&
252+
"expected atomic cas function to have 3 parts "
253+
"(successOrder, failureOrder, scope) after optional bitwidth "
254+
"prefix");
238255
assert(operands.size() == 3 && "expected atomic cas to have 3 operands");
239-
LLVM::AtomicOrdering successOrdering = strToMemoryOrder(parts[0]);
240-
LLVM::AtomicOrdering failureOrdering = strToMemoryOrder(parts[1]);
241-
auto scopeStr = strToScope(parts[2]);
256+
LLVM::AtomicOrdering successOrdering = strToMemoryOrder(parts[idx]);
257+
LLVM::AtomicOrdering failureOrdering = strToMemoryOrder(parts[idx + 1]);
258+
auto scopeStr = strToScope(parts[idx + 2]);
242259
replacementOp = buildAtomicCompareExchangeStrong(
243260
operands[0], operands[1], operands[2], successOrdering,
244261
failureOrdering, scopeStr);

python/triton_dist/amd_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,14 @@ def _get_amdsmi_device_index(device_id: int | None):
276276
uuid = _get_gpu_uuid(device_id)
277277

278278
uuid_map = {get_uuid_by_physical_device_id(i)[-12:]: i for i in range(get_physical_device_count())}
279-
return uuid_map[uuid[-12:]]
279+
# TODO-rocm fix error
280+
uuid_tail = uuid[-12:]
281+
if uuid_tail not in uuid_map:
282+
warnings.warn(f"UUID mapping miss in _get_amdsmi_device_index: device_id={device_id}, "
283+
f"uuid_tail={uuid_tail}, available_tails={sorted(uuid_map.keys())}. "
284+
f"Fallback to logical device_id.")
285+
return device_id
286+
return uuid_map[uuid_tail]
280287

281288

282289
def get_physical_device_count():

python/triton_dist/jit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ def shmem_kernel_module_init_hook(*args, **kwargs) -> None:
8888
else:
8989
hip.hipGetLastError() # Discard the last error
9090
elif backend == 'mori_shmem':
91-
# Initialize mori_shmem device symbols in this kernel module
92-
import mori.shmem as mori_shmem
93-
mori_shmem.shmem_module_init(kernel_module)
91+
if "mori_shmem" in kernel.asm.get('llir', ''):
92+
import mori.shmem as mori_shmem
93+
mori_shmem.shmem_module_init(kernel_module)
9494
elif is_maca():
9595
if "mxshmem" in kernel.asm['ttir']:
9696
import pymxshmem

python/triton_dist/kernels/amd/__init__.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,33 @@
2222
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
2323
#
2424
################################################################################
25-
from .allgather_gemm import ag_gemm_intra_node, create_ag_gemm_intra_node_context
26-
from .gemm_reduce_scatter import gemm_rs_intra_node, create_gemm_rs_intra_node_context
25+
from .ep_a2a_intra_node import (
26+
kernel_dispatch_token_intra_node,
27+
kernel_skipped_token_local_dispatch_intra_node,
28+
kernel_skipped_token_inplace_local_combine_intra_node,
29+
kernel_combine_token_intra_node,
30+
get_ag_splits_and_recv_offset_for_dispatch_intra_node,
31+
)
32+
from .low_latency_all_to_all import create_all_to_all_context, fast_all_to_all, all_to_all_post_process
33+
34+
try:
35+
from .allgather_gemm import ag_gemm_intra_node, create_ag_gemm_intra_node_context
36+
from .gemm_reduce_scatter import gemm_rs_intra_node, create_gemm_rs_intra_node_context
37+
except ImportError as e:
38+
import warnings
39+
warnings.warn(f"allgather_gemm/gemm_reduce_scatter unavailable (pyrocshmem not installed): {e}")
2740

2841
__all__ = [
29-
"ag_gemm_intra_node", "create_ag_gemm_intra_node_context", "gemm_rs_intra_node", "create_gemm_rs_intra_node_context"
42+
"ag_gemm_intra_node",
43+
"create_ag_gemm_intra_node_context",
44+
"gemm_rs_intra_node",
45+
"create_gemm_rs_intra_node_context",
46+
"kernel_dispatch_token_intra_node",
47+
"kernel_skipped_token_local_dispatch_intra_node",
48+
"kernel_skipped_token_inplace_local_combine_intra_node",
49+
"kernel_combine_token_intra_node",
50+
"get_ag_splits_and_recv_offset_for_dispatch_intra_node",
51+
"create_all_to_all_context",
52+
"fast_all_to_all",
53+
"all_to_all_post_process",
3054
]

python/triton_dist/kernels/amd/common_ops.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@
3131

3232
from triton_dist.language.extra.hip.language_extra import load, atomic_add, sync_grid, atomic_cas, tid, __syncthreads
3333
from hip import hip
34-
from triton_dist.utils import HIP_CHECK, rocshmem_barrier_all_on_stream
34+
from triton_dist.utils import (
35+
HIP_CHECK,
36+
get_shmem_backend,
37+
mori_shmem_barrier_all_on_stream,
38+
rocshmem_barrier_all_on_stream,
39+
)
3540

3641

3742
@triton.jit
@@ -175,9 +180,9 @@ def barrier_all_kernel(rank, num_ranks, comm_buf_ptr):
175180

176181

177182
def barrier_all_on_stream(stream: Optional[torch.cuda.Stream] = None):
178-
'''
179-
call rocshmem barrier api
180-
'''
183+
"""Call shmem barrier on stream: mori_shmem when backend is mori_shmem, else rocshmem."""
184+
if get_shmem_backend() == "mori_shmem":
185+
return mori_shmem_barrier_all_on_stream(stream)
181186
return rocshmem_barrier_all_on_stream(stream)
182187

183188

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
################################################################################
2+
#
3+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining
6+
# a copy of this software and associated documentation files
7+
# (the "Software"), to deal in the Software without restriction,
8+
# including without limitation the rights to use, copy, modify, merge,
9+
# publish, distribute, sublicense, and/or sell copies of the Software,
10+
# and to permit persons to whom the Software is furnished to do so,
11+
# subject to the following conditions:
12+
#
13+
# The above copyright notice and this permission notice shall be
14+
# included in all copies or substantial portions of the Software.
15+
#
16+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17+
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18+
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19+
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
20+
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
21+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
22+
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
23+
#
24+
################################################################################
25+
"""
26+
AMD EP A2A kernels and helpers.
27+
Provides bincount and re-exports intra-node kernels from ep_a2a_intra_node.
28+
"""
29+
30+
import torch
31+
import triton.language as tl
32+
import triton_dist
33+
from triton_dist.language.extra.hip.language_extra import tid, ld, atomic_add
34+
from triton_dist.language.extra.language_extra import threads_per_warp
35+
36+
37+
@triton_dist.jit(do_not_specialize=["n", "length", "num_sm"])
38+
def kernel_bincount(n, input, output, length, num_sm, num_warps: tl.constexpr):
39+
"""
40+
GPU bincount: count occurrences of each index in [0, length). AMD version using tid(0)
41+
and fixed threads_per_block (no simt_exec_region). Same semantics as nvidia/ep_a2a.py.
42+
"""
43+
pid = tl.program_id(0)
44+
num_pid = tl.num_programs(0)
45+
thread_idx = tid(0)
46+
threads_per_block = num_warps * threads_per_warp()
47+
for i in range(pid * threads_per_block + thread_idx, n, num_pid * threads_per_block):
48+
val = ld(input + i)
49+
if val < length:
50+
atomic_add(output + val, 1, scope="agent", semantic="relaxed")
51+
52+
53+
def bincount(input_tensor, length, output=None, output_dtype=torch.int32, num_sm=16, num_warps=8):
54+
"""GPU bincount for AMD (no AOT). input_tensor: 1D int32 on device; output: length elements."""
55+
if output is None:
56+
output = torch.zeros(length, dtype=output_dtype, device=input_tensor.device)
57+
assert input_tensor.dim() == 1 and input_tensor.is_contiguous()
58+
assert output.size(0) >= length and output.dtype == output_dtype
59+
n = input_tensor.size(0)
60+
grid = (num_sm, )
61+
kernel_bincount[grid](n, input_tensor, output, length, num_sm, num_warps=num_warps)
62+
return output
63+
64+
65+
# Re-export intra-node kernels and helpers so layer can import from this module only.
66+
from triton_dist.kernels.amd.ep_a2a_intra_node import (
67+
kernel_combine_token_intra_node,
68+
kernel_dispatch_token_intra_node,
69+
get_ag_splits_and_recv_offset_for_dispatch_intra_node,
70+
kernel_skipped_token_local_dispatch_intra_node,
71+
kernel_skipped_token_inplace_local_combine_intra_node,
72+
)
73+
74+
__all__ = [
75+
"kernel_bincount",
76+
"bincount",
77+
"kernel_combine_token_intra_node",
78+
"kernel_dispatch_token_intra_node",
79+
"get_ag_splits_and_recv_offset_for_dispatch_intra_node",
80+
"kernel_skipped_token_local_dispatch_intra_node",
81+
"kernel_skipped_token_inplace_local_combine_intra_node",
82+
]

0 commit comments

Comments
 (0)