Skip to content

Commit 93d261e

Browse files
committed
issue/1021 - feat: support bf16 in infiniccl on moore_gpu_arch mp_31 with mccl
1 parent 718b18c commit 93d261e

File tree

3 files changed

+44
-4
lines changed

3 files changed

+44
-4
lines changed

src/infiniccl/moore/infiniccl_moore.cc

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ inline mcclDataType_t getMcclDtype(infiniDtype_t datatype) {
2323
return mcclFloat;
2424
case INFINI_DTYPE_F16:
2525
return mcclHalf;
26+
27+
#if MARCH_TYPE == 310
28+
case INFINI_DTYPE_BF16:
29+
return mcclBfloat16;
30+
#endif
31+
2632
default:
2733
std::abort();
2834
return mcclHalf;
@@ -83,9 +89,16 @@ infiniStatus_t allReduce(
8389
infinicclComm_t comm,
8490
infinirtStream_t stream) {
8591

86-
if (datatype != INFINI_DTYPE_F32 && datatype != INFINI_DTYPE_F16) {
87-
return INFINI_STATUS_BAD_PARAM;
88-
}
92+
#if MARCH_TYPE == 310
93+
CHECK_DTYPE(datatype,
94+
INFINI_DTYPE_F32,
95+
INFINI_DTYPE_F16,
96+
INFINI_DTYPE_BF16);
97+
#else
98+
CHECK_DTYPE(datatype,
99+
INFINI_DTYPE_F32,
100+
INFINI_DTYPE_F16);
101+
#endif
89102

90103
CHECK_MCCL(mcclAllReduce(sendbuf, recvbuf, count, getMcclDtype(datatype),
91104
getMcclRedOp(op), getMcclComm(comm), getMusaStream(stream)));

xmake.lua

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,12 @@ option("moore-gpu")
180180
set_description("Whether to compile implementations for Moore Threads GPU")
181181
option_end()
182182

183+
option("moore-gpu-arch")
184+
set_default("mp_31")
185+
set_showmenu(true)
186+
set_description("Set Moore GPU architecture (e.g. mp_31)")
187+
option_end()
188+
183189
if has_config("moore-gpu") then
184190
add_defines("ENABLE_MOORE_API")
185191
includes("xmake/moore.lua")

xmake/moore.lua

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,22 @@ rule("mu")
1616
local mcc = MUSA_ROOT .. "/bin/mcc"
1717
local includedirs = table.concat(target:get("includedirs"), " ")
1818

19-
local args = {"-c", sourcefile, "-o", objectfile, "-I" .. MUSA_ROOT .. "/include", "-O3", "-fPIC", "-Wall", "-std=c++17", "-pthread"}
19+
local args = {
20+
"-c", sourcefile,
21+
"-o", objectfile,
22+
"-I" .. MUSA_ROOT .. "/include",
23+
"-O3",
24+
"-fPIC",
25+
"-Wall",
26+
"-std=c++17",
27+
"-pthread"
28+
}
29+
local moore_gpu_arch = get_config("moore-gpu-arch")
30+
31+
if moore_gpu_arch == "mp_31" then
32+
table.insert(args, 1, "--cuda-gpu-arch=mp_31")
33+
end
34+
2035
for _, includedir in ipairs(target:get("includedirs")) do
2136
table.insert(args, "-I" .. includedir)
2237
end
@@ -76,6 +91,12 @@ target("infiniccl-moore")
7691
if has_config("ccl") then
7792
add_links("libmccl.so")
7893
add_files("../src/infiniccl/moore/*.cc")
94+
95+
-- Moore GPU arch with mp_31 support mcclBfloat16 in MCCL
96+
if get_config("moore-gpu-arch") == "mp_31" then
97+
add_defines("MARCH_TYPE=310")
98+
add_cxxflags("-Wno-unused-function")
99+
end
79100
end
80101
set_languages("cxx17")
81102

0 commit comments

Comments
 (0)