Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[sources.BracketingNonlinearSolve]
Expand All @@ -37,6 +38,7 @@ path = "../NonlinearSolveBase"
[extensions]
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
SimpleNonlinearSolveReverseDiffExt = "ReverseDiff"
SimpleNonlinearSolveTaylorDiffExt = "TaylorDiff"
SimpleNonlinearSolveTrackerExt = "Tracker"

[compat]
Expand Down Expand Up @@ -69,6 +71,7 @@ SciMLBase = "2.153, 3"
Setfield = "1.1.1"
StaticArrays = "1.9"
StaticArraysCore = "1.4.3"
TaylorDiff = "0.3"
Test = "1.10"
TestItemRunner = "1"
Tracker = "0.2.35"
Expand All @@ -86,6 +89,7 @@ PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand Down
74 changes: 74 additions & 0 deletions lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTaylorDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
module SimpleNonlinearSolveTaylorDiffExt
using SimpleNonlinearSolve: SimpleNonlinearSolve, SimpleHouseholder, Utils
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem,
AbstractNonlinearSolveAlgorithm
using MaybeInplace: @bb
using FastClosures: @closure
import SciMLBase
import TaylorDiff

SimpleNonlinearSolve.is_extension_loaded(::Val{:TaylorDiff}) = true

const NLBUtils = NonlinearSolveBase.Utils

@inline function __get_higher_order_derivatives(
::SimpleHouseholder{N}, prob, x, fx) where {N}
vN = Val(N)
l = map(one, x)
t = TaylorDiff.make_seed(x, l, vN)

if SciMLBase.isinplace(prob)
bundle = similar(fx, TaylorDiff.TaylorScalar{eltype(fx), N})
prob.f(bundle, t, prob.p)
map!(TaylorDiff.value, fx, bundle)
else
bundle = prob.f(t, prob.p)
fx = map(TaylorDiff.value, bundle)
end
invbundle = inv.(bundle)
num = N == 1 ? map(TaylorDiff.value, invbundle) :
TaylorDiff.extract_derivative(invbundle, Val(N - 1))
den = TaylorDiff.extract_derivative(invbundle, vN)
return num, den, fx
end

function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleHouseholder{N},
args...; abstol = nothing, reltol = nothing, maxiters = 1000,
termination_condition = nothing, alias_u0 = false, kwargs...) where {N}
length(prob.u0) == 1 ||
throw(ArgumentError("SimpleHouseholder only supports scalar problems"))
x = NLBUtils.maybe_unaliased(prob.u0, alias_u0)
fx = NLBUtils.evaluate_f(prob, x)

iszero(fx) &&
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)

abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache(
prob, abstol, reltol, fx, x, termination_condition, Val(:simple))

@bb xo = similar(x)

for i in 1:maxiters
@bb copyto!(xo, x)
num, den, fx = __get_higher_order_derivatives(alg, prob, x, fx)
@bb x .+= N .* num ./ den
solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob)
solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
end

return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
end

function SimpleNonlinearSolve.evaluate_hvvp_internal(
hvvp, prob::ImmutableNonlinearProblem, u, a)
if SciMLBase.isinplace(prob)
binary_f = @closure (y, x) -> prob.f(y, x, prob.p)
TaylorDiff.derivative!(hvvp, binary_f, cache.fu, u, a, Val(2))
else
unary_f = Base.Fix2(prob.f, prob.p)
hvvp = TaylorDiff.derivative(unary_f, u, a, Val(2))
end
hvvp
end

end
3 changes: 2 additions & 1 deletion lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ include("utils.jl")
include("broyden.jl")
include("dfsane.jl")
include("halley.jl")
include("householder.jl")
include("klement.jl")
include("lbroyden.jl")
include("raphson.jl")
Expand Down Expand Up @@ -165,7 +166,7 @@ end
export SimpleBroyden, SimpleKlement, SimpleLimitedMemoryBroyden
export SimpleDFSane
export SimpleGaussNewton, SimpleNewtonRaphson, SimpleTrustRegion
export SimpleHalley
export SimpleHalley, SimpleHouseholder

export solve

Expand Down
1 change: 1 addition & 0 deletions lib/SimpleNonlinearSolve/src/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ A low-overhead implementation of Halley's Method.
- `autodiff`: determines the backend used for the Jacobian. Defaults to `nothing` (i.e.
automatic backend selection). Valid choices include jacobian backends from
`DifferentiationInterface.jl`.
In addition, `AutoTaylorDiff` can be used to enable Taylor mode for computing the Hessian-vector-vector product more efficiently; in this case, the Jacobian would still be calculated using the default backend. You need to have `TaylorDiff.jl` loaded to use this option.
"""
@kwdef @concrete struct SimpleHalley <: AbstractSimpleNonlinearSolveAlgorithm
autodiff = nothing
Expand Down
16 changes: 16 additions & 0 deletions lib/SimpleNonlinearSolve/src/householder.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
SimpleHouseholder{order}()

A low-overhead implementation of Householder's method to arbitrary order.
This method is non-allocating on scalar and static array problems.

!!! warning

Needs `TaylorDiff.jl` to be explicitly loaded before using this functionality.
Internally, this uses TaylorDiff.jl for automatic differentiation.

### Type Parameters

- `order`: the order of the Householder method. `order = 1` is the same as Newton's method, `order = 2` is the same as Halley's method, etc.
"""
struct SimpleHouseholder{order} <: AbstractSimpleNonlinearSolveAlgorithm end
Loading