Skip to content

Commit 433e30b

Browse files
oscardssmithChrisRackauckas
authored andcommitted
more fixes
1 parent 72f6ef8 commit 433e30b

10 files changed

Lines changed: 53 additions & 82 deletions

File tree

lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
1-
# Extract AD type parameter from algorithm, returning as Val to ensure type stability for boolean options.
2-
function _alg_autodiff(alg::OrdinaryDiffEqAlgorithm)
1+
function alg_autodiff(alg::OrdinaryDiffEqAlgorithm)
32
error("This algorithm does not have an autodifferentiation option defined.")
43
end
5-
_alg_autodiff(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm) = alg.autodiff
6-
_alg_autodiff(alg::DAEAlgorithm) = alg.autodiff
7-
_alg_autodiff(alg::OrdinaryDiffEqImplicitAlgorithm) = alg.autodiff
8-
_alg_autodiff(alg::CompositeAlgorithm) = _alg_autodiff(alg.algs[end])
9-
_alg_autodiff(alg::Union{
4+
alg_autodiff(alg::OrdinaryDiffEqAdaptiveImplicitAlgorithm) = alg.autodiff
5+
alg_autodiff(alg::DAEAlgorithm) = alg.autodiff
6+
alg_autodiff(alg::OrdinaryDiffEqImplicitAlgorithm) = alg.autodiff
7+
alg_autodiff(alg::CompositeAlgorithm) = alg_autodiff(alg.algs[end])
8+
alg_autodiff(alg::Union{
109
OrdinaryDiffEqExponentialAlgorithm,
1110
OrdinaryDiffEqAdaptiveExponentialAlgorithm,
1211
}) = alg.autodiff
1312

14-
function alg_autodiff(alg)
15-
return _alg_autodiff(alg)
16-
end
17-
1813
Base.@pure function determine_chunksize(u, alg::SciMLBase.DEAlgorithm)
1914
determine_chunksize(u, get_chunksize(alg))
2015
end

lib/OrdinaryDiffEqLinear/src/algorithms.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,7 @@ struct CayleyEuler <: OrdinaryDiffEqAlgorithm end
191191
iop = 0,
192192
"""
193193
)
194-
struct LinearExponential <:
195-
OrdinaryDiffEqExponentialAlgorithm{1, false, Val{:forward}, Val{true}, nothing}
194+
struct LinearExponential <: OrdinaryDiffEqLinearExponentialAlgorithm
196195
krylov::Symbol
197196
m::Int
198197
iop::Int

lib/OrdinaryDiffEqLowOrderRK/src/algorithms.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ struct Euler <: OrdinaryDiffEqAlgorithm end
1414
"Split Method.",
1515
"", "", ""
1616
)
17-
struct SplitEuler <:
18-
OrdinaryDiffEqExponentialAlgorithm{0, false, Val{:forward}, Val{true}, nothing} end
17+
struct SplitEuler <: OrdinaryDiffEqExponentialAlgorithm end
1918

2019
@doc explicit_rk_docstring(
2120
"The second order Heun's method. Uses embedded Euler method for adaptivity.",

lib/StochasticDiffEq/Project.toml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "6.102.0"
44
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
55

66
[deps]
7+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
78
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
89
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -25,6 +26,8 @@ StochasticDiffEqWeak = "af2a2fcd-1c36-4cbe-a6d0-5afda784a085"
2526

2627
[sources]
2728
DiffEqBase = {path = "../DiffEqBase"}
29+
OrdinaryDiffEqCore = {path = "../OrdinaryDiffEqCore"}
30+
OrdinaryDiffEqNonlinearSolve = {path = "../OrdinaryDiffEqNonlinearSolve"}
2831
StochasticDiffEqCore = {path = "../StochasticDiffEqCore"}
2932
StochasticDiffEqHighOrder = {path = "../StochasticDiffEqHighOrder"}
3033
StochasticDiffEqIIF = {path = "../StochasticDiffEqIIF"}
@@ -36,10 +39,9 @@ StochasticDiffEqMilstein = {path = "../StochasticDiffEqMilstein"}
3639
StochasticDiffEqROCK = {path = "../StochasticDiffEqROCK"}
3740
StochasticDiffEqRODE = {path = "../StochasticDiffEqRODE"}
3841
StochasticDiffEqWeak = {path = "../StochasticDiffEqWeak"}
39-
OrdinaryDiffEqCore = {path = "../OrdinaryDiffEqCore"}
40-
OrdinaryDiffEqNonlinearSolve = {path = "../OrdinaryDiffEqNonlinearSolve"}
4142

4243
[compat]
44+
ADTypes = "1.21.0"
4345
DiffEqBase = "6.187"
4446
DiffEqNoiseProcess = "5.13"
4547
LinearAlgebra = "1.6"
@@ -69,18 +71,18 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
6971
JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
7072
LevyArea = "2d8b4e74-eb68-11e8-0fb9-d5eb67b50637"
7173
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
74+
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
75+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
76+
OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b"
7277
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
7378
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
7479
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
7580
SDEProblemLibrary = "c72e72a9-a271-4b2b-8966-303ed956772e"
7681
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
77-
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
78-
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
79-
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
80-
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
81-
OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b"
8282
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
83+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
8384
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
85+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
8486
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8587

8688
[targets]

lib/StochasticDiffEqCore/src/alg_utils.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,15 @@ for T in [
6868
end
6969

7070

71-
_alg_autodiff(::StochasticDiffEqNewtonAlgorithm) = alg.autodiff
72-
_alg_autodiff(::StochasticDiffEqNewtonAdaptiveAlgorithm) = alg.autodiff
73-
_alg_autodiff(::StochasticDiffEqJumpNewtonAdaptiveAlgorithm) = alg.autodiff
74-
_alg_autodiff(::StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm) = alg.autodiff
75-
_alg_autodiff(alg::StochasticCompositeAlgorithm) = _alg_autodiff(alg.algs[end])
76-
7771
function OrdinaryDiffEqCore.alg_autodiff(
7872
alg::Union{
79-
StochasticDiffEqAlgorithm, StochasticDiffEqRODEAlgorithm,
73+
StochasticDiffEqNewtonAlgorithm,
74+
StochasticDiffEqNewtonAdaptiveAlgorithm,
75+
StochasticDiffEqJumpNewtonAdaptiveAlgorithm,
76+
StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm,
8077
}
8178
)
82-
ad = _alg_autodiff(alg)
79+
ad = alg.autodiff
8380
if ad == Val(false)
8481
return ADTypes.AutoFiniteDiff()
8582
elseif ad == Val(true)
@@ -88,6 +85,7 @@ function OrdinaryDiffEqCore.alg_autodiff(
8885
return SciMLBase._unwrap_val(ad)
8986
end
9087
end
88+
OrdinaryDiffEqCore.alg_autodiff(alg::StochasticDiffEqCompositeAlgorithm) = OrdinaryDiffEqCore.alg_autodiff(alg.algs[end])
9189

9290
isadaptive(alg::Union{StochasticDiffEqAlgorithm, StochasticDiffEqRODEAlgorithm}) = false
9391
function isadaptive(
Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1 @@
1-
# _alg_autodiff methods for SDE Newton algorithm abstract types.
2-
# These extract the autodiff field from the algorithm struct so that
3-
# OrdinaryDiffEqDifferentiation can set up the Jacobian computation correctly.
4-
5-
OrdinaryDiffEqDifferentiation._alg_autodiff(alg::StochasticDiffEqNewtonAlgorithm) = alg.autodiff
6-
OrdinaryDiffEqDifferentiation._alg_autodiff(alg::StochasticDiffEqNewtonAdaptiveAlgorithm) = alg.autodiff
7-
OrdinaryDiffEqDifferentiation._alg_autodiff(alg::StochasticDiffEqJumpNewtonAdaptiveAlgorithm) = alg.autodiff
8-
OrdinaryDiffEqDifferentiation._alg_autodiff(alg::StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm) = alg.autodiff
1+
# alg_autodiff for SDE Newton algorithm types is defined in StochasticDiffEqCore.

lib/StochasticDiffEqImplicit/src/alg_utils.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ function SciMLBase.alg_interpretation(
2222
alg::ImplicitRKMil{
2323
AD,
2424
F,
25-
P,
2625
N,
2726
T2,
27+
T3,
2828
interpretation,
29+
CJ,
2930
}
30-
) where {AD, F, P, N, T2, interpretation}
31+
) where {AD, F, N, T2, T3, interpretation, CJ}
3132
return interpretation
3233
end
3334

lib/StochasticDiffEqImplicit/src/algorithms.jl

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
using ADTypes: AutoForwardDiff
22
using OrdinaryDiffEqCore: _fixup_ad, _unwrap_val
33

4-
struct ImplicitEM{AD, F, F2, P, T2, CJ} <:
4+
struct ImplicitEM{AD, F, F2, T2, T3, CJ} <:
55
StochasticDiffEqNewtonAdaptiveAlgorithm
66
linsolve::F
77
nlsolve::F2
8-
precs::P
98
theta::T2
109
extrapolant::Symbol
11-
new_jac_conv_bound::T2
10+
new_jac_conv_bound::T3
1211
symplectic::Bool
1312
autodiff::AD
1413
concrete_jac::CJ
@@ -17,7 +16,6 @@ end
1716
function ImplicitEM(;
1817
autodiff = AutoForwardDiff(),
1918
concrete_jac = nothing,
20-
precs = OrdinaryDiffEqCore.DEFAULT_PRECS,
2119
linsolve = nothing, nlsolve = NLNewton(),
2220
extrapolant = :constant,
2321
theta = 1, symplectic = false,
@@ -26,7 +24,7 @@ function ImplicitEM(;
2624
)
2725
autodiff = _fixup_ad(autodiff)
2826
return ImplicitEM(
29-
linsolve, nlsolve, precs,
27+
linsolve, nlsolve,
3028
symplectic ? 1 / 2 : theta,
3129
extrapolant, new_jac_conv_bound, symplectic,
3230
autodiff, _unwrap_val(concrete_jac), controller
@@ -36,14 +34,13 @@ end
3634
STrapezoid(; kwargs...) = ImplicitEM(; theta = 1 / 2, kwargs...)
3735
SImplicitMidpoint(; kwargs...) = ImplicitEM(; theta = 1 / 2, symplectic = true, kwargs...)
3836

39-
struct ImplicitEulerHeun{AD, F, P, N, T2, CJ} <:
37+
struct ImplicitEulerHeun{AD, F, N, T2, T3, CJ} <:
4038
StochasticDiffEqNewtonAdaptiveAlgorithm
4139
linsolve::F
4240
nlsolve::N
43-
precs::P
4441
theta::T2
4542
extrapolant::Symbol
46-
new_jac_conv_bound::T2
43+
new_jac_conv_bound::T3
4744
symplectic::Bool
4845
autodiff::AD
4946
concrete_jac::CJ
@@ -52,7 +49,6 @@ end
5249
function ImplicitEulerHeun(;
5350
autodiff = AutoForwardDiff(),
5451
concrete_jac = nothing,
55-
precs = OrdinaryDiffEqCore.DEFAULT_PRECS,
5652
linsolve = nothing, nlsolve = NLNewton(),
5753
extrapolant = :constant,
5854
theta = 1, symplectic = false,
@@ -61,22 +57,21 @@ function ImplicitEulerHeun(;
6157
)
6258
autodiff = _fixup_ad(autodiff)
6359
return ImplicitEulerHeun(
64-
linsolve, nlsolve, precs,
60+
linsolve, nlsolve,
6561
symplectic ? 1 / 2 : theta,
6662
extrapolant,
6763
new_jac_conv_bound, symplectic,
6864
autodiff, _unwrap_val(concrete_jac), controller
6965
)
7066
end
7167

72-
struct ImplicitRKMil{AD, F, P, N, T2, interpretation, CJ} <:
68+
struct ImplicitRKMil{AD, F, N, T2, T3, interpretation, CJ} <:
7369
StochasticDiffEqNewtonAdaptiveAlgorithm
7470
linsolve::F
7571
nlsolve::N
76-
precs::P
7772
theta::T2
7873
extrapolant::Symbol
79-
new_jac_conv_bound::T2
74+
new_jac_conv_bound::T3
8075
symplectic::Bool
8176
autodiff::AD
8277
concrete_jac::CJ
@@ -85,7 +80,6 @@ end
8580
function ImplicitRKMil(;
8681
autodiff = AutoForwardDiff(),
8782
concrete_jac = nothing,
88-
precs = OrdinaryDiffEqCore.DEFAULT_PRECS,
8983
linsolve = nothing, nlsolve = NLNewton(),
9084
extrapolant = :constant,
9185
theta = 1, symplectic = false,
@@ -94,25 +88,24 @@ function ImplicitRKMil(;
9488
)
9589
autodiff = _fixup_ad(autodiff)
9690
return ImplicitRKMil{
97-
typeof(autodiff), typeof(linsolve), typeof(precs), typeof(nlsolve),
98-
typeof(symplectic ? 1 / 2 : theta), typeof(interpretation),
91+
typeof(autodiff), typeof(linsolve), typeof(nlsolve),
92+
typeof(symplectic ? 1 / 2 : theta), typeof(new_jac_conv_bound),
93+
typeof(interpretation), typeof(_unwrap_val(concrete_jac)),
9994
}(
100-
linsolve, nlsolve, precs,
101-
symplectic ? 1 / 2 : theta,
95+
linsolve, nlsolve, symplectic ? 1 / 2 : theta,
10296
extrapolant,
10397
new_jac_conv_bound, symplectic,
10498
autodiff, _unwrap_val(concrete_jac), controller
10599
)
106100
end
107101

108-
struct ISSEM{AD, F, P, N, T2, CJ} <:
102+
struct ISSEM{AD, F, N, T2, T3, CJ} <:
109103
StochasticDiffEqNewtonAdaptiveAlgorithm
110104
linsolve::F
111105
nlsolve::N
112-
precs::P
113106
theta::T2
114107
extrapolant::Symbol
115-
new_jac_conv_bound::T2
108+
new_jac_conv_bound::T3
116109
symplectic::Bool
117110
autodiff::AD
118111
concrete_jac::CJ
@@ -121,7 +114,6 @@ end
121114
function ISSEM(;
122115
autodiff = AutoForwardDiff(),
123116
concrete_jac = nothing,
124-
precs = OrdinaryDiffEqCore.DEFAULT_PRECS,
125117
linsolve = nothing, nlsolve = NLNewton(),
126118
extrapolant = :constant,
127119
theta = 1, symplectic = false,
@@ -130,22 +122,21 @@ function ISSEM(;
130122
)
131123
autodiff = _fixup_ad(autodiff)
132124
return ISSEM(
133-
linsolve, nlsolve, precs,
125+
linsolve, nlsolve,
134126
symplectic ? 1 / 2 : theta,
135127
extrapolant,
136128
new_jac_conv_bound, symplectic,
137129
autodiff, _unwrap_val(concrete_jac), controller
138130
)
139131
end
140132

141-
struct ISSEulerHeun{AD, F, P, N, T2, CJ} <:
133+
struct ISSEulerHeun{AD, F, N, T2, T3, CJ} <:
142134
StochasticDiffEqNewtonAdaptiveAlgorithm
143135
linsolve::F
144136
nlsolve::N
145-
precs::P
146137
theta::T2
147138
extrapolant::Symbol
148-
new_jac_conv_bound::T2
139+
new_jac_conv_bound::T3
149140
symplectic::Bool
150141
autodiff::AD
151142
concrete_jac::CJ
@@ -154,7 +145,6 @@ end
154145
function ISSEulerHeun(;
155146
autodiff = AutoForwardDiff(),
156147
concrete_jac = nothing,
157-
precs = OrdinaryDiffEqCore.DEFAULT_PRECS,
158148
linsolve = nothing, nlsolve = NLNewton(),
159149
extrapolant = :constant,
160150
theta = 1, symplectic = false,
@@ -163,19 +153,18 @@ function ISSEulerHeun(;
163153
)
164154
autodiff = _fixup_ad(autodiff)
165155
return ISSEulerHeun(
166-
linsolve, nlsolve, precs,
156+
linsolve, nlsolve,
167157
symplectic ? 1 / 2 : theta,
168158
extrapolant,
169159
new_jac_conv_bound, symplectic,
170160
autodiff, _unwrap_val(concrete_jac), controller
171161
)
172162
end
173163

174-
struct SKenCarp{AD, F, P, N, T2, CJ} <:
164+
struct SKenCarp{AD, F, N, T2, CJ} <:
175165
StochasticDiffEqNewtonAdaptiveAlgorithm
176166
linsolve::F
177167
nlsolve::N
178-
precs::P
179168
smooth_est::Bool
180169
extrapolant::Symbol
181170
new_jac_conv_bound::T2
@@ -188,15 +177,14 @@ end
188177
function SKenCarp(;
189178
autodiff = AutoForwardDiff(),
190179
concrete_jac = nothing,
191-
precs = OrdinaryDiffEqCore.DEFAULT_PRECS,
192180
linsolve = nothing, nlsolve = NLNewton(),
193181
smooth_est = true, extrapolant = :min_correct,
194182
new_jac_conv_bound = 1.0e-3, controller = :Predictive,
195183
ode_error_est = true
196184
)
197185
autodiff = _fixup_ad(autodiff)
198186
return SKenCarp(
199-
linsolve, nlsolve, precs, smooth_est, extrapolant, new_jac_conv_bound,
187+
linsolve, nlsolve, smooth_est, extrapolant, new_jac_conv_bound,
200188
ode_error_est,
201189
autodiff, _unwrap_val(concrete_jac), controller
202190
)
Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1 @@
1-
# _alg_autodiff methods for SDE Newton algorithm abstract types.
2-
# These extract the autodiff field from the algorithm struct so that
3-
# OrdinaryDiffEqDifferentiation can set up the Jacobian computation correctly.
4-
5-
OrdinaryDiffEqDifferentiation._alg_autodiff(alg::StochasticDiffEqNewtonAlgorithm) = alg.autodiff
6-
OrdinaryDiffEqDifferentiation._alg_autodiff(alg::StochasticDiffEqNewtonAdaptiveAlgorithm) = alg.autodiff
7-
OrdinaryDiffEqDifferentiation._alg_autodiff(alg::StochasticDiffEqJumpNewtonAdaptiveAlgorithm) = alg.autodiff
8-
OrdinaryDiffEqDifferentiation._alg_autodiff(alg::StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm) = alg.autodiff
1+
# alg_autodiff for SDE Newton algorithm types is defined in StochasticDiffEqCore.

lib/StochasticDiffEqWeak/src/algorithms.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -704,13 +704,16 @@ Alternative stochastic generalization of the modified Euler method.
704704
"""
705705
struct SMEB <: StochasticDiffEqAlgorithm end
706706

707-
struct IRI1{AD, F, F2, T2, CJ} <:
707+
708+
using ADTypes: AutoForwardDiff
709+
710+
struct IRI1{AD, F, F2, T2, T3, CJ} <:
708711
StochasticDiffEqNewtonAdaptiveAlgorithm
709712
linsolve::F
710713
nlsolve::F2
711714
theta::T2
712715
extrapolant::Symbol
713-
new_jac_conv_bound::T2
716+
new_jac_conv_bound::T3
714717
autodiff::AD
715718
concrete_jac::CJ
716719
controller::Symbol

0 commit comments

Comments
 (0)