Skip to content
This repository was archived by the owner on Apr 21, 2026. It is now read-only.
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -805,12 +805,79 @@ function promote_f(f::F, ::Val{specialize}, u0, p, t, ::Val{true}) where {F, spe
zeros(uElType, n, n)
f = @set f.jac = wrapfun_jac_iip(f.jac, (J_proto, u0, p, t))
end
# Wrap the initialization data functions if present, so the
# NonlinearSolve dispatch also benefits from compilation caching.
if f.initialization_data isa SciMLBase.OverrideInitData
f = _wrap_initialization_data(f)
end
return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p, t)))
else
return f
end
end

# Wrap functions inside OverrideInitData so their types are erased:
# - NonlinearFunction.f in initializeprob (IIP, same pattern as jac wrapping)
# - initializeprobmap (OOP, 1-arg)
# - initializeprobpmap (OOP, 2-arg)
function _wrap_initialization_data(f)
initdata = f.initialization_data
initprob = initdata.initializeprob
nlf = initprob.f

# Wrap NLF's f if IIP and not already wrapped
initprob2 = initprob
if SciMLBase.isinplace(nlf) && !(nlf.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
init_u0 = SciMLBase.state_values(initprob)
init_p = SciMLBase.parameter_values(initprob)
if init_u0 !== nothing
wrapped_nlf_f = wrapfun_jac_iip(nlf.f, (similar(init_u0), init_u0, init_p))
nlf2 = SciMLBase.NonlinearFunction{
SciMLBase.isinplace(nlf), SciMLBase.specialization(nlf),
typeof(wrapped_nlf_f),
typeof(nlf.mass_matrix), Nothing, Nothing,
Nothing, Nothing, Nothing, typeof(nlf.jac_prototype),
typeof(nlf.sparsity), Nothing, Nothing,
Nothing,
Any, typeof(nlf.colorvec),
typeof(nlf.sys), typeof(nlf.resid_prototype), Nothing,
}(
wrapped_nlf_f, nlf.mass_matrix,
nlf.analytic, nlf.tgrad, nlf.jac,
nlf.jvp, nlf.vjp, nlf.jac_prototype,
nlf.sparsity, nlf.Wfact,
nlf.Wfact_t, nlf.paramjac,
nlf.observed, nlf.colorvec, nlf.sys,
nlf.resid_prototype, nlf.initialization_data
)
initprob2 = @set initprob.f = nlf2
end
end

# Wrap initializeprobmap(nlsol) -> u0 (OOP, 1-arg)
wrapped_map = initdata.initializeprobmap
if wrapped_map !== nothing && !(wrapped_map isa FunctionWrappersWrappers.FunctionWrappersWrapper)
wrapped_map = FunctionWrappersWrappers.FunctionWrappersWrapper(
wrapped_map, (Tuple{Any},), (Any,)
)
end

# Wrap initializeprobpmap(valp, nlsol) -> p (OOP, 2-arg)
wrapped_pmap = initdata.initializeprobpmap
if wrapped_pmap !== nothing && !(wrapped_pmap isa FunctionWrappersWrappers.FunctionWrappersWrapper)
wrapped_pmap = FunctionWrappersWrappers.FunctionWrappersWrapper(
wrapped_pmap, (Tuple{Any, Any},), (Any,)
)
end

f = @set f.initialization_data = SciMLBase.OverrideInitData(
initprob2, initdata.update_initializeprob!,
wrapped_map, wrapped_pmap,
initdata.metadata, initdata.is_update_oop
)
return f
end

# Simple path for algorithms that do NOT use ForwardDiff internally (e.g. Tsit5, Verner).
# Avoids calling hasdualpromote/wrapfun_iip which have extension overrides in
# DiffEqBaseForwardDiffExt that would create invalidating method table backedges.
Expand Down
Loading