Skip to content

Commit 5796679

Browse files
Unwrap AutoSpecializeCallable for Enzyme at __init entry points
Instead of fixing individual call sites (trust_region.jl VecJac/JacVec, jacobian.jl construct_jacobian_cache), create _ad_prob with unwrapped function early in __init for both FirstOrder and QuasiNewton. This ensures ALL downstream AD consumers (Jacobian cache, trust region, linesearch, forcing) receive the unwrapped problem when Enzyme is used. - Add maybe_unwrap_prob_for_enzyme helper in NonlinearSolveBase - FirstOrder: create _ad_prob from alg.autodiff/jvp_autodiff/vjp_autodiff - QuasiNewton: detect Enzyme from kwargs and alg.linesearch/trustregion - Revert trust_region.jl inline fix (now handled upstream in solve.jl) Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent 7554ef9 commit 5796679

4 files changed

Lines changed: 50 additions & 17 deletions

File tree

lib/NonlinearSolveBase/src/autospecialize.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,24 @@ _uses_enzyme_ad(::ADTypes.AutoEnzyme) = true
9292
_uses_enzyme_ad(ad::AutoSparse) = _uses_enzyme_ad(ADTypes.dense_ad(ad))
9393
_uses_enzyme_ad(_) = false
9494

95+
"""
96+
maybe_unwrap_prob_for_enzyme(prob, autodiffs...)
97+
98+
If the problem function is wrapped by AutoSpecialize and any of the given AD backends
99+
use Enzyme, return a copy of `prob` with the unwrapped raw function. Otherwise return
100+
`prob` unchanged.
101+
102+
This should be called early in `__init` so that all downstream AD-related constructions
103+
(Jacobian cache, trust region operators, linesearch, forcing) receive the unwrapped problem.
104+
"""
105+
function maybe_unwrap_prob_for_enzyme(prob, autodiffs...)
106+
is_fw_wrapped(prob.f.f) || return prob
107+
for ad in autodiffs
108+
_uses_enzyme_ad(ad) && return @set prob.f.f = get_raw_f(prob.f.f)
109+
end
110+
return prob
111+
end
112+
95113
# Default dispatch assumes no ForwardDiff loaded.
96114
# The ForwardDiff extension overrides these with dual-aware versions.
97115

