Skip to content

Commit 7554ef9

Browse files
Unwrap AutoSpecializeCallable for Enzyme in TrustRegion VecJac/JacVec operators
The TrustRegion scheme creates VecJacOperator and JacVecOperator directly from the problem, bypassing construct_jacobian_cache. When Enzyme is the AD backend, these operators need the unwrapped function (without FunctionWrappers) to avoid EnzymeMutabilityException. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ff45dbd commit 7554ef9

1 file changed

Lines changed: 11 additions & 2 deletions

File tree

lib/NonlinearSolveFirstOrder/src/trust_region.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,11 +223,20 @@ function InternalAPI.init(
223223
p1, p2, p3, p4 = get_parameters(T, alg.method)
224224
ϵ = T(1.0e-8)
225225

226+
# Enzyme cannot differentiate through FunctionWrappers' llvmcall.
227+
# Unwrap AutoSpecializeCallable so DI sees the raw user function.
228+
_ad_prob = prob
229+
if NonlinearSolveBase.is_fw_wrapped(prob.f.f) &&
230+
(NonlinearSolveBase._uses_enzyme_ad(vjp_autodiff) ||
231+
NonlinearSolveBase._uses_enzyme_ad(jvp_autodiff))
232+
@set! _ad_prob.f.f = NonlinearSolveBase.get_raw_f(prob.f.f)
233+
end
234+
226235
vjp_operator = alg.method isa RUS.__Yuan || alg.method isa RUS.__Bastin ?
227-
VecJacOperator(prob, fu, u; autodiff = vjp_autodiff) : nothing
236+
VecJacOperator(_ad_prob, fu, u; autodiff = vjp_autodiff) : nothing
228237

229238
jvp_operator = alg.method isa RUS.__Bastin ?
230-
JacVecOperator(prob, fu, u; autodiff = jvp_autodiff) : nothing
239+
JacVecOperator(_ad_prob, fu, u; autodiff = jvp_autodiff) : nothing
231240

232241
if alg.method isa RUS.__Yuan
233242
Jᵀfu_cache = StatefulJacobianOperator(vjp_operator, u, prob.p) * Utils.safe_vec(fu)

0 commit comments

Comments
 (0)