Skip to content

Commit e23fc98

Browse files
author
Jordan Benjamin
committed
Optimize performance: thread-local reduction, zero-allocation kernels, and loop symmetry (Phase J8)
1 parent 1f12924 commit e23fc98

3 files changed

Lines changed: 98 additions & 82 deletions

File tree

src/Calculations.jl

Lines changed: 74 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -213,26 +213,34 @@ function _calculate_structure_function_core(
213213
end
214214
distance_bins_vec[end] = distance_bins[end][2]
215215

216+
# Create thread-local buffers to eliminate lock contention and allocations per iteration
217+
n_threads = Threads.nthreads()
218+
local_outputs = [zeros(OT, N3) for _ in 1:n_threads]
219+
local_counts = [zeros(OT, N3) for _ in 1:n_threads]
220+
216221
if verbose
217-
@info("calculating structure function")
222+
@info("calculating structure function (parallel reduction)")
218223
end
219224

220-
iter_inds = eachindex(x_vecs[1]) # these should all match..., idk if doing 1:N2 is faster but the indexing could be shifted...
221-
# PM.@showprogress enabled = show_progress for i in iter_inds # is this the fast order?
222-
lock = Threads.ReentrantLock()
223-
Threads.@threads for i::Int64 in iter_inds
224-
_output, _counts = calculate_structure_function_i(
225+
iter_inds = eachindex(x_vecs[1])
226+
Threads.@threads for i in iter_inds
227+
tid = Threads.threadid()
228+
calculate_structure_function_i!(
229+
local_outputs[tid],
230+
local_counts[tid],
225231
structure_function_type,
226232
i,
227233
x_vecs,
228234
u_vecs,
229235
distance_bins_vec;
230236
distance_metric = distance_metric,
231237
)
232-
Threads.lock(lock) do
233-
output .+= _output
234-
counts .+= _counts
235-
end
238+
end
239+
240+
# Global reduction from thread-local buffers
241+
for tid in 1:n_threads
242+
output .+= local_outputs[tid]
243+
counts .+= local_counts[tid]
236244
end
237245

238246
if RSAC # just return the sums and the counts, don't take the mean in each bin...
@@ -247,43 +255,28 @@ end
247255

248256

249257

250-
function calculate_structure_function_i(
258+
function calculate_structure_function_i!(
259+
output::AbstractVector{OT},
260+
counts::AbstractVector{OT},
251261
structure_function_type::SFT.AbstractStructureFunctionType,
252262
i::Int,
253263
x_vecs::Tuple{T1, Vararg{T1}},
254264
u_vecs::Tuple{T2, Vararg{T2}},
255265
distance_bins_vec::AbstractVector{FT3};
256266
distance_metric::DI.PreMetric = DI.Euclidean(),
257-
) where {T1, T2, FT3}
267+
) where {OT, T1, T2, FT3}
258268
N = length(x_vecs)
259269
FT1 = eltype(T1)
260270
FT2 = eltype(T2)
261271
N3 = length(distance_bins_vec)
262272

263-
264-
265-
# Commit A2: Fixed "double calculating" by iterating only over unique pairs (j > i).
266-
# Removed BadImplementationError as this path is now considered correct.
267-
268-
# preallocate output as vector of length of distance_bins
269-
OT = promote_type(float(FT1), float(FT2))
270-
output = zeros(OT, N3 - 1)
271-
counts = zeros(OT, N3 - 1)
272-
273-
iter_inds = eachindex(x_vecs[1])
274-
275-
# PERFORMANCE NOTE: Converting to SVector at the loop boundary (hoisting)
276-
# is critical for zero heap-allocations in the inner loop.
277-
# v2 - v1 on SubArrays/Vectors triggers heap temporaries via broadcasting.
278-
# SVector arithmetic is handled entirely on the stack/CPU registers.
279-
# We load U1/X1 once here to avoid redundant memory access in the j-loop.
280-
X1 = SA.SVector{N, FT1}(ntuple(k -> x_vecs[k][i], Val{N}()))
281-
U1 = SA.SVector{N, FT2}(ntuple(k -> u_vecs[k][i], Val{N}()))
273+
X1 = SA.SVector{N, FT1}(ntuple(k -> x_vecs[k][i], Val(N)))
274+
U1 = SA.SVector{N, FT2}(ntuple(k -> u_vecs[k][i], Val(N)))
282275

283-
# Iterate only over unique pairs where j > i to avoid double calculation
276+
iter_inds = eachindex(x_vecs[1])
284277
for j in (i+1):last(iter_inds)
285-
X2 = SA.SVector{N, FT1}(ntuple(k -> x_vecs[k][j], Val{N}()))
286-
U2 = SA.SVector{N, FT2}(ntuple(k -> u_vecs[k][j], Val{N}()))
278+
X2 = SA.SVector{N, FT1}(ntuple(k -> x_vecs[k][j], Val(N)))
279+
U2 = SA.SVector{N, FT2}(ntuple(k -> u_vecs[k][j], Val(N)))
287280

288281
distance = distance_metric(X1, X2)
289282
bin = SFH.digitize(distance, distance_bins_vec)
@@ -292,7 +285,7 @@ function calculate_structure_function_i(
292285
@inbounds counts[bin] += 1
293286
end
294287
end
295-
return output, counts
288+
return nothing
296289
end
297290

298291

@@ -460,26 +453,38 @@ function _calculate_structure_function_core(
460453
end
461454
distance_bins_vec[end] = distance_bins[end][2]
462455

456+
# Create thread-local buffers
457+
n_threads = Threads.nthreads()
458+
local_outputs = [zeros(n_threads) for _ in 1:n_threads] # Wait, N3 length...
459+
local_outputs = [zeros(N3) for _ in 1:n_threads]
460+
local_counts = [zeros(N3) for _ in 1:n_threads]
461+
463462
if verbose
464-
@info("calculating structure function")
463+
@info("calculating structure function (parallel reduction)")
465464
end
466465

467-
iter_inds = axes(x_arr,2) # these should all match..., idk if doing 1:N2 is faster but the indexing could be shifted...
468-
# iter_inds = axes(x_arr,1) # these should all match..., idk if doing 1:N2 is faster but the indexing could be shifted...
466+
iter_inds = axes(x_arr, 2)
469467
N = size(x_arr, 1)
470-
vN = if N == 1 Val{1}() elseif N == 2 Val{2}() elseif N == 3 Val{3}() else Val(N) end
471-
PM.@showprogress enabled = show_progress for i in iter_inds
472-
_output, _counts = calculate_structure_function_i(
473-
vN,
474-
structure_function_type,
475-
i,
476-
x_arr,
477-
u_arr,
478-
distance_bins_vec;
479-
distance_metric = distance_metric,
480-
)
481-
output .+= _output
482-
counts .+= _counts
468+
vN = if N == 1 Val(1) elseif N == 2 Val(2) elseif N == 3 Val(3) else Val(N) end
469+
470+
Threads.@threads for i in iter_inds
471+
tid = Threads.threadid()
472+
calculate_structure_function_i!(
473+
local_outputs[tid],
474+
local_counts[tid],
475+
vN,
476+
structure_function_type,
477+
i,
478+
x_arr,
479+
u_arr,
480+
distance_bins_vec;
481+
distance_metric = distance_metric,
482+
)
483+
end
484+
485+
for tid in 1:n_threads
486+
output .+= local_outputs[tid]
487+
counts .+= local_counts[tid]
483488
end
484489

485490
if RSAC # just return the sums and the counts, don't take the mean in each bin...
@@ -492,45 +497,36 @@ function _calculate_structure_function_core(
492497
end
493498
end
494499

495-
function calculate_structure_function_i(
500+
function calculate_structure_function_i!(
501+
output::AbstractVector{FT},
502+
counts::AbstractVector{FT},
496503
::Val{N},
497504
structure_function_type::SFT.AbstractStructureFunctionType,
498505
i::Int,
499506
x_arr::AbstractArray{FT1},
500507
u_arr::AbstractArray{FT2},
501508
distance_bins_vec::AbstractVector{FT3};
502509
distance_metric::DI.PreMetric = DI.Euclidean(),
503-
) where {N, FT1 <: Number, FT2 <: Number, FT3 <: Number}
510+
) where {FT, N, FT1 <: Number, FT2 <: Number, FT3 <: Number}
504511
N3 = length(distance_bins_vec)
505-
506-
# preallocate output as vector of length of distance_bins (vector so it's mutable)
507-
output = zeros(N3 - 1)
508-
counts = zeros(N3 - 1)
509-
510-
iter_inds = axes(x_arr,2)
512+
iter_inds = axes(x_arr, 2)
511513

512-
# PERFORMANCE NOTE: Converting to SVector at the loop boundary (hoisting)
513-
# is critical for zero heap-allocations in the inner loop.
514-
# v2 - v1 on SubArrays/Vectors triggers heap temporaries via broadcasting.
515-
# SVector arithmetic is handled entirely on the stack/CPU registers.
516-
# We load U1/X1 once here to avoid redundant memory access in the j-loop.
517-
X1 = SA.SVector{N, FT1}(ntuple(k -> x_arr[k, i], Val{N}()))
518-
U1 = SA.SVector{N, FT2}(ntuple(k -> u_arr[k, i], Val{N}()))
519-
520-
for j in iter_inds
521-
if i != j
522-
X2 = SA.SVector{N, FT1}(ntuple(k -> x_arr[k, j], Val{N}()))
523-
U2 = SA.SVector{N, FT2}(ntuple(k -> u_arr[k, j], Val{N}()))
514+
X1 = SA.SVector{N, FT1}(ntuple(k -> x_arr[k, i], Val(N)))
515+
U1 = SA.SVector{N, FT2}(ntuple(k -> u_arr[k, i], Val(N)))
524516

525-
distance = distance_metric(X1, X2)
526-
bin = SFH.digitize(distance, distance_bins_vec)
527-
if 1 <= bin < N3
528-
@inbounds output[bin] += structure_function_type(U2 - U1, SFH.(X1, X2))
529-
@inbounds counts[bin] += 1
530-
end
517+
# Restore symmetry (j > i) for $O(N^2/2)$ performance
518+
for j in (i+1):last(iter_inds)
519+
X2 = SA.SVector{N, FT1}(ntuple(k -> x_arr[k, j], Val(N)))
520+
U2 = SA.SVector{N, FT2}(ntuple(k -> u_arr[k, j], Val(N)))
521+
522+
distance = distance_metric(X1, X2)
523+
bin = SFH.digitize(distance, distance_bins_vec)
524+
if 1 <= bin < N3
525+
@inbounds output[bin] += structure_function_type(U2 - U1, SFH.(X1, X2))
526+
@inbounds counts[bin] += 1
531527
end
532528
end
533-
return output, counts
529+
return nothing
534530
end
535531

536532

src/HelperFunctions.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,7 @@ end
6262
6363
Convert a collection of bin edges (tuples) to their midpoints. Internal helper.
6464
"""
65-
function midpoints(bins::AbstractVector{Tuple{T, T}}) where {T <: Number}
66-
return [(b[1] + b[2]) / 2 for b in bins]
67-
end
68-
65+
midpoints(bins) = map(b -> (b[1]+b[2])/2, bins)
6966
midpoints(v::AbstractVector{<:Number}) = v
7067

7168
@inline function digitize(x::AbstractVector, bins::AbstractVector)

test/benchmark_j8_baseline.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using BenchmarkTools
2+
using StructureFunctions: StructureFunctions as SF
3+
using StaticArrays: StaticArrays as SA
4+
5+
function run_benchmark()
6+
N = 2000
7+
FT = Float64
8+
x = ([rand(FT) for _ in 1:N], [rand(FT) for _ in 1:N])
9+
u = ([rand(FT) for _ in 1:N], [rand(FT) for _ in 1:N])
10+
bins = SA.SVector{10}([(i*0.1, (i+1)*0.1) for i in 0:9])
11+
12+
println("--- Benchmark: Tuple Variant (N=$N, Bins=10) ---")
13+
b_tuple = @benchmark SF.calculate_structure_function(SF.L2SF, $x, $u, $bins; verbose=false, show_progress=false)
14+
display(b_tuple)
15+
16+
x_arr = rand(FT, 2, N)
17+
u_arr = rand(FT, 2, N)
18+
println("\n--- Benchmark: Array Variant (N=$N, Bins=10) ---")
19+
b_array = @benchmark SF.calculate_structure_function(SF.L2SF, $x_arr, $u_arr, $bins; verbose=false, show_progress=false)
20+
display(b_array)
21+
end
22+
23+
run_benchmark()

0 commit comments

Comments
 (0)