diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl index 51803892..cdacdfbe 100644 --- a/src/interpolation_methods.jl +++ b/src/interpolation_methods.jl @@ -130,6 +130,108 @@ function _interpolate(A::LinearInterpolation{<:AbstractArray}, t::Number, iguess return A.u[ax..., idx] + slope * Δt end +# Sorted-batch fast path for LinearInterpolation. +function (A::LinearInterpolation{<:AbstractVector{<:Number}})( + out::AbstractVector, tt::AbstractVector + ) + if length(out) != length(tt) + throw( + DimensionMismatch( + "number of evaluation points and length of the result vector must be equal" + ) + ) + end + if _linear_eval_fast_applicable(A) && !any(isnan, A.u) && issorted(tt) + _linear_eval_sorted!(out, A, tt) + else + map!(A, out, tt) + end + return out +end + +@inline function _linear_eval_fast_applicable(A::LinearInterpolation) + el = A.extrapolation_left + er = A.extrapolation_right + el_ok = el == ExtrapolationType.None || + el == ExtrapolationType.Constant || + el == ExtrapolationType.Linear || + el == ExtrapolationType.Extension + er_ok = er == ExtrapolationType.None || + er == ExtrapolationType.Constant || + er == ExtrapolationType.Linear || + er == ExtrapolationType.Extension + return el_ok && er_ok +end + +function _linear_eval_sorted!( + out::AbstractVector, A::LinearInterpolation{<:AbstractVector{<:Number}}, tt::AbstractVector + ) + u = A.u + t = A.t + el = A.extrapolation_left + er = A.extrapolation_right + n = length(t) + m = length(tt) + t1 = @inbounds t[1] + tn = @inbounds t[n] + + i = 1 + + # Left extrapolation + if el == ExtrapolationType.None + @inbounds if i <= m && tt[i] < t1 + throw(LeftExtrapolationError()) + end + elseif el == ExtrapolationType.Constant + u1 = @inbounds u[1] + @inbounds while i <= m && tt[i] < t1 + out[i] = u1 + i += 1 + end + else # Linear or Extension — both reduce to the first-segment line + u1 = @inbounds u[1] + slope1 = get_parameters(A, 1) + @inbounds while i <= m && tt[i] < t1 + out[i] = u1 + slope1 * (tt[i] - t1) + i += 1 + end + end + + # Interior: per-segment outer loop, hoist coefficients + @inbounds for idx in 1:(n - 1) + ti = t[idx] + tip1 = t[idx + 1] + ui = u[idx] + slope = get_parameters(A, idx) + while i <= m && tt[i] <= tip1 + out[i] = ui + slope * (tt[i] - ti) + i += 1 + end + end + + # Right extrapolation + if er == ExtrapolationType.None + @inbounds if i <= m + throw(RightExtrapolationError()) + end + elseif er == ExtrapolationType.Constant + un = @inbounds u[n] + @inbounds while i <= m + out[i] = un + i += 1 + end + else # Linear or Extension — both reduce to the last-segment line + un = @inbounds u[n] + slope_n = get_parameters(A, n - 1) + @inbounds while i <= m + out[i] = un + slope_n * (tt[i] - tn) + i += 1 + end + end + + return nothing +end + # Quadratic Interpolation function _interpolate(A::QuadraticInterpolation, t::Number, iguess) idx = get_idx(A, t, iguess) @@ -405,6 +507,149 @@ function _interpolate(A::CubicSpline{<:AbstractArray}, t::Number, iguess) return I + C + D end +# Sorted-batch fast path for CubicSpline. +function (A::CubicSpline{<:AbstractVector{<:Number}})( + out::AbstractVector, tt::AbstractVector + ) + if length(out) != length(tt) + throw( + DimensionMismatch( + "number of evaluation points and length of the result vector must be equal" + ) + ) + end + if _cubicspline_eval_fast_applicable(A) && issorted(tt) + _cubicspline_eval_sorted!(out, A, tt) + else + map!(A, out, tt) + end + return out +end + +@inline function _cubicspline_eval_fast_applicable(A::CubicSpline) + el = A.extrapolation_left + er = A.extrapolation_right + el_ok = el == ExtrapolationType.None || + el == ExtrapolationType.Constant || + el == ExtrapolationType.Linear || + el == ExtrapolationType.Extension + er_ok = er == ExtrapolationType.None || + er == ExtrapolationType.Constant || + er == ExtrapolationType.Linear || + er == ExtrapolationType.Extension + return el_ok && er_ok +end + +@inline function _cubicspline_segment_eval( + ttt, ti, tip1, zi, zip1, hinv6, c1, c2 + ) + dt1 = ttt - ti + dt2 = tip1 - ttt + return (zi * dt2 * dt2 * dt2 + zip1 * dt1 * dt1 * dt1) * hinv6 + + c1 * dt1 + c2 * dt2 +end + +function _cubicspline_eval_sorted!( + out::AbstractVector, A::CubicSpline{<:AbstractVector{<:Number}}, tt::AbstractVector + ) + u = A.u + t = A.t + z = A.z + h = A.h + el = A.extrapolation_left + er = A.extrapolation_right + n = length(t) + m = length(tt) + t1 = @inbounds t[1] + tn = @inbounds t[n] + + i = 1 + + # Left extrapolation + if el == ExtrapolationType.None + @inbounds if i <= m && tt[i] < t1 + throw(LeftExtrapolationError()) + end + elseif el == ExtrapolationType.Constant + u1 = @inbounds u[1] + @inbounds while i <= m && tt[i] < t1 + out[i] = u1 + i += 1 + end + elseif el == ExtrapolationType.Linear + u1 = @inbounds u[1] + slope1 = _derivative(A, t1, 1) + @inbounds while i <= m && tt[i] < t1 + out[i] = u1 + slope1 * (tt[i] - t1) + i += 1 + end + else # Extension — continue the first-segment cubic + ti = t1 + tip1 = @inbounds t[2] + zi = @inbounds z[1] + zip1 = @inbounds z[2] + hinv6 = inv(6 * @inbounds(h[2])) + c1, c2 = get_parameters(A, 1) + @inbounds while i <= m && tt[i] < t1 + out[i] = _cubicspline_segment_eval( + tt[i], ti, tip1, zi, zip1, hinv6, c1, c2 + ) + i += 1 + end + end + + # Interior: per-segment outer loop with hoisted coefficients + @inbounds for idx in 1:(n - 1) + ti = t[idx] + tip1 = t[idx + 1] + zi = z[idx] + zip1 = z[idx + 1] + hinv6 = inv(6 * h[idx + 1]) + c1, c2 = get_parameters(A, idx) + while i <= m && tt[i] <= tip1 + out[i] = _cubicspline_segment_eval( + tt[i], ti, tip1, zi, zip1, hinv6, c1, c2 + ) + i += 1 + end + end + + # Right extrapolation + if er == ExtrapolationType.None + @inbounds if i <= m + throw(RightExtrapolationError()) + end + elseif er == ExtrapolationType.Constant + un = @inbounds u[n] + @inbounds while i <= m + out[i] = un + i += 1 + end + elseif er == ExtrapolationType.Linear + un = @inbounds u[n] + slope_n = _derivative(A, tn, n) + @inbounds while i <= m + out[i] = un + slope_n * (tt[i] - tn) + i += 1 + end + else # Extension — continue the last-segment cubic + ti = @inbounds t[n - 1] + tip1 = tn + zi = @inbounds z[n - 1] + zip1 = @inbounds z[n] + hinv6 = inv(6 * @inbounds(h[n])) + c1, c2 = get_parameters(A, n - 1) + @inbounds while i <= m + out[i] = _cubicspline_segment_eval( + tt[i], ti, tip1, zi, zip1, hinv6, c1, c2 + ) + i += 1 + end + end + + return nothing +end + # BSpline Curve Interpolation function _interpolate( A::BSplineInterpolation{<:AbstractVector{<:Number}}, diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 77416842..c1409cc0 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -274,6 +274,88 @@ end test_uvals = [1.0, 2.0, 4.0, 8.0] f(tvals) = LinearInterpolation(test_uvals, tvals)(3.5) @test_nowarn ForwardDiff.gradient(f, test_tvals) + + @testset "Sorted-batch evaluator" begin + u_b = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0] + t_b = collect(0.0:10.0) + + # Sorted query matches per-point path on the default (Constant) and + # each fast-path extrapolation mode paired left × right + for el in ( + ExtrapolationType.Constant, + ExtrapolationType.Linear, + ExtrapolationType.Extension, + ), + er in ( + ExtrapolationType.Constant, + ExtrapolationType.Linear, + ExtrapolationType.Extension, + ) + + A_b = LinearInterpolation( + u_b, t_b; extrapolation_left = el, extrapolation_right = er + ) + tt = collect(-2.0:0.4:12.0) + out = similar(tt) + A_b(out, tt) + for k in eachindex(tt) + @test out[k] ≈ A_b(tt[k]) + end + # Knot pass-through + outk = similar(t_b) + A_b(outk, t_b) + for k in eachindex(t_b) + @test outk[k] ≈ u_b[k] + end + end + + # Periodic/Reflective extrapolation falls back to map! + for ext in (ExtrapolationType.Periodic, ExtrapolationType.Reflective) + A_b = LinearInterpolation(u_b, t_b; extrapolation = ext) + tt = collect(-3.0:0.5:13.0) + out = similar(tt) + A_b(out, tt) + for k in eachindex(tt) + @test out[k] ≈ A_b(tt[k]) + end + end + + # ExtrapolationType.None throws on out-of-range sorted queries + A_none = LinearInterpolation(u_b, t_b) + @test_throws DataInterpolations.LeftExtrapolationError A_none( + similar([-1.0, 5.0]), [-1.0, 5.0] + ) + @test_throws DataInterpolations.RightExtrapolationError A_none( + similar([5.0, 11.0]), [5.0, 11.0] + ) + + # Unsorted query falls back to per-point path + A_b = LinearInterpolation(u_b, t_b; extrapolation = ExtrapolationType.Constant) + tt_u = [3.1, 7.7, 0.2, 5.5, 9.9] + out_u = similar(tt_u) + A_b(out_u, tt_u) + for k in eachindex(tt_u) + @test out_u[k] ≈ A_b(tt_u[k]) + end + + # NaN in u: fast path is skipped (since `any(isnan, u)` is true), + # preserving the per-point NaN semantics + u_nan = [0.0, NaN, 2.0, 3.0] + t_nan = [1.0, 2.0, 3.0, 4.0] + A_nan = LinearInterpolation(u_nan, t_nan; extrapolation = ExtrapolationType.Extension) + tt_nan = [1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0] + out_nan = similar(tt_nan) + A_nan(out_nan, tt_nan) + for k in eachindex(tt_nan) + ref = A_nan(tt_nan[k]) + @test (isnan(out_nan[k]) && isnan(ref)) || out_nan[k] ≈ ref + end + + # DimensionMismatch + @test_throws DimensionMismatch LinearInterpolation(u_b, t_b)( + zeros(3), [1.0, 2.0] + ) + end end @testset "Quadratic Interpolation" begin @@ -1056,6 +1138,83 @@ end f_test = reduce(hcat, f3d.(t_test)) @test isapprox(u_test, f_test, atol = 1.0e-2) end + + @testset "Sorted-batch evaluator" begin + u_b = [0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0] + t_b = collect(0.0:10.0) + + # Sorted query matches per-point path on each fast-path extrapolation mode + for el in ( + ExtrapolationType.Constant, + ExtrapolationType.Linear, + ExtrapolationType.Extension, + ), + er in ( + ExtrapolationType.Constant, + ExtrapolationType.Linear, + ExtrapolationType.Extension, + ) + + A_b = CubicSpline( + u_b, t_b; extrapolation_left = el, extrapolation_right = er + ) + tt = collect(-2.0:0.4:12.0) + out = similar(tt) + A_b(out, tt) + for k in eachindex(tt) + @test out[k] ≈ A_b(tt[k]) + end + # Knot pass-through + outk = similar(t_b) + A_b(outk, t_b) + for k in eachindex(t_b) + @test outk[k] ≈ u_b[k] + end + end + + # Periodic/Reflective extrapolation falls back to map! + for ext in (ExtrapolationType.Periodic, ExtrapolationType.Reflective) + A_b = CubicSpline(u_b, t_b; extrapolation = ext) + tt = collect(-3.0:0.5:13.0) + out = similar(tt) + A_b(out, tt) + for k in eachindex(tt) + @test out[k] ≈ A_b(tt[k]) + end + end + + # ExtrapolationType.None throws on out-of-range sorted queries + A_none = CubicSpline(u_b, t_b) + @test_throws DataInterpolations.LeftExtrapolationError A_none( + similar([-1.0, 5.0]), [-1.0, 5.0] + ) + @test_throws DataInterpolations.RightExtrapolationError A_none( + similar([5.0, 11.0]), [5.0, 11.0] + ) + + # Unsorted query falls back to per-point path + A_b = CubicSpline(u_b, t_b; extrapolation = ExtrapolationType.Constant) + tt_u = [3.1, 7.7, 0.2, 5.5, 9.9] + out_u = similar(tt_u) + A_b(out_u, tt_u) + for k in eachindex(tt_u) + @test out_u[k] ≈ A_b(tt_u[k]) + end + + # cache_parameters = true should also work (different get_parameters branch) + A_cache = CubicSpline( + u_b, t_b; extrapolation = ExtrapolationType.Extension, cache_parameters = true + ) + tt = collect(-2.0:0.4:12.0) + out = similar(tt) + A_cache(out, tt) + for k in eachindex(tt) + @test out[k] ≈ A_cache(tt[k]) + end + + # DimensionMismatch + @test_throws DimensionMismatch CubicSpline(u_b, t_b)(zeros(3), [1.0, 2.0]) + end end @testset "BSplines" begin