From eb4cccc7ff8d6402dfd587a06c2c2450022636dd Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Tue, 20 Aug 2024 21:28:12 +0200 Subject: [PATCH 1/3] Add draft of DataInterpolations extension --- ...SparseConnectivityDataInterpolationsExt.jl | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 ext/SparseConnectivityDataInterpolationsExt.jl diff --git a/ext/SparseConnectivityDataInterpolationsExt.jl b/ext/SparseConnectivityDataInterpolationsExt.jl new file mode 100644 index 00000000..40a644da --- /dev/null +++ b/ext/SparseConnectivityDataInterpolationsExt.jl @@ -0,0 +1,24 @@ +module SparseConnectivityTracerDataInterpolationsExt + +if isdefined(Base, :get_extension) + import SparseConnectivityTracer as SCT + using DataInterpolations +else + import ..SparseConnectivitytracer as SCT + using ..DataInterpolations +end + +# In general the first and second derivatives are non-zero +SCT.is_der1_zero_global(::DataInterpolations.AbstractInterpolation) = false +SCT.is_der2_zero_global(::DataInterpolations.AbstractInterpolation) = false + +# Special cases +SCT.is_der1_zero_global(::ConstantInterpolation) = true +SCT.is_der2_zero_global(::ConstantInterpolation) = true +SCT.is_der2_zero_global(::LinearInterpolation) = true + +# To do: derivative, integral + +eval(SCT.overload_gradient_1_to_1(:DataInterpolations, AbstractInterpolation)) + +end # module SparseConnectivityTracerDataInterpolationsExt \ No newline at end of file From 22f9a2af63ca47b57ae050daa383821c0e67a4be Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Thu, 10 Jul 2025 16:45:07 +0200 Subject: [PATCH 2/3] add SmoothedLinearInterpolation support --- ...ConnectivityTracerDataInterpolationsExt.jl | 94 ++++++++++--------- 1 file changed, 48 insertions(+), 46 deletions(-) diff --git a/ext/SparseConnectivityTracerDataInterpolationsExt.jl b/ext/SparseConnectivityTracerDataInterpolationsExt.jl index 2f668d1b..04fbbbfe 100644 --- a/ext/SparseConnectivityTracerDataInterpolationsExt.jl +++ b/ext/SparseConnectivityTracerDataInterpolationsExt.jl @@ -9,6 +9,7 @@ using FillArrays: Fill # from FillArrays.jl using DataInterpolations: AbstractInterpolation, LinearInterpolation, + SmoothedLinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation, @@ -31,41 +32,41 @@ using DataInterpolations: # https://github.com/adrhill/SparseConnectivityTracer.jl/pull/234#discussion_r2031038566 function _sct_interpolate( - ::AbstractInterpolation, - uType::Type{<:AbstractVector{<:Number}}, - t::GradientTracer, - is_der_1_zero, - is_der_2_zero, - ) + ::AbstractInterpolation, + uType::Type{<:AbstractVector{<:Number}}, + t::GradientTracer, + is_der_1_zero, + is_der_2_zero, +) return gradient_tracer_1_to_1(t, is_der_1_zero) end function _sct_interpolate( - ::AbstractInterpolation, - uType::Type{<:AbstractVector{<:Number}}, - t::HessianTracer, - is_der_1_zero, - is_der_2_zero, - ) + ::AbstractInterpolation, + uType::Type{<:AbstractVector{<:Number}}, + t::HessianTracer, + is_der_1_zero, + is_der_2_zero, +) return hessian_tracer_1_to_1(t, is_der_1_zero, is_der_2_zero) end function _sct_interpolate( - interp::AbstractInterpolation, - uType::Type{<:AbstractMatrix{<:Number}}, - t::GradientTracer, - is_der_1_zero, - is_der_2_zero, - ) + interp::AbstractInterpolation, + uType::Type{<:AbstractMatrix{<:Number}}, + t::GradientTracer, + is_der_1_zero, + is_der_2_zero, +) t = gradient_tracer_1_to_1(t, is_der_1_zero) N = only(output_size(interp)) return Fill(t, N) end function _sct_interpolate( - interp::AbstractInterpolation, - uType::Type{<:AbstractMatrix{<:Number}}, - t::HessianTracer, - is_der_1_zero, - is_der_2_zero, - ) + interp::AbstractInterpolation, + uType::Type{<:AbstractMatrix{<:Number}}, + t::HessianTracer, + is_der_1_zero, + is_der_2_zero, +) t = hessian_tracer_1_to_1(t, is_der_1_zero, is_der_2_zero) N = only(output_size(interp)) return Fill(t, N) @@ -79,40 +80,41 @@ end # all interpolations have a non-zero second derivative at some point in the input domain. for (I, is_der1_zero, is_der2_zero) in ( - (:ConstantInterpolation, true, true), - (:LinearInterpolation, false, true), - (:QuadraticInterpolation, false, false), - (:LagrangeInterpolation, false, false), - (:AkimaInterpolation, false, false), - (:QuadraticSpline, false, false), - (:CubicSpline, false, false), - (:BSplineInterpolation, false, false), - (:BSplineApprox, false, false), - (:CubicHermiteSpline, false, false), - (:QuinticHermiteSpline, false, false), - ) + (:ConstantInterpolation, true, true), + (:LinearInterpolation, false, true), + (:SmoothedLinearInterpolation, false, false), + (:QuadraticInterpolation, false, false), + (:LagrangeInterpolation, false, false), + (:AkimaInterpolation, false, false), + (:QuadraticSpline, false, false), + (:CubicSpline, false, false), + (:BSplineInterpolation, false, false), + (:BSplineApprox, false, false), + (:CubicHermiteSpline, false, false), + (:QuinticHermiteSpline, false, false), +) @eval function (interp::$(I){uType})( - t::AbstractTracer - ) where {uType <: AbstractArray{<:Number}} + t::AbstractTracer + ) where {uType<:AbstractArray{<:Number}} return _sct_interpolate(interp, uType, t, $is_der1_zero, $is_der2_zero) end end # Some Interpolations require custom overloads on `Dual` due to mutation of caches. for I in ( - :LagrangeInterpolation, - :BSplineInterpolation, - :BSplineApprox, - :CubicHermiteSpline, - :QuinticHermiteSpline, - ) - @eval function (interp::$(I){uType})(d::Dual) where {uType <: AbstractVector} + :LagrangeInterpolation, + :BSplineInterpolation, + :BSplineApprox, + :CubicHermiteSpline, + :QuinticHermiteSpline, +) + @eval function (interp::$(I){uType})(d::Dual) where {uType<:AbstractVector} p = interp(primal(d)) t = interp(tracer(d)) return Dual(p, t) end - @eval function (interp::$(I){uType})(d::Dual) where {uType <: AbstractMatrix} + @eval function (interp::$(I){uType})(d::Dual) where {uType<:AbstractMatrix} p = interp(primal(d)) t = interp(tracer(d)) return Dual.(p, t) From a7618fce7c82cada89dcfa64e243a0c343a34a6e Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Thu, 10 Jul 2025 16:47:46 +0200 Subject: [PATCH 3/3] fix --- ...SparseConnectivityDataInterpolationsExt.jl | 24 ----- ...ConnectivityTracerDataInterpolationsExt.jl | 94 +++++++++---------- 2 files changed, 47 insertions(+), 71 deletions(-) delete mode 100644 ext/SparseConnectivityDataInterpolationsExt.jl diff --git a/ext/SparseConnectivityDataInterpolationsExt.jl b/ext/SparseConnectivityDataInterpolationsExt.jl deleted file mode 100644 index 40a644da..00000000 --- a/ext/SparseConnectivityDataInterpolationsExt.jl +++ /dev/null @@ -1,24 +0,0 @@ -module SparseConnectivityTracerDataInterpolationsExt - -if isdefined(Base, :get_extension) - import SparseConnectivityTracer as SCT - using DataInterpolations -else - import ..SparseConnectivitytracer as SCT - using ..DataInterpolations -end - -# In general the first and second derivatives are non-zero -SCT.is_der1_zero_global(::DataInterpolations.AbstractInterpolation) = false -SCT.is_der2_zero_global(::DataInterpolations.AbstractInterpolation) = false - -# Special cases -SCT.is_der1_zero_global(::ConstantInterpolation) = true -SCT.is_der2_zero_global(::ConstantInterpolation) = true -SCT.is_der2_zero_global(::LinearInterpolation) = true - -# To do: derivative, integral - -eval(SCT.overload_gradient_1_to_1(:DataInterpolations, AbstractInterpolation)) - -end # module SparseConnectivityTracerDataInterpolationsExt \ No newline at end of file diff --git a/ext/SparseConnectivityTracerDataInterpolationsExt.jl b/ext/SparseConnectivityTracerDataInterpolationsExt.jl index 04fbbbfe..f6499f66 100644 --- a/ext/SparseConnectivityTracerDataInterpolationsExt.jl +++ b/ext/SparseConnectivityTracerDataInterpolationsExt.jl @@ -32,41 +32,41 @@ using DataInterpolations: # https://github.com/adrhill/SparseConnectivityTracer.jl/pull/234#discussion_r2031038566 function _sct_interpolate( - ::AbstractInterpolation, - uType::Type{<:AbstractVector{<:Number}}, - t::GradientTracer, - is_der_1_zero, - is_der_2_zero, -) + ::AbstractInterpolation, + uType::Type{<:AbstractVector{<:Number}}, + t::GradientTracer, + is_der_1_zero, + is_der_2_zero, + ) return gradient_tracer_1_to_1(t, is_der_1_zero) end function _sct_interpolate( - ::AbstractInterpolation, - uType::Type{<:AbstractVector{<:Number}}, - t::HessianTracer, - is_der_1_zero, - is_der_2_zero, -) + ::AbstractInterpolation, + uType::Type{<:AbstractVector{<:Number}}, + t::HessianTracer, + is_der_1_zero, + is_der_2_zero, + ) return hessian_tracer_1_to_1(t, is_der_1_zero, is_der_2_zero) end function _sct_interpolate( - interp::AbstractInterpolation, - uType::Type{<:AbstractMatrix{<:Number}}, - t::GradientTracer, - is_der_1_zero, - is_der_2_zero, -) + interp::AbstractInterpolation, + uType::Type{<:AbstractMatrix{<:Number}}, + t::GradientTracer, + is_der_1_zero, + is_der_2_zero, + ) t = gradient_tracer_1_to_1(t, is_der_1_zero) N = only(output_size(interp)) return Fill(t, N) end function _sct_interpolate( - interp::AbstractInterpolation, - uType::Type{<:AbstractMatrix{<:Number}}, - t::HessianTracer, - is_der_1_zero, - is_der_2_zero, -) + interp::AbstractInterpolation, + uType::Type{<:AbstractMatrix{<:Number}}, + t::HessianTracer, + is_der_1_zero, + is_der_2_zero, + ) t = hessian_tracer_1_to_1(t, is_der_1_zero, is_der_2_zero) N = only(output_size(interp)) return Fill(t, N) @@ -80,41 +80,41 @@ end # all interpolations have a non-zero second derivative at some point in the input domain. for (I, is_der1_zero, is_der2_zero) in ( - (:ConstantInterpolation, true, true), - (:LinearInterpolation, false, true), - (:SmoothedLinearInterpolation, false, false), - (:QuadraticInterpolation, false, false), - (:LagrangeInterpolation, false, false), - (:AkimaInterpolation, false, false), - (:QuadraticSpline, false, false), - (:CubicSpline, false, false), - (:BSplineInterpolation, false, false), - (:BSplineApprox, false, false), - (:CubicHermiteSpline, false, false), - (:QuinticHermiteSpline, false, false), -) + (:ConstantInterpolation, true, true), + (:LinearInterpolation, false, true), + (:SmoothedLinearInterpolation, false, false), + (:QuadraticInterpolation, false, false), + (:LagrangeInterpolation, false, false), + (:AkimaInterpolation, false, false), + (:QuadraticSpline, false, false), + (:CubicSpline, false, false), + (:BSplineInterpolation, false, false), + (:BSplineApprox, false, false), + (:CubicHermiteSpline, false, false), + (:QuinticHermiteSpline, false, false), + ) @eval function (interp::$(I){uType})( - t::AbstractTracer - ) where {uType<:AbstractArray{<:Number}} + t::AbstractTracer + ) where {uType <: AbstractArray{<:Number}} return _sct_interpolate(interp, uType, t, $is_der1_zero, $is_der2_zero) end end # Some Interpolations require custom overloads on `Dual` due to mutation of caches. for I in ( - :LagrangeInterpolation, - :BSplineInterpolation, - :BSplineApprox, - :CubicHermiteSpline, - :QuinticHermiteSpline, -) - @eval function (interp::$(I){uType})(d::Dual) where {uType<:AbstractVector} + :LagrangeInterpolation, + :BSplineInterpolation, + :BSplineApprox, + :CubicHermiteSpline, + :QuinticHermiteSpline, + ) + @eval function (interp::$(I){uType})(d::Dual) where {uType <: AbstractVector} p = interp(primal(d)) t = interp(tracer(d)) return Dual(p, t) end - @eval function (interp::$(I){uType})(d::Dual) where {uType<:AbstractMatrix} + @eval function (interp::$(I){uType})(d::Dual) where {uType <: AbstractMatrix} p = interp(primal(d)) t = interp(tracer(d)) return Dual.(p, t)