diff --git a/src/PartitionedArrays.jl b/src/PartitionedArrays.jl index 7d3e5d26..163c5578 100644 --- a/src/PartitionedArrays.jl +++ b/src/PartitionedArrays.jl @@ -55,6 +55,8 @@ export exchange export exchange! export allocate_exchange export find_rcv_ids_gather_scatter +export setup_non_blocking_reduction +export non_blocking_reduction include("primitives.jl") export DebugArray @@ -144,6 +146,8 @@ export SplitVector export split_vector export split_vector_blocks export pvector_from_split_blocks +export setup_non_blocking_dot +export non_blocking_dot include("p_vector.jl") export SplitMatrix diff --git a/src/mpi_array.jl b/src/mpi_array.jl index 72135536..7b3d0d32 100644 --- a/src/mpi_array.jl +++ b/src/mpi_array.jl @@ -500,6 +500,38 @@ function reduction_impl(op,a::MPIArray,destination;init=nothing) MPIArray(b_item,comm,size(a)) end +function setup_non_blocking_reduction_impl(a::MPIArray, ::Type{T}) where T + request = MPI.UnsafeRequest() # Single reduction request + buffer = Ref{T}() + return (request = request, recvbuf = buffer) +end + +function non_blocking_reduction_impl(op, a::MPIArray, setup, destination=:all; init=nothing) + @assert destination === :all + T = eltype(a) + comm = a.comm + opr = MPI.Op(op, T) + + sendbuf = Ref(a.item) + recvbuf = setup.recvbuf + request = setup.request + rbuf = MPI.RBuffer(sendbuf, recvbuf) + + state = (sendbuf, recvbuf, request) + + GC.@preserve state MPI.API.MPI_Iallreduce(rbuf.senddata, rbuf.recvdata, rbuf.count, rbuf.datatype, opr, comm, request) + + + @fake_async begin + GC.@preserve state MPI.Wait(request) + b_item = recvbuf[] + if init !== nothing + b_item = op(b_item,init) + end + MPIArray(b_item,comm,size(a)) + end +end + function Base.reduce(op,a::MPIArray;kwargs...) r = reduction(op,a;destination=:all,kwargs...) r.item diff --git a/src/p_vector.jl b/src/p_vector.jl index 61b02ef2..07ff7723 100644 --- a/src/p_vector.jl +++ b/src/p_vector.jl @@ -1191,6 +1191,22 @@ function LinearAlgebra.dot(a::PVector,b::PVector) sum(c) end +function setup_non_blocking_dot(a::PVector, b::PVector) + partials = map(own_values(a), own_values(b)) do mya, myb + zero(eltype(mya)) + zero(eltype(myb)) + end + setup_non_blocking_reduction(partials) +end + +function non_blocking_dot(a::PVector, b::PVector, setup) + partials = map(dot, own_values(a), own_values(b)) + t = non_blocking_reduction(+, partials, setup, destination=:all, init=zero(eltype(a)) + zero(eltype(b))) + @fake_async begin + getany(fetch(t)) + end +end + + function LinearAlgebra.rmul!(a::PVector,v::Number) map(partition(a)) do l rmul!(l,v) diff --git a/src/primitives.jl b/src/primitives.jl index 25b1869f..86726137 100644 --- a/src/primitives.jl +++ b/src/primitives.jl @@ -708,6 +708,24 @@ end # b #end +function non_blocking_reduction(op,a,setup= setup_non_blocking_reduction(a);destination=MAIN,kwargs...) + non_blocking_reduction_impl(op,a,setup,destination;kwargs...) +end + +function setup_non_blocking_reduction(a) + setup_non_blocking_reduction_impl(a,eltype(a)) +end + +function setup_non_blocking_reduction_impl(a::AbstractArray, ::Type{T}) where T + return nothing +end + +function non_blocking_reduction_impl(op, a::AbstractArray, setup, destination=:all; init=nothing) + @fake_async begin + reduction_impl(op, a, destination; init=init) + end +end + """ struct ExchangeGraph{A} diff --git a/test/p_vector_tests.jl b/test/p_vector_tests.jl index b3e21797..6af6ded6 100644 --- a/test/p_vector_tests.jl +++ b/test/p_vector_tests.jl @@ -90,6 +90,11 @@ function p_vector_tests(distribute) @test sqrt(a⋅a) ≈ norm(a) @test euclidean(a,a) + 1 ≈ 1 + # Quick Test non_blocking_dot + setup = setup_non_blocking_dot(a,b) + t = non_blocking_dot(a,b,setup) + @test fetch(t) ≈ a⋅b + n = 10 parts = rank row_partition = map(parts) do part