Skip to content

Commit 4d3f0d6

Browse files
Widen UJacobianWrapper.p for nested ForwardDiff (#3381)
When a Rosenbrock integrator runs inside an outer ForwardDiff layer, the stored `p` in `UJacobianWrapper` is a `Vector{<:Dual}` at the outer Dual level, but the inner Jacobian computation widens `u` into a deeper nested-Dual type via its prepared `JacobianConfig`. The user `f(du, u, p, t)` body then multiplies `p[i] * u[i]` across two different Dual nesting levels, which dispatches through ForwardDiff's `tagcount`-based tag precedence. That precedence is a `@generated` function whose literal value is baked at first compile, so its ordering depends on which package's precompile cache first instantiated a given tag type — in nested scenarios this can invert nesting and produce a triple-nested `Dual` that crashes `setindex!(du, ...)` with `Float64(::nested_dual)`. Fix: in `jacobian!`, inspect the prepared `ForwardDiff.JacobianConfig` for the inner xdual buffer's element type and lift `uf.p` into that nested-Dual type ahead of time via a fresh `UJacobianWrapper`. The widened `p` carries zero inner partials (correct — `p` does not depend on `u`), and the fast path is a single type-stable dispatch returning `f` unchanged so there is no overhead for non-nested calls. Add a regression test exercising the precise call graph (outer-Dual `p`, inner Rosenbrock Jacobian) that crashes pre-fix with a `FirstAutodiffJacError(MethodError(Float64, (nested_dual,)))` and succeeds post-fix. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent dff8c62 commit 4d3f0d6

4 files changed

Lines changed: 106 additions & 3 deletions

File tree

lib/OrdinaryDiffEqDifferentiation/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
name = "OrdinaryDiffEqDifferentiation"
2-
version = "2.8.0"
2+
version = "2.8.1"
33
uuid = "4302a76b-040a-498a-8c04-15b101fed76b"
44
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>", "Yingbo Ma <mayingbo5@gmail.com>"]
55

lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,45 @@ function jacobian(f::F, x, integrator) where {F}
185185
return jac
186186
end
187187

188+
# Inner-Dual eltype that the prepared ForwardDiff JacobianConfig will allocate
189+
# for its xdual buffer. Returns `nothing` if `prep` is not a ForwardDiff
190+
# JacobianConfig-backed prep (e.g. AutoFiniteDiff path).
191+
function _jac_prep_inner_dual_eltype(prep)
192+
hasfield(typeof(prep), :config) || return nothing
193+
cfg = getfield(prep, :config)
194+
cfg isa ForwardDiff.JacobianConfig || return nothing
195+
duals = cfg.duals
196+
duals isa Tuple && length(duals) >= 2 || return nothing
197+
return eltype(duals[2])
198+
end
199+
200+
# When the integrator's stored `p` (held in `f::UJacobianWrapper`) is a
201+
# `Vector{<:Dual}` because we are *inside* an outer ForwardDiff Jacobian /
202+
# gradient, the inner Rosenbrock Jacobian widens `u` into a deeper nested-Dual
203+
# type via the prepared `JacobianConfig`. The user `ode(du, u, p, t)` body
204+
# then multiplies `p[i] * u[i]` across two different Dual nesting levels,
205+
# which falls back to ForwardDiff's tag-precedence machinery (`tagcount`-
206+
# based `≺`). That precedence is unstable across precompile boundaries (the
207+
# `@generated tagcount` literal is baked at first compile and depends on
208+
# which package precompiled which tag first), so the result type can come
209+
# out wrong and crash inside `setindex!(du, ...)` with `Float64(::nested_dual)`.
210+
#
211+
# Lift `p` into the inner nested-Dual type so the user body never multiplies
212+
# across tag levels. The widened `p` carries zero inner partials (which is
213+
# what we want — `p` does not depend on `u`).
214+
_widen_uf_p_for_jac(f, prep) = f
215+
function _widen_uf_p_for_jac(f::UJacobianWrapper, prep)
216+
inner_T = _jac_prep_inner_dual_eltype(prep)
217+
inner_T === nothing && return f
218+
p = f.p
219+
p isa AbstractArray || return f
220+
Tp = eltype(p)
221+
Tp <: ForwardDiff.Dual || return f
222+
Tp === inner_T && return f
223+
inner_T <: ForwardDiff.Dual || return f
224+
return UJacobianWrapper{isinplace(f)}(f.f, f.t, convert.(inner_T, p))
225+
end
226+
188227
function jacobian!(
189228
J::AbstractMatrix{<:Number}, f::F, x::AbstractArray{<:Number},
190229
fx::AbstractArray{<:Number}, integrator::SciMLBase.DEIntegrator,
@@ -240,14 +279,16 @@ function jacobian!(
240279
config = jac_config[1]
241280
end
242281

282+
f_eff = _widen_uf_p_for_jac(f, config)
283+
243284
if integrator.iter == 1
244285
try
245-
DI.jacobian!(f, fx, J, config, gpu_safe_autodiff(alg_autodiff(alg), x), x)
286+
DI.jacobian!(f_eff, fx, J, config, gpu_safe_autodiff(alg_autodiff(alg), x), x)
246287
catch e
247288
throw(FirstAutodiffJacError(e))
248289
end
249290
else
250-
DI.jacobian!(f, fx, J, config, gpu_safe_autodiff(alg_autodiff(alg), x), x)
291+
DI.jacobian!(f_eff, fx, J, config, gpu_safe_autodiff(alg_autodiff(alg), x), x)
251292
end
252293

253294
return nothing
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
using Test
2+
using OrdinaryDiffEqRosenbrock
3+
using SciMLBase
4+
using ADTypes
5+
using ForwardDiff
6+
7+
# Regression test for nested ForwardDiff over an ODE solve (#3381). When a
8+
# Rosenbrock solver is invoked with a `Vector{<:Dual}` `p` (i.e. we are
9+
# inside some *outer* ForwardDiff layer), the inner Rosenbrock Jacobian
10+
# widens `u` to a deeper nested-Dual type via its JacobianConfig. The user
11+
# `f(du, u, p, t)` body then multiplies `p[i] * u[i]` across two different
12+
# Dual nesting levels. ForwardDiff's cross-tag multiplication uses
13+
# `tagcount`-based precedence, which is a `@generated` function whose
14+
# literal value is baked at first compile. Depending on the order in which
15+
# tag types are first instantiated (which varies with precompile order),
16+
# the resulting ordering can invert nesting and produce a triple-nested
17+
# `Dual` that crashes the eventual `setindex!` with
18+
# `Float64(::nested_dual)` MethodError.
19+
#
20+
# The fix (`_widen_uf_p_for_jac` in derivative_wrappers.jl) lifts
21+
# `uf.p` into the inner nested-Dual type ahead of time so the user
22+
# body never has to cross tag levels.
23+
#
24+
# To reliably reproduce the broken precedence without depending on
25+
# NonlinearSolve's precompile-baked NonlinearSolveTag, we construct the
26+
# outer Dual manually with a throwaway tag that is *not* a concrete
27+
# `Tag(f, V)` call. Because the throwaway tag is only used as a type
28+
# parameter, ForwardDiff does not trigger its `tagcount` until the first
29+
# cross-tag multiplication — which happens *after* `OrdinaryDiffEqTag` has
30+
# already had its nested-V `tagcount` triggered via
31+
# `OrdinaryDiffEqDifferentiation.prepare_ADType`. This matches the
32+
# precedence inversion that the original issue hits via
33+
# `NonlinearSolveTag`'s precompile workload.
34+
35+
struct _Nested3381OuterTag end
36+
37+
function _nested3381_ode!(du, u, p, t)
38+
du[1] = -p[1] * u[1]
39+
du[2] = -u[1] - p[2] * u[2]
40+
return nothing
41+
end
42+
43+
@testset "Nested ForwardDiff through Rosenbrock Jacobian (#3381)" begin
44+
OuterTag = ForwardDiff.Tag{_Nested3381OuterTag, Float64}
45+
DualT = ForwardDiff.Dual{OuterTag, Float64, 2}
46+
pdual = [
47+
DualT(1.5, ForwardDiff.Partials((1.0, 0.0))),
48+
DualT(2.0, ForwardDiff.Partials((0.0, 1.0))),
49+
]
50+
51+
ode_f = ODEFunction{true, SciMLBase.FullSpecialize}(_nested3381_ode!)
52+
u0 = [1.0, 1.0]
53+
tspan = (0.0, 1.0)
54+
prob = ODEProblem(ode_f, u0, tspan, pdual)
55+
56+
sol = solve(prob, Rosenbrock23(autodiff = AutoForwardDiff(chunksize = 2)))
57+
@test SciMLBase.successful_retcode(sol.retcode)
58+
@test eltype(sol.u[end]) === DualT
59+
# Sanity: primal values still reach the target trajectory.
60+
@test all(isfinite, ForwardDiff.value.(sol.u[end]))
61+
end

lib/OrdinaryDiffEqDifferentiation/test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ if TEST_GROUP ∉ ("QA", "Sparse", "ModelingToolkit")
2727
@time @safetestset "Differentiation Trait Tests" include("differentiation_traits_tests.jl")
2828
@time @safetestset "Autodiff Error Tests" include("autodiff_error_tests.jl")
2929
@time @safetestset "No Jac Tests" include("nojac_tests.jl")
30+
@time @safetestset "Nested ForwardDiff" include("nested_forwarddiff_tests.jl")
3031
end
3132

3233
# Run sparse tests (separate environment due to ComponentArrays dep conflicts)

0 commit comments

Comments
 (0)