Skip to content

Commit 504b1b4

Browse files
Fix allocation regression: replace try-catch with type-based dispatch
The try-catch in AutoSpecializeCallable prevented inlining and added ~32 bytes per call, exceeding the 64-byte @ballocated budget in NonlinearSolveFirstOrder, QuasiNewton, and SpectralMethods tests. Replace with explicit dispatch methods for known argument types (Vector{Float64}, Float64, NullParameters, and ForwardDiff duals), routing to f.fw for zero-allocation calls. Unsupported types fall back to f.orig via vararg dispatch. Also fix @test_broken -> @test for @inferred solve(prob) which now passes. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
1 parent 223b9d5 commit 504b1b4

3 files changed

Lines changed: 54 additions & 22 deletions

File tree

lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
1313

1414
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils, InternalAPI,
1515
NonlinearSolvePolyAlgorithm, NonlinearSolveForwardDiffCache,
16-
NonlinearSolveTag, is_fw_wrapped
16+
NonlinearSolveTag, AutoSpecializeCallable, is_fw_wrapped
1717

1818
import NonlinearSolveBase: wrapfun_iip, standardize_forwarddiff_tag
1919

@@ -28,6 +28,35 @@ dualgen(::Type{T}) where {T} = ForwardDiff.Dual{
2828
ForwardDiff.Tag{NonlinearSolveTag, T}, T, 1,
2929
}
3030

31+
# Fast-path dispatch for IIP calls with NonlinearSolveTag duals.
32+
# These bypass the generic fallback path, calling directly into FunctionWrappersWrapper
33+
# for zero-allocation dispatch.
34+
@inline function (f::AutoSpecializeCallable)(
35+
du::Vector{dualT}, u::Vector{dualT}, p::Vector{Float64},
36+
)
37+
return f.fw(du, u, p)
38+
end
39+
@inline function (f::AutoSpecializeCallable)(
40+
du::Vector{dualT}, u::Vector{dualT}, p::SciMLBase.NullParameters,
41+
)
42+
return f.fw(du, u, p)
43+
end
44+
@inline function (f::AutoSpecializeCallable)(
45+
du::Vector{dualT}, u::Vector{dualT}, p::Vector{dualT},
46+
)
47+
return f.fw(du, u, p)
48+
end
49+
@inline function (f::AutoSpecializeCallable)(
50+
du::Vector{dualT}, u::Vector{Float64}, p::Vector{dualT},
51+
)
52+
return f.fw(du, u, p)
53+
end
54+
@inline function (f::AutoSpecializeCallable)(
55+
du::Vector{dualT}, u::Vector{dualT}, p::Float64,
56+
)
57+
return f.fw(du, u, p)
58+
end
59+
3160
# Helper: build the canonical AutoForwardDiff for wrapped functions (chunksize=1 + tag).
3261
function _wrapped_forwarddiff_ad()
3362
tag = ForwardDiff.Tag(NonlinearSolveTag(), Float64)

lib/NonlinearSolveBase/src/autospecialize.jl

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,29 @@ struct AutoSpecializeCallable{FW} <: Function
4040
orig::Any # type-erased: all wrapped functions share the same Julia type
4141
end
4242

43-
# FunctionWrappersWrapper throws NoFunctionWrapperFoundError (or ErrorException
44-
# in older versions) when no signature matches. Fall back to the original function
45-
# for unsupported argument types (e.g., external packages like LeastSquaresOptim
46-
# doing their own ForwardDiff with different tags/chunksizes, or JVP paths that
47-
# bypass tag standardization).
48-
function (f::AutoSpecializeCallable)(args...)
49-
try
50-
return f.fw(args...)
51-
catch e
52-
if e isa FunctionWrappersWrappers.NoFunctionWrapperFoundError ||
53-
(e isa ErrorException && contains(e.msg, "No matching function wrapper"))
54-
return f.orig(args...)
55-
end
56-
rethrow()
57-
end
43+
# Fast-path dispatch for IIP calls with Vector{Float64} arguments.
44+
# These call directly through the FunctionWrappersWrapper for zero-allocation dispatch.
45+
# The ForwardDiff extension adds analogous methods for dual-number argument types.
46+
@inline function (f::AutoSpecializeCallable)(
47+
du::Vector{Float64}, u::Vector{Float64}, p::Vector{Float64},
48+
)
49+
return f.fw(du, u, p)
50+
end
51+
@inline function (f::AutoSpecializeCallable)(
52+
du::Vector{Float64}, u::Vector{Float64}, p::Float64,
53+
)
54+
return f.fw(du, u, p)
5855
end
56+
@inline function (f::AutoSpecializeCallable)(
57+
du::Vector{Float64}, u::Vector{Float64}, p::SciMLBase.NullParameters,
58+
)
59+
return f.fw(du, u, p)
60+
end
61+
62+
# Fallback: call original function for unsupported argument types (e.g., external
63+
# packages like LeastSquaresOptim doing their own ForwardDiff with different
64+
# tags/chunksizes, or JVP paths that bypass tag standardization).
65+
@inline (f::AutoSpecializeCallable)(args...) = f.orig(args...)
5966

6067
"""
6168
is_fw_wrapped(f) -> Bool

test/core_tests.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -317,19 +317,15 @@ end
317317

318318
@test all(!isnan, sol.u)
319319
@test !SciMLBase.successful_retcode(sol.retcode)
320-
# IIP Vector{Float64} problems go through AutoSpecialize FunctionWrapper wrapping,
321-
# which changes the function type in the solution, breaking type inference.
322-
@test_broken (@inferred solve(prob)) isa Any
320+
@test (@inferred solve(prob)) isa Any
323321

324322
u0 = [0.0, 0.0, 0.0]
325323
prob = NonlinearProblem(f1_infeasible, u0)
326324
sol = solve(prob)
327325

328326
@test all(!isnan, sol.u)
329327
@test !SciMLBase.successful_retcode(sol.retcode)
330-
# OOP Vector{Float64} problems go through AutoSpecialize FunctionWrapper wrapping,
331-
# which uses try-catch fallback for mismatched dual tags, breaking type inference.
332-
@test_broken (@inferred solve(prob)) isa Any
328+
@test (@inferred solve(prob)) isa Any
333329

334330
u0 = @SVector [0.0, 0.0, 0.0]
335331
prob = NonlinearProblem(f1_infeasible, u0)

0 commit comments

Comments
 (0)