Skip to content

Commit 6f170af

Browse files
Merge pull request #138 from PumasAI/dw/prefix_ext
Add prefix to extension + add CRC extension
2 parents a988884 + f4a96d9 commit 6f170af

4 files changed

Lines changed: 33 additions & 16 deletions

File tree

Project.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DataInterpolations"
22
uuid = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
3-
version = "3.11"
3+
version = "3.12"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -13,10 +13,12 @@ RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182"
1313
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1414

1515
[weakdeps]
16+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1617
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1718

1819
[extensions]
19-
SymbolicsExt = "Symbolics"
20+
DataInterpolationsChainRulesCoreExt = "ChainRulesCore"
21+
DataInterpolationsSymbolicsExt = "Symbolics"
2022

2123
[compat]
2224
ChainRulesCore = "0.9.44, 0.10, 1"
@@ -29,6 +31,7 @@ Symbolics = "4"
2931
julia = "1.6"
3032

3133
[extras]
34+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3235
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
3336
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3437
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
@@ -37,4 +40,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3740
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
3841

3942
[targets]
40-
test = ["Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics"]
43+
test = ["ChainRulesCore", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics"]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
module DataInterpolationsChainRulesCoreExt
2+
3+
using DataInterpolations: _interpolate, derivative, AbstractInterpolation,
4+
LagrangeInterpolation, AkimaInterpolation, BSplineInterpolation, BSplineApprox
5+
using ChainRulesCore
6+
7+
function ChainRulesCore.rrule(
8+
::typeof(_interpolate),
9+
A::Union{LagrangeInterpolation,AkimaInterpolation,BSplineInterpolation,BSplineApprox},
10+
t::Number,
11+
)
12+
deriv = derivative(A, t)
13+
interpolate_pullback(Δ) = (NoTangent(), NoTangent(), deriv * Δ)
14+
return _interpolate(A, t), interpolate_pullback
15+
end
16+
17+
function ChainRulesCore.frule(
18+
(_, _, Δt), ::typeof(_interpolate), A::AbstractInterpolation, t::Number
19+
)
20+
return _interpolate(A, t), derivative(A, t) * Δt
21+
end
22+
23+
end # module
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module SymbolicsExt
1+
module DataInterpolationsSymbolicsExt
22

33
using DataInterpolations: AbstractInterpolation
44
using Symbolics: Num, unwrap, SymbolicUtils

src/DataInterpolations.jl

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Base.setindex!(A::AbstractInterpolation,x,i) = A.u[i] = x
1313
Base.setindex!(A::AbstractInterpolation{true},x,i) =
1414
i <= length(A.u) ? (A.u[i] = x) : (A.t[i-length(A.u)] = x)
1515

16-
using ChainRulesCore, LinearAlgebra, RecursiveArrayTools, RecipesBase, Reexport
16+
using LinearAlgebra, RecursiveArrayTools, RecipesBase, Reexport
1717
@reexport using Optim
1818

1919
include("interpolation_caches.jl")
@@ -24,20 +24,11 @@ include("derivatives.jl")
2424
include("integrals.jl")
2525
include("online.jl")
2626

27-
function ChainRulesCore.rrule(::typeof(_interpolate),
28-
A::Union{LagrangeInterpolation,AkimaInterpolation,
29-
BSplineInterpolation,BSplineApprox}, t::Number)
30-
interpolate_pullback(Δ) = (NoTangent(), NoTangent(), derivative(A, t) * Δ)
31-
return _interpolate(A, t), interpolate_pullback
32-
end
33-
34-
ChainRulesCore.frule((_, _, Δt), ::typeof(_interpolate), A::AbstractInterpolation,
35-
t::Number) = _interpolate(A, t), derivative(A, t) * Δt
36-
3727
(interp::AbstractInterpolation)(t::Number) = _interpolate(interp, t)
3828

3929
if !isdefined(Base, :get_extension)
40-
include("../ext/SymbolicsExt.jl")
30+
include("../ext/DataInterpolationsChainRulesCoreExt.jl")
31+
include("../ext/DataInterpolationsSymbolicsExt.jl")
4132
end
4233

4334
export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation,

0 commit comments

Comments
 (0)