Skip to content

Commit 3e13835

Browse files
Support SciMLBase v3 and DiffEqBase v7
- Widen compat for SciMLBase to "2, 3" and DiffEqBase to "6.62, 7". - Bump version to 1.2.0. - Accept `DEVerbosity` objects for the `verbose` kwarg by treating any non-`false` value as "show warnings". DiffEqBase v7 now passes a structured verbosity object instead of a plain `Bool`. - Unwrap `SciMLBase.unwrapped_f(prob.f.f)` (and `prob.f.f1.f` for second-order problems) before invoking the user function. SciMLBase v3 changed the default specialization of `ODEFunction{iip}` to `AutoSpecialize`, which wraps the function in a FunctionWrappersWrapper that only accepts the exact argument types captured at problem construction (e.g. `Matrix{Float64}`). Passing a `Base.ReshapedArray` through the wrapper fails; using the unwrapped function avoids the type-specialization trap. Supersedes #60 and #61. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent 89e45c1 commit 3e13835

2 files changed

Lines changed: 22 additions & 9 deletions

File tree

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "GeometricIntegratorsDiffEq"
22
uuid = "5a33fad7-5ce4-5983-9f5d-5f26ceab5c96"
3-
version = "1.1.0"
3+
version = "1.2.0"
44
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
55

66
[deps]
@@ -10,11 +10,11 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1010
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1111

1212
[compat]
13-
DiffEqBase = "6.62"
13+
DiffEqBase = "6.62, 7"
1414
ExplicitImports = "1.14.0"
1515
GeometricIntegrators = "0.15, 0.16"
1616
Reexport = "0.2, 1"
17-
SciMLBase = "2"
17+
SciMLBase = "2, 3"
1818
julia = "1.10"
1919

2020
[extras]

src/solve.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,18 @@ function DiffEqBase.__solve(
1515
error("dt required for fixed timestep methods.")
1616
end
1717

18+
# DiffEqBase v7 passes `verbose` as a `DEVerbosity` object instead of a Bool.
19+
# Treat anything that is not literally `false` as "show warnings".
20+
verbose_bool = verbose !== false
21+
1822
isstiff = !(
1923
alg isa Union{
2024
GIImplicitEuler, GIImplicitMidpoint,
2125
GISRK3, GIGLRK, GIRadauIA, GIRadauIIA,
2226
}
2327
)
2428

25-
if verbose
29+
if verbose_bool
2630
warned = !isempty(kwargs) && check_keywords(alg, kwargs, warnlist)
2731
if !(prob.f isa DiffEqBase.AbstractParameterizedFunction) && isstiff
2832
if DiffEqBase.has_tgrad(prob.f)
@@ -68,15 +72,20 @@ function DiffEqBase.__solve(
6872
# Create function wrapper for GeometricIntegrators API
6973
# GeometricIntegrators expects: v(v, t, q, params)
7074
# DiffEqBase provides: f(du, u, p, t) for inplace or f(u, p, t) for out-of-place
75+
# SciMLBase v3's default AutoSpecialize wraps f in a FunctionWrappersWrapper that
76+
# only accepts the exact argument types captured at problem construction (e.g.
77+
# `Matrix{Float64}`), so passing a `Base.ReshapedArray` to it errors. Unwrap the
78+
# underlying user-defined function before invoking it.
79+
raw_f = SciMLBase.unwrapped_f(prob.f.f)
7180
if !isinplace && u isa AbstractArray
72-
v! = (v, t, q, params) -> (v .= vec(prob.f(reshape(q, sizeu), p, t)); nothing)
81+
v! = (v, t, q, params) -> (v .= vec(raw_f(reshape(q, sizeu), p, t)); nothing)
7382
elseif !(u isa Vector{Float64})
7483
v! = (
7584
v, t, q,
7685
params,
77-
) -> (prob.f(reshape(v, sizeu), reshape(q, sizeu), p, t); nothing)
86+
) -> (raw_f(reshape(v, sizeu), reshape(q, sizeu), p, t); nothing)
7887
else
79-
v! = (v, t, q, params) -> prob.f(v, q, p, t)
88+
v! = (v, t, q, params) -> raw_f(v, q, p, t)
8089
end
8190

8291
ode = GeometricIntegrators.ODEProblem(v!, prob.tspan, dt, vec(prob.u0))
@@ -111,17 +120,21 @@ function DiffEqBase.__solve(
111120

112121
v! = (v, t, q, p_state, params) -> (v .= p_state) # dq/dt = p
113122

123+
# Unwrap past SciMLBase v3 AutoSpecialize FunctionWrappers for the acceleration
124+
# function so it accepts argument types other than those captured at construction.
125+
raw_f1 = SciMLBase.unwrapped_f(prob.f.f1.f)
126+
114127
# Handle both inplace and out-of-place problems
115128
if isinplace
116129
f! = (
117130
f_out, t, q, p_state,
118131
params,
119-
) -> (prob.f.f1.f(f_out, p_state, q, p, t); nothing) # dp/dt = f1(p, q)
132+
) -> (raw_f1(f_out, p_state, q, p, t); nothing) # dp/dt = f1(p, q)
120133
else
121134
f! = (
122135
f_out, t, q, p_state,
123136
params,
124-
) -> (f_out .= prob.f.f1.f(p_state, q, p, t); nothing)
137+
) -> (f_out .= raw_f1(p_state, q, p, t); nothing)
125138
end
126139

127140
pode = GeometricIntegrators.PODEProblem(

0 commit comments

Comments
 (0)