Skip to content

Commit 5798d0f

Browse files
Merge pull request #568 from ChrisRackauckas-Claude/polyester-threading-ext-v3
[v3 backport] Add Polyester threading extension for FastBroadcast VectorOfArray
2 parents cdd0769 + 973db56 commit 5798d0f

File tree

5 files changed

+90
-22
lines changed

5 files changed

+90
-22
lines changed

.github/workflows/Tests.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@ on:
44
pull_request:
55
branches:
66
- master
7+
- v3-backport
78
paths-ignore:
89
- 'docs/**'
910
push:
1011
branches:
1112
- master
13+
- v3-backport
1214
paths-ignore:
1315
- 'docs/**'
1416

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveArrayTools"
22
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
33
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
4-
version = "3.53.0"
4+
version = "3.54.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -25,11 +25,13 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2525
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2626
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
2727
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
28+
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
2829
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2930
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3031

3132
[extensions]
3233
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
34+
RecursiveArrayToolsFastBroadcastPolyesterExt = ["FastBroadcast", "Polyester"]
3335
RecursiveArrayToolsForwardDiffExt = "ForwardDiff"
3436
RecursiveArrayToolsKernelAbstractionsExt = "KernelAbstractions"
3537
RecursiveArrayToolsMeasurementsExt = "Measurements"
@@ -56,6 +58,7 @@ Measurements = "2.11"
5658
MonteCarloMeasurements = "1.2"
5759
NLsolve = "4.5"
5860
Pkg = "1"
61+
Polyester = "0.7.16"
5962
PrecompileTools = "1.2.1"
6063
Random = "1"
6164
RecipesBase = "1.3.4"
@@ -84,6 +87,7 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
8487
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
8588
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
8689
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
90+
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
8791
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
8892
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
8993
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
@@ -97,4 +101,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
97101
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
98102

99103
[targets]
100-
test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Random", "SafeTestsets", "SciMLBase", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Tables", "Test", "Unitful", "Zygote"]
104+
test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Polyester", "Random", "SafeTestsets", "SciMLBase", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Tables", "Test", "Unitful", "Zygote"]

