diff --git a/src/solve.jl b/src/solve.jl index fd120bde4..dcf8afbd4 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -805,12 +805,79 @@ function promote_f(f::F, ::Val{specialize}, u0, p, t, ::Val{true}) where {F, spe zeros(uElType, n, n) f = @set f.jac = wrapfun_jac_iip(f.jac, (J_proto, u0, p, t)) end + # Wrap the initialization data functions if present, so the + # NonlinearSolve dispatch also benefits from compilation caching. + if f.initialization_data isa SciMLBase.OverrideInitData + f = _wrap_initialization_data(f) + end return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p, t))) else return f end end +# Wrap functions inside OverrideInitData so their types are erased: +# - NonlinearFunction.f in initializeprob (IIP, same pattern as jac wrapping) +# - initializeprobmap (OOP, 1-arg) +# - initializeprobpmap (OOP, 2-arg) +function _wrap_initialization_data(f) + initdata = f.initialization_data + initprob = initdata.initializeprob + nlf = initprob.f + + # Wrap NLF's f if IIP and not already wrapped + initprob2 = initprob + if SciMLBase.isinplace(nlf) && !(nlf.f isa FunctionWrappersWrappers.FunctionWrappersWrapper) + init_u0 = SciMLBase.state_values(initprob) + init_p = SciMLBase.parameter_values(initprob) + if init_u0 !== nothing + wrapped_nlf_f = wrapfun_jac_iip(nlf.f, (similar(init_u0), init_u0, init_p)) + nlf2 = SciMLBase.NonlinearFunction{ + SciMLBase.isinplace(nlf), SciMLBase.specialization(nlf), + typeof(wrapped_nlf_f), + typeof(nlf.mass_matrix), Nothing, Nothing, + Nothing, Nothing, Nothing, typeof(nlf.jac_prototype), + typeof(nlf.sparsity), Nothing, Nothing, + Nothing, + Any, typeof(nlf.colorvec), + typeof(nlf.sys), typeof(nlf.resid_prototype), Nothing, + }( + wrapped_nlf_f, nlf.mass_matrix, + nlf.analytic, nlf.tgrad, nlf.jac, + nlf.jvp, nlf.vjp, nlf.jac_prototype, + nlf.sparsity, nlf.Wfact, + nlf.Wfact_t, nlf.paramjac, + nlf.observed, nlf.colorvec, nlf.sys, + nlf.resid_prototype, nlf.initialization_data + ) + initprob2 = @set initprob.f = nlf2 + end + end + + # Wrap initializeprobmap(nlsol) -> u0 (OOP, 1-arg) + wrapped_map = initdata.initializeprobmap + if wrapped_map !== nothing && !(wrapped_map isa FunctionWrappersWrappers.FunctionWrappersWrapper) + wrapped_map = FunctionWrappersWrappers.FunctionWrappersWrapper( + wrapped_map, (Tuple{Any},), (Any,) + ) + end + + # Wrap initializeprobpmap(valp, nlsol) -> p (OOP, 2-arg) + wrapped_pmap = initdata.initializeprobpmap + if wrapped_pmap !== nothing && !(wrapped_pmap isa FunctionWrappersWrappers.FunctionWrappersWrapper) + wrapped_pmap = FunctionWrappersWrappers.FunctionWrappersWrapper( + wrapped_pmap, (Tuple{Any, Any},), (Any,) + ) + end + + f = @set f.initialization_data = SciMLBase.OverrideInitData( + initprob2, initdata.update_initializeprob!, + wrapped_map, wrapped_pmap, + initdata.metadata, initdata.is_update_oop + ) + return f +end + # Simple path for algorithms that do NOT use ForwardDiff internally (e.g. Tsit5, Verner). # Avoids calling hasdualpromote/wrapfun_iip which have extension overrides in # DiffEqBaseForwardDiffExt that would create invalidating method table backedges.