Skip to content

Commit 0825141

Browse files
committed
Test support for MPI_Reduce and MPI_Allreduce for AMDGPU
1 parent 275d21d commit 0825141

4 files changed

Lines changed: 35 additions & 7 deletions

File tree

test/mpi_support_test.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,30 @@ MPI.Init()
66
# (or a similar signal) when called, which cannot be handled in Julia in a portable way.
77

88
op = ARGS[1]
9-
if op == "Iallreduce"
10-
# Iallreduce is unsupported for CUDA with OpenMPI + UCX
11-
# See https://docs.open-mpi.org/en/main/tuning-apps/networking/cuda.html#which-mpi-apis-do-not-work-with-cuda-aware-ucx
9+
if op == "Allreduce"
10+
# Allreduce is unsupported for AMDGPU with UCX
11+
send_arr = ArrayType(zeros(Int, 1))
12+
recv_arr = ArrayType{Int}(undef, 1)
13+
synchronize()
14+
MPI.Allreduce!(send_arr, recv_arr, +, MPI.COMM_WORLD)
15+
16+
elseif op == "Iallreduce"
17+
# Iallreduce is unsupported for CUDA with OpenMPI 5 + UCX
1218
send_arr = ArrayType(zeros(Int, 1))
1319
recv_arr = ArrayType{Int}(undef, 1)
1420
synchronize()
1521
req = MPI.Iallreduce!(send_arr, recv_arr, +, MPI.COMM_WORLD)
1622
MPI.Wait(req)
1723

24+
elseif op == "Reduce"
25+
# Reduce is unsupported for AMDGPU with UCX
26+
send_arr = ArrayType(zeros(Int, 1))
27+
recv_arr = ArrayType{Int}(undef, 1)
28+
synchronize()
29+
MPI.Reduce!(send_arr, recv_arr, +, MPI.COMM_WORLD; root=0)
30+
1831
elseif op == "Ireduce"
19-
# Iallreduce is unsupported for CUDA with OpenMPI + UCX
32+
# Ireduce is unsupported for CUDA with OpenMPI 5 + UCX
2033
send_arr = ArrayType(zeros(Int, 1))
2134
recv_arr = ArrayType{Int}(undef, 1)
2235
synchronize()

test/runtests.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,17 @@ end
7474
function is_mpi_operation_supported(mpi_op, n=nprocs)
7575
test_file = joinpath(@__DIR__, "mpi_support_test.jl")
7676
cmd = `$(mpiexec()) -n $n $(Base.julia_cmd()) --startup-file=no $test_file $mpi_op`
77-
supported = success(run(ignorestatus(cmd)))
77+
cmd = ignorestatus(pipeline(cmd; stderr=devnull))
78+
process = run(cmd)
79+
supported = success(process)
7880
!supported && @warn "$mpi_op is unsupported with $backend_name"
81+
!supported && !Base.process_signaled(process) && @warn "support check for $mpi_op may be broken as process did not signal"
7982
return supported
8083
end
8184

8285
if ArrayType != Array # we expect that only GPU backends can have unsupported features
86+
ENV["JULIA_MPI_TEST_ALLREDUCE"] = is_mpi_operation_supported("Allreduce")
87+
ENV["JULIA_MPI_TEST_REDUCE"] = is_mpi_operation_supported("Reduce")
8388
ENV["JULIA_MPI_TEST_IALLREDUCE"] = is_mpi_operation_supported("Iallreduce")
8489
ENV["JULIA_MPI_TEST_IREDUCE"] = is_mpi_operation_supported("Ireduce")
8590
end

test/test_allreduce.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
include("common.jl")
22

3+
allreduce_supported = get(ENV, "JULIA_MPI_TEST_ALLREDUCE", "true") == "true"
4+
iallreduce_supported = get(ENV, "JULIA_MPI_TEST_IALLREDUCE", "true") == "true"
5+
if !allreduce_supported
6+
@warn "Skipping all tests in 'test_allreduce.jl' as reductions are unsupported"
7+
return
8+
end
9+
310
MPI.Init()
411

512
comm_size = MPI.Comm_size(MPI.COMM_WORLD)
@@ -13,8 +20,6 @@ else
1320
operators = [MPI.SUM, +, (x,y) -> 2x+y-x]
1421
end
1522

16-
iallreduce_supported = get(ENV, "JULIA_MPI_TEST_IALLREDUCE", "true") == "true"
17-
1823

1924
for T = [Int]
2025
for dims = [1, 2, 3]

test/test_reduce.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@ const can_do_closures =
99
Sys.ARCH !== :aarch64 &&
1010
!startswith(string(Sys.ARCH), "arm")
1111

12+
reduce_supported = get(ENV, "JULIA_MPI_TEST_REDUCE", "true") == "true"
1213
ireduce_supported = get(ENV, "JULIA_MPI_TEST_IREDUCE", "true") == "true"
14+
if !reduce_supported
15+
@warn "Skipping all tests in 'test_reduce.jl' as reductions are unsupported"
16+
return
17+
end
1318

1419
using DoubleFloats
1520

0 commit comments

Comments
 (0)