Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion ext/RecursiveArrayToolsFastBroadcastExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module RecursiveArrayToolsFastBroadcastExt

using RecursiveArrayTools
using FastBroadcast
using FastBroadcast: Serial
using FastBroadcast: Serial, Threaded
using StaticArraysCore

const AbstractVectorOfSArray = AbstractVectorOfArray{
Expand All @@ -27,4 +27,32 @@ const AbstractVectorOfSArray = AbstractVectorOfArray{
return dst
end

@inline function FastBroadcast.fast_materialize!(
::Threaded, dst::AbstractVectorOfSArray,
bc::Broadcast.Broadcasted{S}
) where {S}
if FastBroadcast.use_fast_broadcast(S)
Threads.@threads for i in 1:length(dst.u)
unpacked = RecursiveArrayTools.unpack_voa(bc, i)
dst.u[i] = StaticArraysCore.similar_type(dst.u[i])(
unpacked[j]
for j in eachindex(unpacked)
)
end
else
Broadcast.materialize!(dst, bc)
end
return dst
end

# Fallback for non-SArray VectorOfArray: the generic threaded path splits
# along the last axis via views, which does not correctly partition work for
# VectorOfArray. Fall back to serial broadcasting.
@inline function FastBroadcast.fast_materialize!(
::Threaded, dst::AbstractVectorOfArray,
bc::Broadcast.Broadcasted
)
return FastBroadcast.fast_materialize!(Serial(), dst, bc)
end

end # module
12 changes: 12 additions & 0 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,18 @@ f3!(z, zz)
@test z == VA[fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})]
@test (@allocated f3!(z, zz)) == 0

# Test threaded FastBroadcast with VectorOfArray of StaticArrays (issue #564)
@testset "Threaded @.. with VectorOfArray{SArray}" begin
u_t = VectorOfArray(fill(SVector(1.0, 1.0), 2, 2))
v_t = copy(u_t)
@.. thread = true v_t = v_t + u_t
@test all(x -> x == SVector(2.0, 2.0), v_t.u)

# Test that repeated threaded application accumulates correctly
@.. thread = true v_t = v_t + u_t
@test all(x -> x == SVector(3.0, 3.0), v_t.u)
end

struct ImmutableVectorOfArray{T, N, A} <: AbstractVectorOfArray{T, N, A}
u::A # A <: AbstractArray{<: AbstractArray{T, N - 1}}
end
Expand Down
Loading