Skip to content

Commit 110f2aa

Browse files
Merge pull request #567 from ChrisRackauckas-Claude/polyester-threading-ext
Add Polyester threading extension for FastBroadcast VectorOfArray
2 parents 3b56aa8 + 4335a0d commit 110f2aa

File tree

4 files changed

+89
-21
lines changed

4 files changed

+89
-21
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2121
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
2222
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
2323
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
24+
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
2425
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2526
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2627
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -32,6 +33,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3233
[extensions]
3334
RecursiveArrayToolsCUDAExt = "CUDA"
3435
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
36+
RecursiveArrayToolsFastBroadcastPolyesterExt = ["FastBroadcast", "Polyester"]
3537
RecursiveArrayToolsForwardDiffExt = "ForwardDiff"
3638
RecursiveArrayToolsKernelAbstractionsExt = "KernelAbstractions"
3739
RecursiveArrayToolsMeasurementsExt = "Measurements"
@@ -59,6 +61,7 @@ Measurements = "2.11"
5961
MonteCarloMeasurements = "1.2"
6062
NLsolve = "4.5"
6163
Pkg = "1"
64+
Polyester = "0.7.16"
6265
PrecompileTools = "1.2.1"
6366
Random = "1"
6467
RecipesBase = "1.3.4"
@@ -86,6 +89,7 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
8689
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
8790
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
8891
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
92+
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
8993
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9094
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
9195
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -98,4 +102,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
98102
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
99103

100104
[targets]
101-
test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Random", "SafeTestsets", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Tables", "Test", "Unitful", "Zygote"]
105+
test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Polyester", "Random", "SafeTestsets", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Tables", "Test", "Unitful", "Zygote"]

