Skip to content

Commit 9cb488b

Browse files
committed
Add modified predict function from SmoothingSpline for ReverseDiff.TrackedReal used by adjoint
1 parent c7290c9 commit 9cb488b

1 file changed

Lines changed: 35 additions & 0 deletions

File tree

src/Tools.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,38 @@ function getspline(xs,vals;s=1e-10)
4848
F(x::T) where {T} = _predict(smspl,x)
4949
return F
5050
end
51+
52+
# obtained from SmoothingSpline.jl and modified the input typing for ReverseDiff TrackedReal
53+
function _predict(spl::SmoothingSpline{T}, x::T1) where {T<:Union{Float32, Float64},T1}
54+
n = length(spl.Xdesign)
55+
idxl = searchsortedlast(spl.Xdesign, x)
56+
idxr = idxl + 1
57+
if idxl == 0 # linear extrapolation to the left
58+
gl = spl.g[1]
59+
gr = spl.g[2]
60+
γ = spl.γ[1]
61+
xl = spl.Xdesign[1]
62+
xr = spl.Xdesign[2]
63+
gprime = (gr-gl)/(xr-xl) - 1/6*(xr-xl)*γ
64+
val = gl - (xl-x)*gprime
65+
elseif idxl == n # linear extrapolation to the right
66+
gl = spl.g[n-1]
67+
gr = spl.g[n]
68+
γ = spl.γ[n-2]
69+
xl = spl.Xdesign[n-1]
70+
xr = spl.Xdesign[n]
71+
gprime = (gr-gl)/(xr-xl) +1/6*(xr-xl)*γ
72+
val = gr + (x - xr)*gprime
73+
else # cubic interpolation
74+
xl = spl.Xdesign[idxl]
75+
xr = spl.Xdesign[idxr]
76+
γl = idxl == 1 ? zero(T) : spl.γ[idxl-1]
77+
γr = idxl == n-1 ? zero(T) : spl.γ[idxr-1]
78+
gl = spl.g[idxl]
79+
gr = spl.g[idxr]
80+
h = xr-xl
81+
val = ((x-xl)*gr + (xr-x)*gl)/h
82+
val -= 1/6*(x-xl)*(xr-x)*((1 + (x-xl)/h)*γr + (1+ (xr-x)/h)*γl)
83+
end
84+
val
85+
end

0 commit comments

Comments
 (0)