lib/NonlinearSolveFirstOrder/src/solve.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ function SciMLBase.__init(
176176
verbose = NonlinearVerbosity(verbose)
177177
end
178178

179+
# Enzyme cannot differentiate through FunctionWrappers' llvmcall.
180+
# Create unwrapped prob for all AD-related constructions when using Enzyme.
181+
_ad_prob = NonlinearSolveBase.maybe_unwrap_prob_for_enzyme(
182+
prob, alg.autodiff, alg.jvp_autodiff, alg.vjp_autodiff
183+
)
184+
179185
timer = get_timer_output()
180186
@static_timeit timer "cache construction" begin
181187
u = Utils.maybe_unaliased(prob.u0, alias_u0)
@@ -191,7 +197,7 @@ function SciMLBase.__init(
191197
linsolve_kwargs = merge((; verbose = verbose.linear_verbosity, abstol, reltol), linsolve_kwargs)
192198

193199
jac_cache = NonlinearSolveBase.construct_jacobian_cache(
194-
prob, alg, prob.f, fu, u, prob.p;
200+
_ad_prob, alg, _ad_prob.f, fu, u, _ad_prob.p;
195201
stats, alg.autodiff, linsolve, alg.jvp_autodiff, alg.vjp_autodiff
196202
)
197203
J = reused_jacobian(jac_cache, u)
@@ -219,7 +225,7 @@ function SciMLBase.__init(
219225
NonlinearSolveBase.supports_trust_region(alg.descent) ||
220226
error("Trust Region not supported by $(alg.descent).")
221227
trustregion_cache = InternalAPI.init(
222-
prob, alg.trustregion, prob.f, fu, u, prob.p;
228+
_ad_prob, alg.trustregion, _ad_prob.f, fu, u, _ad_prob.p;
223229
alg.vjp_autodiff, alg.jvp_autodiff, stats, internalnorm, kwargs...
224230
)
225231
globalization = Val(:TrustRegion)
@@ -229,7 +235,7 @@ function SciMLBase.__init(
229235
NonlinearSolveBase.supports_line_search(alg.descent) ||
230236
error("Line Search not supported by $(alg.descent).")
231237
linesearch_cache = CommonSolve.init(
232-
prob, alg.linesearch, fu, u; stats, internalnorm,
238+
_ad_prob, alg.linesearch, fu, u; stats, internalnorm,
233239
autodiff = ifelse(
234240
provided_jvp_autodiff, alg.jvp_autodiff, alg.vjp_autodiff
235241
),
@@ -240,7 +246,7 @@ function SciMLBase.__init(
240246

241247
if has_forcing
242248
forcing_cache = InternalAPI.init(
243-
prob, alg.forcing, fu, u, u, prob.p; stats, internalnorm,
249+
_ad_prob, alg.forcing, fu, u, u, _ad_prob.p; stats, internalnorm,
244250
autodiff = ifelse(
245251
provided_jvp_autodiff, alg.jvp_autodiff, alg.vjp_autodiff
246252
),

lib/NonlinearSolveFirstOrder/src/trust_region.jl

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -223,20 +223,11 @@ function InternalAPI.init(
223223
p1, p2, p3, p4 = get_parameters(T, alg.method)
224224
ϵ = T(1.0e-8)
225225

226-
# Enzyme cannot differentiate through FunctionWrappers' llvmcall.
227-
# Unwrap AutoSpecializeCallable so DI sees the raw user function.
228-
_ad_prob = prob
229-
if NonlinearSolveBase.is_fw_wrapped(prob.f.f) &&
230-
(NonlinearSolveBase._uses_enzyme_ad(vjp_autodiff) ||
231-
NonlinearSolveBase._uses_enzyme_ad(jvp_autodiff))
232-
@set! _ad_prob.f.f = NonlinearSolveBase.get_raw_f(prob.f.f)
233-
end
234-
235226
vjp_operator = alg.method isa RUS.__Yuan || alg.method isa RUS.__Bastin ?
236-
VecJacOperator(_ad_prob, fu, u; autodiff = vjp_autodiff) : nothing
227+
VecJacOperator(prob, fu, u; autodiff = vjp_autodiff) : nothing
237228

238229
jvp_operator = alg.method isa RUS.__Bastin ?
239-
JacVecOperator(_ad_prob, fu, u; autodiff = jvp_autodiff) : nothing
230+
JacVecOperator(prob, fu, u; autodiff = jvp_autodiff) : nothing
240231

241232
if alg.method isa RUS.__Yuan
242233
Jᵀfu_cache = StatefulJacobianOperator(vjp_operator, u, prob.p) * Utils.safe_vec(fu)

lib/NonlinearSolveQuasiNewton/src/solve.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,24 @@ function SciMLBase.__init(
160160
alias = SciMLBase.NonlinearAliasSpecifier(alias_u0 = kwargs[:alias_u0])
161161
end
162162
alias_u0 = alias.alias_u0
163+
# Enzyme cannot differentiate through FunctionWrappers' llvmcall.
164+
# QuasiNewton doesn't have alg.autodiff fields; autodiff may come through kwargs
165+
# or from the linesearch/trustregion algorithm's own autodiff field.
166+
_ad_autodiffs = Any[
167+
get(kwargs, :autodiff, nothing),
168+
get(kwargs, :jvp_autodiff, nothing),
169+
get(kwargs, :vjp_autodiff, nothing),
170+
]
171+
if alg.linesearch !== missing && alg.linesearch !== nothing &&
172+
hasfield(typeof(alg.linesearch), :autodiff)
173+
push!(_ad_autodiffs, alg.linesearch.autodiff)
174+
end
175+
if alg.trustregion !== missing && alg.trustregion !== nothing &&
176+
hasfield(typeof(alg.trustregion), :autodiff)
177+
push!(_ad_autodiffs, alg.trustregion.autodiff)
178+
end
179+
_ad_prob = NonlinearSolveBase.maybe_unwrap_prob_for_enzyme(prob, _ad_autodiffs...)
180+
163181
timer = get_timer_output()
164182
@static_timeit timer "cache construction" begin
165183

@@ -222,7 +240,7 @@ function SciMLBase.__init(
222240
NonlinearSolveBase.supports_trust_region(alg.descent) ||
223241
error("Trust Region not supported by $(alg.descent).")
224242
trustregion_cache = InternalAPI.init(
225-
prob, alg.trustregion, fu, u, p; stats, internalnorm, kwargs...
243+
_ad_prob, alg.trustregion, fu, u, _ad_prob.p; stats, internalnorm, kwargs...
226244
)
227245
globalization = Val(:TrustRegion)
228246
end
@@ -231,7 +249,7 @@ function SciMLBase.__init(
231249
NonlinearSolveBase.supports_line_search(alg.descent) ||
232250
error("Line Search not supported by $(alg.descent).")
233251
linesearch_cache = CommonSolve.init(
234-
prob, alg.linesearch, fu, u; stats, internalnorm, kwargs...
252+
_ad_prob, alg.linesearch, fu, u; stats, internalnorm, kwargs...
235253
)
236254
globalization = Val(:LineSearch)
237255
end

0 commit comments

Comments
 (0)