Skip to content

Commit 96ab7c8

Browse files
Merge pull request #526 from ChrisRackauckas-Claude/akima-perf-opts
Optimize AkimaInterpolation constructor (3.6-4.1x) and add sorted-batch evaluator (8-9x)
2 parents 3ac4780 + b92c4d4 commit 96ab7c8

3 files changed

Lines changed: 269 additions & 34 deletions

File tree

src/interpolation_caches.jl

Lines changed: 65 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,66 @@ struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T} <:
288288
end
289289
end
290290

291+
# In-place scalar kernel for computing the Akima / makima coefficients.
292+
# Allocates a single length-(n+3) buffer for the padded divided differences;
293+
# every other intermediate (dm, f1, f2, f12, w1, w2, ind, b-default) from the
294+
# original vectorized formulation is fused into a scalar pass.
295+
function _akima_init!(
296+
b::AbstractVector{T}, c::AbstractVector{T}, d::AbstractVector{T},
297+
u::AbstractVector, t::AbstractVector, ::Val{modified}
298+
) where {T, modified}
299+
n = length(u)
300+
m = Vector{T}(undef, n + 3)
301+
@inbounds begin
302+
for i in 1:(n - 1)
303+
m[i + 2] = (u[i + 1] - u[i]) / (t[i + 1] - t[i])
304+
end
305+
m[2] = 2 * m[3] - m[4]
306+
m[1] = 2 * m[2] - m[3]
307+
m[n + 2] = 2 * m[n + 1] - m[n]
308+
m[n + 3] = 2 * m[n + 2] - m[n + 1]
309+
310+
# First pass: maximum weight, used as the small-weight cutoff
311+
wmax = zero(T)
312+
for i in 1:n
313+
if modified
314+
w1 = abs(m[i + 3] - m[i + 2]) + abs(m[i + 3] + m[i + 2]) / 2
315+
w2 = abs(m[i + 1] - m[i]) + abs(m[i + 1] + m[i]) / 2
316+
else
317+
w1 = abs(m[i + 3] - m[i + 2])
318+
w2 = abs(m[i + 1] - m[i])
319+
end
320+
w12 = w1 + w2
321+
wmax = ifelse(w12 > wmax, w12, wmax)
322+
end
323+
tol = T(1.0e-9) * wmax
324+
325+
# Second pass: coefficients
326+
for i in 1:n
327+
if modified
328+
w1 = abs(m[i + 3] - m[i + 2]) + abs(m[i + 3] + m[i + 2]) / 2
329+
w2 = abs(m[i + 1] - m[i]) + abs(m[i + 1] + m[i]) / 2
330+
bdefault = (m[i + 1] + m[i + 2]) / 2
331+
else
332+
w1 = abs(m[i + 3] - m[i + 2])
333+
w2 = abs(m[i + 1] - m[i])
334+
bdefault = (m[i + 3] + m[i]) / 2
335+
end
336+
w12 = w1 + w2
337+
b[i] = w12 > tol ?
338+
(w1 * m[i + 1] + w2 * m[i + 2]) / w12 :
339+
bdefault
340+
end
341+
342+
for i in 1:(n - 1)
343+
dt = t[i + 1] - t[i]
344+
c[i] = (3 * m[i + 2] - 2 * b[i] - b[i + 1]) / dt
345+
d[i] = (b[i] + b[i + 1] - 2 * m[i + 2]) / (dt * dt)
346+
end
347+
end
348+
return nothing
349+
end
350+
291351
function AkimaInterpolation(
292352
u, t; modified::Bool = false,
293353
extrapolation::ExtrapolationType.T = ExtrapolationType.None,
@@ -301,40 +361,11 @@ function AkimaInterpolation(
301361
u, t = munge_data(u, t)
302362
linear_lookup = seems_linear(assume_linear_t, t)
303363
n = length(t)
304-
dt = diff(t)
305-
m = Array{eltype(u)}(undef, n + 3)
306-
m[3:(end - 2)] = diff(u) ./ dt
307-
m[2] = 2m[3] - m[4]
308-
m[1] = 2m[2] - m[3]
309-
m[end - 1] = 2m[end - 2] - m[end - 3]
310-
m[end] = 2m[end - 1] - m[end - 2]
311-
312-
if modified
313-
# Modified Akima (makima): adds |m_{i+1} + m_i| / 2 to each weight, which
314-
# reduces overshoot on flat regions. The simple-average fallback still
315-
# guards the case where all four neighboring slopes vanish.
316-
w1 = abs.(m[4:end] .- m[3:(end - 1)]) .+
317-
abs.(m[4:end] .+ m[3:(end - 1)]) ./ 2
318-
w2 = abs.(m[2:(end - 2)] .- m[1:(end - 3)]) .+
319-
abs.(m[2:(end - 2)] .+ m[1:(end - 3)]) ./ 2
320-
w12 = w1 .+ w2
321-
b = (m[2:(end - 2)] .+ m[3:(end - 1)]) ./ 2
322-
ind = findall(w12 .> 1.0e-9 * maximum(w12))
323-
b[ind] = (w1[ind] .* m[ind .+ 1] .+ w2[ind] .* m[ind .+ 2]) ./ w12[ind]
324-
else
325-
b = (m[4:end] .+ m[1:(end - 3)]) ./ 2
326-
dm = abs.(diff(m))
327-
f1 = dm[3:(n + 2)]
328-
f2 = dm[1:n]
329-
f12 = f1 + f2
330-
ind = findall(f12 .> 1.0e-9 * maximum(f12))
331-
b[ind] = (
332-
f1[ind] .* m[ind .+ 1] .+
333-
f2[ind] .* m[ind .+ 2]
334-
) ./ f12[ind]
335-
end
336-
c = (3 .* m[3:(end - 2)] .- 2 .* b[1:(end - 1)] .- b[2:end]) ./ dt
337-
d = (b[1:(end - 1)] .+ b[2:end] .- 2 .* m[3:(end - 2)]) ./ dt .^ 2
364+
T = eltype(u)
365+
b = Vector{T}(undef, n)
366+
c = Vector{T}(undef, n - 1)
367+
d = Vector{T}(undef, n - 1)
368+
_akima_init!(b, c, d, u, t, Val(modified))
338369

339370
A = AkimaInterpolation(
340371
u, t, nothing, b, c, d, extrapolation_left,

src/interpolation_methods.jl

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,134 @@ function _interpolate(A::AkimaInterpolation{<:AbstractVector}, t::Number, iguess
205205
return @evalpoly wj A.u[idx] A.b[idx] A.c[idx] A.d[idx]
206206
end
207207

208+
# Sorted-batch fast path: when the query points are already sorted (and the
209+
# extrapolation modes don't require per-point transformation), walk the knots
210+
# and queries in lockstep instead of running a binary search per query.
211+
function (A::AkimaInterpolation{<:AbstractVector})(
212+
out::AbstractVector, tt::AbstractVector
213+
)
214+
if length(out) != length(tt)
215+
throw(
216+
DimensionMismatch(
217+
"number of evaluation points and length of the result vector must be equal"
218+
)
219+
)
220+
end
221+
if _akima_eval_fast_applicable(A) && issorted(tt)
222+
_akima_eval_sorted!(out, A, tt)
223+
else
224+
map!(A, out, tt)
225+
end
226+
return out
227+
end
228+
229+
@inline function _akima_eval_fast_applicable(A::AkimaInterpolation)
230+
el = A.extrapolation_left
231+
er = A.extrapolation_right
232+
el_ok = el == ExtrapolationType.None ||
233+
el == ExtrapolationType.Constant ||
234+
el == ExtrapolationType.Linear ||
235+
el == ExtrapolationType.Extension
236+
er_ok = er == ExtrapolationType.None ||
237+
er == ExtrapolationType.Constant ||
238+
er == ExtrapolationType.Linear ||
239+
er == ExtrapolationType.Extension
240+
return el_ok && er_ok
241+
end
242+
243+
function _akima_eval_sorted!(
244+
out::AbstractVector, A::AkimaInterpolation{<:AbstractVector}, tt::AbstractVector
245+
)
246+
u = A.u
247+
t = A.t
248+
bv = A.b
249+
cv = A.c
250+
dv = A.d
251+
el = A.extrapolation_left
252+
er = A.extrapolation_right
253+
n = length(t)
254+
m = length(tt)
255+
t1 = @inbounds t[1]
256+
tn = @inbounds t[n]
257+
258+
i = 1
259+
260+
# Left extrapolation
261+
if el == ExtrapolationType.None
262+
@inbounds if i <= m && tt[i] < t1
263+
throw(LeftExtrapolationError())
264+
end
265+
elseif el == ExtrapolationType.Constant
266+
u1 = @inbounds u[1]
267+
@inbounds while i <= m && tt[i] < t1
268+
out[i] = u1
269+
i += 1
270+
end
271+
elseif el == ExtrapolationType.Linear
272+
u1 = @inbounds u[1]
273+
b1 = @inbounds bv[1]
274+
@inbounds while i <= m && tt[i] < t1
275+
out[i] = u1 + b1 * (tt[i] - t1)
276+
i += 1
277+
end
278+
else # Extension
279+
u1 = @inbounds u[1]
280+
b1 = @inbounds bv[1]
281+
c1 = @inbounds cv[1]
282+
d1 = @inbounds dv[1]
283+
@inbounds while i <= m && tt[i] < t1
284+
wj = tt[i] - t1
285+
out[i] = @evalpoly wj u1 b1 c1 d1
286+
i += 1
287+
end
288+
end
289+
290+
# Interior: walk knots in lockstep
291+
idx = 1
292+
@inbounds while i <= m && tt[i] <= tn
293+
ttt = tt[i]
294+
while idx < n - 1 && ttt > t[idx + 1]
295+
idx += 1
296+
end
297+
wj = ttt - t[idx]
298+
out[i] = @evalpoly wj u[idx] bv[idx] cv[idx] dv[idx]
299+
i += 1
300+
end
301+
302+
# Right extrapolation
303+
if er == ExtrapolationType.None
304+
@inbounds if i <= m
305+
throw(RightExtrapolationError())
306+
end
307+
elseif er == ExtrapolationType.Constant
308+
un = @inbounds u[n]
309+
@inbounds while i <= m
310+
out[i] = un
311+
i += 1
312+
end
313+
elseif er == ExtrapolationType.Linear
314+
un = @inbounds u[n]
315+
bn = @inbounds bv[n]
316+
@inbounds while i <= m
317+
out[i] = un + bn * (tt[i] - tn)
318+
i += 1
319+
end
320+
else # Extension
321+
un1 = @inbounds u[n - 1]
322+
bn1 = @inbounds bv[n - 1]
323+
cn1 = @inbounds cv[n - 1]
324+
dn1 = @inbounds dv[n - 1]
325+
tn1 = @inbounds t[n - 1]
326+
@inbounds while i <= m
327+
wj = tt[i] - tn1
328+
out[i] = @evalpoly wj un1 bn1 cn1 dn1
329+
i += 1
330+
end
331+
end
332+
333+
return nothing
334+
end
335+
208336
# Constant Interpolation
209337
function _interpolate(A::ConstantInterpolation{<:AbstractVector}, t::Number, iguess)
210338
if A.dir === :left

test/interpolation_tests.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,82 @@ end
565565
@test isfinite(DataInterpolations.derivative(A_makima, 5.0))
566566
@test isfinite(DataInterpolations.integral(A_makima, 0.0, 10.0))
567567
end
568+
569+
@testset "Sorted-batch evaluator" begin
570+
u = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0]
571+
t = collect(0.0:10.0)
572+
573+
for modified in (false, true)
574+
A = AkimaInterpolation(u, t; modified = modified)
575+
# Sorted query: fast path matches the per-point path
576+
tt = sort!([0.0, 0.5, 1.0, 2.7, 5.3, 7.9, 10.0])
577+
out = similar(tt)
578+
A(out, tt)
579+
for k in eachindex(tt)
580+
@test out[k] A(tt[k])
581+
end
582+
# Knot pass-through
583+
outk = similar(t)
584+
A(outk, t)
585+
for k in eachindex(t)
586+
@test outk[k] u[k]
587+
end
588+
# Unsorted query falls back to per-point and stays consistent
589+
tt_unsorted = [3.1, 7.7, 0.2, 5.5, 9.9]
590+
out_u = similar(tt_unsorted)
591+
A(out_u, tt_unsorted)
592+
for k in eachindex(tt_unsorted)
593+
@test out_u[k] A(tt_unsorted[k])
594+
end
595+
end
596+
597+
# Each fast-path extrapolation mode matches the per-point path
598+
for el in (
599+
ExtrapolationType.Constant,
600+
ExtrapolationType.Linear,
601+
ExtrapolationType.Extension,
602+
),
603+
er in (
604+
ExtrapolationType.Constant,
605+
ExtrapolationType.Linear,
606+
ExtrapolationType.Extension,
607+
)
608+
609+
A = AkimaInterpolation(
610+
u, t; extrapolation_left = el, extrapolation_right = er
611+
)
612+
tt = collect(-2.0:0.4:12.0)
613+
out = similar(tt)
614+
A(out, tt)
615+
for k in eachindex(tt)
616+
@test out[k] A(tt[k])
617+
end
618+
end
619+
620+
# Periodic/Reflective fall back to the map! path
621+
for ext in (ExtrapolationType.Periodic, ExtrapolationType.Reflective)
622+
A = AkimaInterpolation(u, t; extrapolation = ext)
623+
tt = collect(-3.0:0.5:13.0)
624+
out = similar(tt)
625+
A(out, tt)
626+
for k in eachindex(tt)
627+
@test out[k] A(tt[k])
628+
end
629+
end
630+
631+
# ExtrapolationType.None throws when a sorted query is out of range
632+
A_none = AkimaInterpolation(u, t)
633+
@test_throws DataInterpolations.LeftExtrapolationError A_none(
634+
similar([-1.0, 5.0]), [-1.0, 5.0]
635+
)
636+
@test_throws DataInterpolations.RightExtrapolationError A_none(
637+
similar([5.0, 11.0]), [5.0, 11.0]
638+
)
639+
640+
# DimensionMismatch
641+
A_dim = AkimaInterpolation(u, t)
642+
@test_throws DimensionMismatch A_dim(zeros(3), [1.0, 2.0])
643+
end
568644
end
569645

570646
@testset "ConstantInterpolation" begin

0 commit comments

Comments
 (0)