Skip to content

Commit 2cf9e57

Browse files
feat: update to Symbolics@7
1 parent 912d889 commit 2cf9e57

2 files changed

Lines changed: 33 additions & 10 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ SafeTestsets = "0.1"
5151
SparseConnectivityTracer = "1"
5252
StableRNGs = "1"
5353
StaticArrays = "1.9"
54-
Symbolics = "6.46"
54+
Symbolics = "6.46, 7"
5555
Test = "1.10"
5656
Unitful = "1.21.1"
5757
Zygote = "0.6.77, 0.7"

ext/DataInterpolationsSymbolicsExt.jl

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,40 @@ using Symbolics: Num, unwrap, SymbolicUtils
88
@register_symbolic (interp::AbstractInterpolation)(t)
99
Base.nameof(interp::AbstractInterpolation) = :Interpolation
1010

11-
function derivative(interp::AbstractInterpolation, t::Num, order = 1)
12-
Symbolics.wrap(SymbolicUtils.term(derivative, interp, unwrap(t), order))
13-
end
14-
SymbolicUtils.promote_symtype(::typeof(derivative), _...) = Real
11+
@static if pkgversion(Symbolics) >= v"7"
12+
@register_symbolic derivative(interp::AbstractInterpolation, t, order::Integer) false
13+
function SymbolicUtils.promote_symtype(::typeof(derivative), Ti::SymbolicUtils.TypeT,
14+
Tt::SymbolicUtils.TypeT,
15+
To::SymbolicUtils.TypeT)
16+
@assert Ti <: AbstractInterpolation
17+
@assert Tt <: Real
18+
@assert To <: Integer
19+
Real
20+
end
21+
function SymbolicUtils.promote_shape(::typeof(derivative),
22+
@nospecialize(shi::SymbolicUtils.ShapeT),
23+
@nospecialize(sht::SymbolicUtils.ShapeT),
24+
@nospecialize(sho::SymbolicUtils.ShapeT))
25+
@assert !SymbolicUtils.is_array_shape(shi)
26+
@assert !SymbolicUtils.is_array_shape(sht)
27+
@assert !SymbolicUtils.is_array_shape(sho)
28+
return SymbolicUtils.ShapeVecT()
29+
end
1530

16-
function Symbolics.derivative(::typeof(derivative), args::NTuple{3, Any}, ::Val{2})
17-
Symbolics.unwrap(derivative(args[1], Symbolics.wrap(args[2]), args[3] + 1))
18-
end
31+
@register_derivative derivative(interp, t, ord) 2 derivative(interp, t, ord + 1)
32+
@register_derivative (interp::AbstractInterpolation)(t) 1 derivative(interp, t, 1)
33+
else
34+
function derivative(interp::AbstractInterpolation, t::Num, order = 1)
35+
Symbolics.wrap(SymbolicUtils.term(derivative, interp, unwrap(t), order))
36+
end
37+
SymbolicUtils.promote_symtype(::typeof(derivative), _...) = Real
38+
function Symbolics.derivative(::typeof(derivative), args::NTuple{3, Any}, ::Val{2})
39+
Symbolics.unwrap(derivative(args[1], Symbolics.wrap(args[2]), args[3] + 1))
40+
end
1941

20-
function Symbolics.derivative(interp::AbstractInterpolation, args::NTuple{1, Any}, ::Val{1})
21-
Symbolics.unwrap(derivative(interp, Symbolics.wrap(args[1])))
42+
function Symbolics.derivative(interp::AbstractInterpolation, args::NTuple{1, Any}, ::Val{1})
43+
Symbolics.unwrap(derivative(interp, Symbolics.wrap(args[1])))
44+
end
2245
end
2346

2447
end # module

0 commit comments

Comments
 (0)