Skip to content

Commit 9f1f85c

Browse files
authored
Merge pull request #440 from SciML/os/support-ForwardDiff@1.0
support ForwardDiff@1.0
2 parents 2834eac + 0c81e8c commit 9f1f85c

3 files changed

Lines changed: 12 additions & 14 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ ChainRulesCore = "1.24"
3131
EnumX = "1.0.4"
3232
FindFirstFunctions = "1.3"
3333
FiniteDifferences = "0.12.31"
34-
ForwardDiff = "0.10.36"
34+
ForwardDiff = "0.10.36, 1"
3535
LinearAlgebra = "1.10"
3636
Optim = "1.6"
3737
PrettyTables = "2"

src/derivatives.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ function derivative(A, t, order = 1)
66
_extrapolate_derivative_right(A, t, order)
77
else
88
iguess = A.iguesser
9-
(order == 1) ? _derivative(A, t, iguess) :
10-
ForwardDiff.derivative(t -> begin
11-
_derivative(A, t, iguess)
12-
end, t)
9+
if order == 1
10+
return _derivative(A, t, iguess)
11+
end
12+
return ForwardDiff.derivative(t -> begin
13+
-_derivative(A, -t, iguess)
14+
end, -t) # take derivative backwards in t to make it a left rather than right derivative
1315
end
1416
end
1517

@@ -313,9 +315,8 @@ function _derivative(
313315
ducum = (A.c[ax_u..., 2] - A.c[ax_u..., 1]) / (A.k[A.d + 2])
314316
else
315317
for i in 1:(A.h - 1)
316-
ducum = ducum +
317-
sc[i + 1] * (A.c[ax_u..., i + 1] - A.c[ax_u..., i]) /
318-
(A.k[i + A.d + 1] - A.k[i + 1])
318+
ducum += sc[i + 1] * (A.c[ax_u..., i + 1] - A.c[ax_u..., i]) /
319+
(A.k[i + A.d + 1] - A.k[i + 1])
319320
end
320321
end
321322
ducum * A.d * scale

test/derivative_tests.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Symbolics
66
using StableRNGs
77
using RegularizationTools
88
using Optim
9-
using ForwardDiff
9+
import ForwardDiff
1010
using LinearAlgebra
1111

1212
function test_derivatives(method; args = [], kwargs = [], name::String)
@@ -35,11 +35,8 @@ function test_derivatives(method; args = [], kwargs = [], name::String)
3535

3636
# Interpolation transition points
3737
for _t in t[2:(end - 1)]
38-
if func isa Union{BSplineInterpolation, BSplineApprox,
39-
CubicHermiteSpline}
40-
fdiff = forward_fdm(5, 1; geom = true)(func, _t)
41-
fdiff2 = forward_fdm(5, 1; geom = true)(t -> derivative(func, t), _t)
42-
elseif func isa SmoothedConstantInterpolation
38+
if func isa Union{SmoothedConstantInterpolation, BSplineInterpolation, BSplineApprox}
39+
# TODO fix interpolations
4340
continue
4441
else
4542
fdiff = backward_fdm(5, 1; geom = true)(func, _t)

0 commit comments

Comments
 (0)