ext/RecursiveArrayToolsFastBroadcastExt.jl

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,23 @@ const AbstractVectorOfSArray = AbstractVectorOfArray{
2727
return dst
2828
end
2929

30-
@inline function FastBroadcast.fast_materialize!(
31-
::Threaded, dst::AbstractVectorOfSArray,
32-
bc::Broadcast.Broadcasted{S}
33-
) where {S}
34-
if FastBroadcast.use_fast_broadcast(S)
35-
Threads.@threads for i in 1:length(dst.u)
36-
unpacked = RecursiveArrayTools.unpack_voa(bc, i)
37-
dst.u[i] = StaticArraysCore.similar_type(dst.u[i])(
38-
unpacked[j]
39-
for j in eachindex(unpacked)
40-
)
41-
end
42-
else
43-
Broadcast.materialize!(dst, bc)
44-
end
45-
return dst
46-
end
47-
48-
# Fallback for non-SArray VectorOfArray: the generic threaded path splits
49-
# along the last axis via views, which does not correctly partition work for
30+
# Fallback for non-SArray VectorOfArray: the generic threaded path splits along
31+
# the last axis via views, which does not correctly partition work for
5032
# VectorOfArray. Fall back to serial broadcasting.
33+
# For SArray VectorOfArray, throw an informative error telling the user to
34+
# load Polyester.jl for threaded broadcasting.
5135
@inline function FastBroadcast.fast_materialize!(
5236
::Threaded, dst::AbstractVectorOfArray,
5337
bc::Broadcast.Broadcasted
5438
)
39+
# When Polyester is loaded, RecursiveArrayToolsFastBroadcastPolyesterExt
40+
# defines more-specific methods for AbstractVectorOfSArray, so reaching
41+
# this method with an SArray VoA means Polyester is not loaded.
42+
if dst isa AbstractVectorOfSArray
43+
error("Threaded FastBroadcast on VectorOfArray{SArray} requires Polyester.jl. " *
44+
"Add `using Polyester` to enable threaded broadcasting, or use " *
45+
"`@.. thread=false` for serial broadcasting.")
46+
end
5547
return FastBroadcast.fast_materialize!(Serial(), dst, bc)
5648
end
5749

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
module RecursiveArrayToolsFastBroadcastPolyesterExt
2+
3+
using RecursiveArrayTools
4+
using FastBroadcast
5+
using FastBroadcast: Serial, Threaded
6+
using Polyester
7+
using StaticArraysCore
8+
9+
const AbstractVectorOfSArray = AbstractVectorOfArray{
10+
T, N, <:AbstractVector{<:StaticArraysCore.SArray},
11+
} where {T, N}
12+
13+
@inline function _polyester_fast_materialize!(
14+
dst::AbstractVectorOfSArray,
15+
bc::Broadcast.Broadcasted{S}
16+
) where {S}
17+
if FastBroadcast.use_fast_broadcast(S)
18+
@batch for i in 1:length(dst.u)
19+
unpacked = RecursiveArrayTools.unpack_voa(bc, i)
20+
dst.u[i] = StaticArraysCore.similar_type(dst.u[i])(
21+
unpacked[j]
22+
for j in eachindex(unpacked)
23+
)
24+
end
25+
else
26+
Broadcast.materialize!(dst, bc)
27+
end
28+
return dst
29+
end
30+
31+
@inline function FastBroadcast.fast_materialize!(
32+
::Threaded, dst::AbstractVectorOfSArray,
33+
bc::Broadcast.Broadcasted{S}
34+
) where {S}
35+
return _polyester_fast_materialize!(dst, bc)
36+
end
37+
38+
# Disambiguation: this method is more specific than both the base ext's
39+
# (::Threaded, ::AbstractVectorOfArray, ::Broadcasted) fallback and
40+
# the above (::Threaded, ::AbstractVectorOfSArray, ::Broadcasted{S}).
41+
@inline function FastBroadcast.fast_materialize!(
42+
::Threaded, dst::AbstractVectorOfSArray,
43+
bc::Broadcast.Broadcasted
44+
)
45+
return _polyester_fast_materialize!(dst, bc)
46+
end
47+
48+
end # module

test/interface_tests.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using RecursiveArrayTools, StaticArrays, Test
22
using FastBroadcast
3+
using Polyester
34
using SymbolicIndexingInterface: SymbolCache
45

56
t = 1:3
@@ -296,6 +297,27 @@ f3!(z, zz)
296297
@test all(x -> x == SVector(3.0, 3.0), v_t.u)
297298
end
298299

300+
# Test Polyester-based threaded FastBroadcast extension (issue #564)
301+
@testset "Polyester-threaded @.. with VectorOfArray{SArray}" begin
302+
# Verify the Polyester extension is loaded
303+
@test Base.get_extension(
304+
Base.PkgId(RecursiveArrayTools),
305+
:RecursiveArrayToolsFastBroadcastPolyesterExt
306+
) !== nothing
307+
308+
# Test basic threaded broadcast with Polyester (Vector-of-SVector storage)
309+
u_p = VectorOfArray([SVector(2.0, 3.0) for _ in 1:9])
310+
v_p = copy(u_p)
311+
@.. thread = true v_p = v_p + u_p
312+
@test all(x -> x == SVector(4.0, 6.0), v_p.u)
313+
314+
# Test with larger array to exercise Polyester batching
315+
u_large = VectorOfArray([SVector(1.0, 1.0, 1.0) for _ in 1:100])
316+
v_large = VectorOfArray([SVector(0.0, 0.0, 0.0) for _ in 1:100])
317+
@.. thread = true v_large = u_large * 2.0
318+
@test all(x -> x == SVector(2.0, 2.0, 2.0), v_large.u)
319+
end
320+
299321
struct ImmutableVectorOfArray{T, N, A} <: AbstractVectorOfArray{T, N, A}
300322
u::A # A <: AbstractArray{<: AbstractArray{T, N - 1}}
301323
end

0 commit comments

Comments
 (0)