Skip to content

Commit cdd0769

Browse files
Merge pull request #566 from ChrisRackauckas-Claude/fix-threaded-voa-broadcast-v3
Backport: Fix threaded FastBroadcast for VectorOfArray
2 parents 40a6231 + da51fbc commit cdd0769

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveArrayTools"
22
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
33
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
4-
version = "3.52.0"
4+
version = "3.53.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

ext/RecursiveArrayToolsFastBroadcastExt.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module RecursiveArrayToolsFastBroadcastExt
22

33
using RecursiveArrayTools
44
using FastBroadcast
5-
using FastBroadcast: Serial
5+
using FastBroadcast: Serial, Threaded
66
using StaticArraysCore
77

88
const AbstractVectorOfSArray = AbstractVectorOfArray{
@@ -27,4 +27,32 @@ const AbstractVectorOfSArray = AbstractVectorOfArray{
2727
return dst
2828
end
2929

30+
@inline function FastBroadcast.fast_materialize!(
31+
::Threaded, dst::AbstractVectorOfSArray,
32+
bc::Broadcast.Broadcasted{S}
33+
) where {S}
34+
if FastBroadcast.use_fast_broadcast(S)
35+
Threads.@threads for i in 1:length(dst.u)
36+
unpacked = RecursiveArrayTools.unpack_voa(bc, i)
37+
dst.u[i] = StaticArraysCore.similar_type(dst.u[i])(
38+
unpacked[j]
39+
for j in eachindex(unpacked)
40+
)
41+
end
42+
else
43+
Broadcast.materialize!(dst, bc)
44+
end
45+
return dst
46+
end
47+
48+
# Fallback for non-SArray VectorOfArray: the generic threaded path splits
49+
# along the last axis via views, which does not correctly partition work for
50+
# VectorOfArray. Fall back to serial broadcasting.
51+
@inline function FastBroadcast.fast_materialize!(
52+
::Threaded, dst::AbstractVectorOfArray,
53+
bc::Broadcast.Broadcasted
54+
)
55+
return FastBroadcast.fast_materialize!(Serial(), dst, bc)
56+
end
57+
3058
end # module

test/interface_tests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,18 @@ f3!(z, zz)
284284
@test z == VA[fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})]
285285
@test (@allocated f3!(z, zz)) == 0
286286

287+
# Test threaded FastBroadcast with VectorOfArray of StaticArrays (issue #564)
288+
@testset "Threaded @.. with VectorOfArray{SArray}" begin
289+
u_t = VectorOfArray(fill(SVector(1.0, 1.0), 2, 2))
290+
v_t = copy(u_t)
291+
@.. thread = true v_t = v_t + u_t
292+
@test all(x -> x == SVector(2.0, 2.0), v_t.u)
293+
294+
# Test that repeated threaded application accumulates correctly
295+
@.. thread = true v_t = v_t + u_t
296+
@test all(x -> x == SVector(3.0, 3.0), v_t.u)
297+
end
298+
287299
struct ImmutableVectorOfArray{T, N, A} <: AbstractVectorOfArray{T, N, A}
288300
u::A # A <: AbstractArray{<: AbstractArray{T, N - 1}}
289301
end

0 commit comments

Comments
 (0)