Skip to content

Commit a75036b

Browse files
Unwrap AutoSpecializeCallable for Enzyme AD and fix Runic formatting
Enzyme cannot differentiate through FunctionWrappers' llvmcall, causing EnzymeMutabilityException in all IIP Vector{Float64} tests with AutoEnzyme. Unwrap the function in construct_jacobian_cache when the AD backend is Enzyme-based (including AutoSparse(AutoEnzyme(...))), so DI sees the raw user function. Also apply Runic formatting to SCCNonlinearSolve files. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 504b1b4 commit a75036b

4 files changed

Lines changed: 34 additions & 10 deletions

File tree

lib/NonlinearSolveBase/src/autospecialize.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,17 @@ Otherwise return `f` unchanged.
8181
get_raw_f(f) = f
8282
get_raw_f(f::AutoSpecializeCallable) = f.orig
8383

84+
"""
85+
_uses_enzyme_ad(ad) -> Bool
86+
87+
Return `true` if `ad` is an Enzyme-based AD backend (possibly wrapped in `AutoSparse`).
88+
Enzyme cannot differentiate through FunctionWrappers' `llvmcall`, so
89+
`AutoSpecializeCallable` must be unwrapped before passing to DI with Enzyme.
90+
"""
91+
_uses_enzyme_ad(::ADTypes.AutoEnzyme) = true
92+
_uses_enzyme_ad(ad::AutoSparse) = _uses_enzyme_ad(ADTypes.dense_ad(ad))
93+
_uses_enzyme_ad(_) = false
94+
8495
# Default dispatch assumes no ForwardDiff loaded.
8596
# The ForwardDiff extension overrides these with dual-aware versions.
8697

lib/NonlinearSolveBase/src/jacobian.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ function construct_jacobian_cache(
5353
end
5454
autodiff = standardize_forwarddiff_tag(autodiff, prob)
5555
autodiff = construct_concrete_adtype(f, autodiff)
56+
# Enzyme cannot differentiate through FunctionWrappers' llvmcall.
57+
# Unwrap AutoSpecializeCallable so DI sees the raw user function.
58+
if is_fw_wrapped(f.f) && _uses_enzyme_ad(autodiff)
59+
f = @set f.f = get_raw_f(f.f)
60+
end
5661
di_extras = if SciMLBase.isinplace(f)
5762
DI.prepare_jacobian(f, fu_cache, autodiff, u, Constant(p), strict = Val(false))
5863
else

lib/SCCNonlinearSolve/src/SCCNonlinearSolve.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,11 @@ function solve_single_scc(alg, prob, explicitfun, sols; kwargs...)
8585
)
8686
else
8787
sol = SciMLBase.solve(prob, alg.nlalg; kwargs...)
88-
SciMLBase.strip_solution(SciMLBase.build_solution(
89-
prob, nothing, sol.u, sol.resid, retcode = sol.retcode
90-
))
88+
SciMLBase.strip_solution(
89+
SciMLBase.build_solution(
90+
prob, nothing, sol.u, sol.resid, retcode = sol.retcode
91+
)
92+
)
9193
end
9294

9395
return _sol
@@ -101,20 +103,24 @@ function iteratively_build_sols(alg, probs::AbstractVector, explicitfuns::Abstra
101103
uType = typeof(probvec(prob1))
102104
T = eltype(uType)
103105
rType = uType # resid has same type as u for nonlinear problems
104-
ST = SciMLBase.NonlinearSolution{T, 1, uType, rType,
105-
NamedTuple{(:p,), Tuple{Nothing}}, Nothing, Nothing, Nothing, Nothing, Nothing}
106+
ST = SciMLBase.NonlinearSolution{
107+
T, 1, uType, rType,
108+
NamedTuple{(:p,), Tuple{Nothing}}, Nothing, Nothing, Nothing, Nothing, Nothing,
109+
}
106110
sols = Vector{ST}(undef, length(probs))
107111
for i in eachindex(probs)
108-
sols[i] = solve_single_scc(alg, probs[i], explicitfuns[i], view(sols, 1:i-1); kwargs...)
112+
sols[i] = solve_single_scc(alg, probs[i], explicitfuns[i], view(sols, 1:(i - 1)); kwargs...)
109113
end
110114
return sols
111115
end
112116

113117
@generated function iteratively_build_sols(alg, probs::Tuple, explicitfuns::Tuple, ::Val{N}; kwargs...) where {N}
114118
return quote
115119
Base.Cartesian.@nexprs $N i -> begin
116-
prob_i = solve_single_scc(alg, probs[i], explicitfuns[i],
117-
Base.Cartesian.@ntuple((i - 1), j -> prob_j); kwargs...)
120+
prob_i = solve_single_scc(
121+
alg, probs[i], explicitfuns[i],
122+
Base.Cartesian.@ntuple((i - 1), j -> prob_j); kwargs...
123+
)
118124
end
119125
return Base.Cartesian.@ntuple $N i -> prob_i
120126
end

lib/SCCNonlinearSolve/test/core_tests.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,10 @@ end
248248
# Wrap explicitfuns with FunctionWrapper for type unification.
249249
# The stripped solution type is deterministic — compute it from u0 type.
250250
uType = Vector{Float64}
251-
SSol = SciMLBase.NonlinearSolution{Float64, 1, uType, uType,
252-
NamedTuple{(:p,), Tuple{Nothing}}, Nothing, Nothing, Nothing, Nothing, Nothing}
251+
SSol = SciMLBase.NonlinearSolution{
252+
Float64, 1, uType, uType,
253+
NamedTuple{(:p,), Tuple{Nothing}}, Nothing, Nothing, Nothing, Nothing, Nothing,
254+
}
253255
SolsView = SubArray{SSol, 1, Vector{SSol}, Tuple{UnitRange{Int64}}, true}
254256
EFW = FunctionWrapper{Nothing, Tuple{Vector{Float64}, SolsView}}
255257
ef1_wrapped = EFW(explicitfun1_raw)

0 commit comments

Comments
 (0)