Skip to content

Commit d3176a3

Browse files
Fix threaded FastBroadcast for VectorOfArray (#564)
The generic threaded path in FastBroadcast splits work along the last axis via views, which does not correctly partition VectorOfArray — each thread ended up operating on the full array. Add Threaded dispatch methods that iterate over inner arrays directly using Threads.@threads for SArray VectorOfArrays, and fall back to Serial for other types. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 40a6231 commit d3176a3

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

ext/RecursiveArrayToolsFastBroadcastExt.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module RecursiveArrayToolsFastBroadcastExt
22

33
using RecursiveArrayTools
44
using FastBroadcast
5-
using FastBroadcast: Serial
5+
using FastBroadcast: Serial, Threaded
66
using StaticArraysCore
77

88
const AbstractVectorOfSArray = AbstractVectorOfArray{
@@ -27,4 +27,32 @@ const AbstractVectorOfSArray = AbstractVectorOfArray{
2727
return dst
2828
end
2929

30+
@inline function FastBroadcast.fast_materialize!(
31+
::Threaded, dst::AbstractVectorOfSArray,
32+
bc::Broadcast.Broadcasted{S}
33+
) where {S}
34+
if FastBroadcast.use_fast_broadcast(S)
35+
Threads.@threads for i in 1:length(dst.u)
36+
unpacked = RecursiveArrayTools.unpack_voa(bc, i)
37+
dst.u[i] = StaticArraysCore.similar_type(dst.u[i])(
38+
unpacked[j]
39+
for j in eachindex(unpacked)
40+
)
41+
end
42+
else
43+
Broadcast.materialize!(dst, bc)
44+
end
45+
return dst
46+
end
47+
48+
# Fallback for non-SArray VectorOfArray: the generic threaded path splits
49+
# along the last axis via views, which does not correctly partition work for
50+
# VectorOfArray. Fall back to serial broadcasting.
51+
@inline function FastBroadcast.fast_materialize!(
52+
::Threaded, dst::AbstractVectorOfArray,
53+
bc::Broadcast.Broadcasted
54+
)
55+
return FastBroadcast.fast_materialize!(Serial(), dst, bc)
56+
end
57+
3058
end # module

test/interface_tests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,18 @@ f3!(z, zz)
284284
@test z == VA[fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})]
285285
@test (@allocated f3!(z, zz)) == 0
286286

287+
# Test threaded FastBroadcast with VectorOfArray of StaticArrays (issue #564)
288+
@testset "Threaded @.. with VectorOfArray{SArray}" begin
289+
u_t = VectorOfArray(fill(SVector(1.0, 1.0), 2, 2))
290+
v_t = copy(u_t)
291+
@.. thread = true v_t = v_t + u_t
292+
@test all(x -> x == SVector(2.0, 2.0), v_t.u)
293+
294+
# Test that repeated threaded application accumulates correctly
295+
@.. thread = true v_t = v_t + u_t
296+
@test all(x -> x == SVector(3.0, 3.0), v_t.u)
297+
end
298+
287299
struct ImmutableVectorOfArray{T, N, A} <: AbstractVectorOfArray{T, N, A}
288300
u::A # A <: AbstractArray{<: AbstractArray{T, N - 1}}
289301
end

0 commit comments

Comments
 (0)