ext/RecursiveArrayToolsFastBroadcastExt.jl

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,25 @@ const AbstractVectorOfSArray = AbstractVectorOfArray{
2727
return dst
2828
end
2929

30-
@inline function FastBroadcast.fast_materialize!(
31-
::Threaded, dst::AbstractVectorOfSArray,
32-
bc::Broadcast.Broadcasted{S}
33-
) where {S}
34-
if FastBroadcast.use_fast_broadcast(S)
35-
Threads.@threads for i in 1:length(dst.u)
36-
unpacked = RecursiveArrayTools.unpack_voa(bc, i)
37-
dst.u[i] = StaticArraysCore.similar_type(dst.u[i])(
38-
unpacked[j]
39-
for j in eachindex(unpacked)
40-
)
41-
end
42-
else
43-
Broadcast.materialize!(dst, bc)
44-
end
45-
return dst
46-
end
47-
48-
# Fallback for non-SArray VectorOfArray: the generic threaded path splits
49-
# along the last axis via views, which does not correctly partition work for
30+
# Fallback for non-SArray VectorOfArray: the generic threaded path splits along
31+
# the last axis via views, which does not correctly partition work for
5032
# VectorOfArray. Fall back to serial broadcasting.
33+
# For SArray VectorOfArray, throw an informative error telling the user to
34+
# load Polyester.jl for threaded broadcasting.
5135
@inline function FastBroadcast.fast_materialize!(
5236
::Threaded, dst::AbstractVectorOfArray,
5337
bc::Broadcast.Broadcasted
5438
)
39+
# When Polyester is loaded, RecursiveArrayToolsFastBroadcastPolyesterExt
40+
# defines more-specific methods for AbstractVectorOfSArray, so reaching
41+
# this method with an SArray VoA means Polyester is not loaded.
42+
if dst isa AbstractVectorOfSArray
43+
error(
44+
"Threaded FastBroadcast on VectorOfArray{SArray} requires Polyester.jl. " *
45+
"Add `using Polyester` to enable threaded broadcasting, or use " *
46+
"`@.. thread=false` for serial broadcasting."
47+
)
48+
end
5549
return FastBroadcast.fast_materialize!(Serial(), dst, bc)
5650
end
5751

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
module RecursiveArrayToolsFastBroadcastPolyesterExt
2+
3+
using RecursiveArrayTools
4+
using FastBroadcast
5+
using FastBroadcast: Serial, Threaded
6+
using Polyester
7+
using StaticArraysCore
8+
9+
const AbstractVectorOfSArray = AbstractVectorOfArray{
10+
T, N, <:AbstractVector{<:StaticArraysCore.SArray},
11+
} where {T, N}
12+
13+
@inline function _polyester_fast_materialize!(
14+
dst::AbstractVectorOfSArray,
15+
bc::Broadcast.Broadcasted{S}
16+
) where {S}
17+
if FastBroadcast.use_fast_broadcast(S)
18+
@batch for i in 1:length(dst.u)
19+
unpacked = RecursiveArrayTools.unpack_voa(bc, i)
20+
dst.u[i] = StaticArraysCore.similar_type(dst.u[i])(
21+
unpacked[j]
22+
for j in eachindex(unpacked)
23+
)
24+
end
25+
else
26+
Broadcast.materialize!(dst, bc)
27+
end
28+
return dst
29+
end
30+
31+
@inline function FastBroadcast.fast_materialize!(
32+
::Threaded, dst::AbstractVectorOfSArray,
33+
bc::Broadcast.Broadcasted{S}
34+
) where {S}
35+
return _polyester_fast_materialize!(dst, bc)
36+
end
37+
38+
# Disambiguation: this method is more specific than both the base ext's
39+
# (::Threaded, ::AbstractVectorOfArray, ::Broadcasted) fallback and
40+
# the above (::Threaded, ::AbstractVectorOfSArray, ::Broadcasted{S}).
41+
@inline function FastBroadcast.fast_materialize!(
42+
::Threaded, dst::AbstractVectorOfSArray,
43+
bc::Broadcast.Broadcasted
44+
)
45+
return _polyester_fast_materialize!(dst, bc)
46+
end
47+
48+
end # module

test/interface_tests.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using RecursiveArrayTools, StaticArrays, Test
22
using RecursiveArrayToolsShorthandConstructors
33
using FastBroadcast
4+
using Polyester
45
using SymbolicIndexingInterface: SymbolCache
56

67
t = 1:3
@@ -302,6 +303,27 @@ f3!(z, zz)
302303
@test all(x -> x == SVector(3.0, 3.0), v_t.u)
303304
end
304305

306+
# Test Polyester-based threaded FastBroadcast extension (issue #564)
307+
@testset "Polyester-threaded @.. with VectorOfArray{SArray}" begin
308+
# Verify the Polyester extension is loaded
309+
@test Base.get_extension(
310+
Base.PkgId(RecursiveArrayTools),
311+
:RecursiveArrayToolsFastBroadcastPolyesterExt
312+
) !== nothing
313+
314+
# Test basic threaded broadcast with Polyester (Vector-of-SVector storage)
315+
u_p = VectorOfArray([SVector(2.0, 3.0) for _ in 1:9])
316+
v_p = copy(u_p)
317+
@.. thread = true v_p = v_p + u_p
318+
@test all(x -> x == SVector(4.0, 6.0), v_p.u)
319+
320+
# Test with larger array to exercise Polyester batching
321+
u_large = VectorOfArray([SVector(1.0, 1.0, 1.0) for _ in 1:100])
322+
v_large = VectorOfArray([SVector(0.0, 0.0, 0.0) for _ in 1:100])
323+
@.. thread = true v_large = u_large * 2.0
324+
@test all(x -> x == SVector(2.0, 2.0, 2.0), v_large.u)
325+
end
326+
305327
struct ImmutableVectorOfArray{T, N, A} <: AbstractVectorOfArray{T, N, A}
306328
u::A # A <: AbstractArray{<: AbstractArray{T, N - 1}}
307329
end

0 commit comments

Comments
 (0)