Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/PartitionedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ export exchange
export exchange!
export allocate_exchange
export find_rcv_ids_gather_scatter
export setup_non_blocking_reduction
export non_blocking_reduction
include("primitives.jl")

export DebugArray
Expand Down Expand Up @@ -144,6 +146,8 @@ export SplitVector
export split_vector
export split_vector_blocks
export pvector_from_split_blocks
export setup_non_blocking_dot
export non_blocking_dot
include("p_vector.jl")

export SplitMatrix
Expand Down
32 changes: 32 additions & 0 deletions src/mpi_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,38 @@ function reduction_impl(op,a::MPIArray,destination;init=nothing)
MPIArray(b_item,comm,size(a))
end

function setup_non_blocking_reduction_impl(a::MPIArray, ::Type{T}) where T
request = MPI.UnsafeRequest() # Single reduction request
buffer = Ref{T}()
return (request = request, recvbuf = buffer)
end

function non_blocking_reduction_impl(op, a::MPIArray, setup, destination=:all; init=nothing)
@assert destination === :all
T = eltype(a)
comm = a.comm
opr = MPI.Op(op, T)

sendbuf = Ref(a.item)
recvbuf = setup.recvbuf
request = setup.request
rbuf = MPI.RBuffer(sendbuf, recvbuf)

state = (sendbuf, recvbuf, request)

GC.@preserve state MPI.API.MPI_Iallreduce(rbuf.senddata, rbuf.recvdata, rbuf.count, rbuf.datatype, opr, comm, request)


@fake_async begin
GC.@preserve state MPI.Wait(request)
b_item = recvbuf[]
if init !== nothing
b_item = op(b_item,init)
end
MPIArray(b_item,comm,size(a))
end
end

function Base.reduce(op,a::MPIArray;kwargs...)
r = reduction(op,a;destination=:all,kwargs...)
r.item
Expand Down
16 changes: 16 additions & 0 deletions src/p_vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,22 @@ function LinearAlgebra.dot(a::PVector,b::PVector)
sum(c)
end

function setup_non_blocking_dot(a::PVector, b::PVector)
partials = map(own_values(a), own_values(b)) do mya, myb
zero(eltype(mya)) + zero(eltype(myb))
end
setup_non_blocking_reduction(partials)
end

function non_blocking_dot(a::PVector, b::PVector, setup)
partials = map(dot, own_values(a), own_values(b))
t = non_blocking_reduction(+, partials, setup, destination=:all, init=zero(eltype(a)) + zero(eltype(b)))
@fake_async begin
getany(fetch(t))
end
end


function LinearAlgebra.rmul!(a::PVector,v::Number)
map(partition(a)) do l
rmul!(l,v)
Expand Down
18 changes: 18 additions & 0 deletions src/primitives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,24 @@ end
# b
#end

function non_blocking_reduction(op,a,setup= setup_non_blocking_reduction(a);destination=MAIN,kwargs...)
non_blocking_reduction_impl(op,a,setup,destination;kwargs...)
end

function setup_non_blocking_reduction(a)
setup_non_blocking_reduction_impl(a,eltype(a))
end

function setup_non_blocking_reduction_impl(a::AbstractArray, ::Type{T}) where T
return nothing
end

function non_blocking_reduction_impl(op, a::AbstractArray, setup, destination=:all; init=nothing)
@fake_async begin
reduction_impl(op, a, destination; init=init)
end
end

"""
struct ExchangeGraph{A}

Expand Down
5 changes: 5 additions & 0 deletions test/p_vector_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ function p_vector_tests(distribute)
@test sqrt(a⋅a) ≈ norm(a)
@test euclidean(a,a) + 1 ≈ 1

# Quick Test non_blocking_dot
setup = setup_non_blocking_dot(a,b)
t = non_blocking_dot(a,b,setup)
@test fetch(t) ≈ a⋅b

n = 10
parts = rank
row_partition = map(parts) do part
Expand Down
Loading