@@ -5,27 +5,39 @@ using AbstractPPL.Evaluators: Evaluators, Prepared, VectorEvaluator, _ad_output_
55using ADTypes: AbstractADType, AutoReverseDiff
66using DifferentiationInterface: DifferentiationInterface as DI
77
8- # AD target used by both `DICache` modes . `Vararg{Any,N}` with a free `N`
8+ # AD target used by every `DICache` mode . `Vararg{Any,N}` with a free `N`
99# forces specialization on the trailing arity (a bare `Vararg{Any}` would
1010# skip it). DI invokes this as `_call_evaluator(x, f, c1, …, cN)` on the
1111# constants path, and as `_call_evaluator(x, evaluator)` (via `Fix2`) on
1212# the closure path — empty `ctx` then makes the splat a no-op.
1313@inline _call_evaluator (x, f:: F , ctx:: Vararg{Any,N} ) where {F,N} = f (x, ctx... )
1414
1515# `Mode` tags the cache shape:
16- # * `:closure` — compiled-tape ReverseDiff: target is a `Fix2` closure,
17- # the AD call passes **0** `DI.Constant`s.
18- # * `N::Int` — constants path: `N == length(evaluator.context)`, the
19- # AD call passes **N + 1** `DI.Constant`s (`f` plus the
20- # `N` context values).
21- # Encoding `Mode` in the type resolves the dispatch in `_di_value_and_*`
22- # at compile time without a runtime branch.
23- struct DICache{Mode,F,GP,JP}
16+ # * `:closure` — compiled-tape ReverseDiff: target is a `Fix2` closure, the
17+ # AD call passes **0** `DI.Constant`s.
18+ # * `N::Int` — constants path: `N == length(evaluator.context)`, the AD
19+ # call passes **N + 1** `DI.Constant`s (`f` plus the `N`
20+ # context values).
21+ # Encoding `Mode` in the type resolves the dispatch in `_di_value_and_*` at
22+ # compile time without a runtime branch.
23+ #
24+ # Single cache for every derivative order. At most one of `gradient_prep`,
25+ # `jacobian_prep`, `hessian_prep` is non-`Nothing` at any time; the hot-path
26+ # methods discriminate via `=== nothing` checks (folded at compile time since
27+ # field types are concrete in each instantiation). `grad_buf` / `hess_buf` are
28+ # non-`Nothing` only for order=2 — caller-owned output buffers handed to
29+ # `DI.value_gradient_and_hessian!`. Returned arrays alias them (`!!` contract).
30+ struct DICache{Mode,F,GP,JP,HP,G,H}
2431 target:: F
2532 gradient_prep:: GP
2633 jacobian_prep:: JP
27- function DICache {Mode} (target:: F , gp:: GP , jp:: JP ) where {Mode,F,GP,JP}
28- return new {Mode,F,GP,JP} (target, gp, jp)
34+ hessian_prep:: HP
35+ grad_buf:: G
36+ hess_buf:: H
37+ function DICache {Mode} (
38+ target:: F , gp:: GP , jp:: JP , hp:: HP , g:: G , h:: H
39+ ) where {Mode,F,GP,JP,HP,G,H}
40+ return new {Mode,F,GP,JP,HP,G,H} (target, gp, jp, hp, g, h)
2941 end
3042end
3143
@@ -48,18 +60,42 @@ function _prepare_di(prep::F, adtype::AbstractADType, x, evaluator) where {F}
4860 )
4961end
5062
51- @inline _wrap_cache (target, gp, jp, :: Val{Mode} ) where {Mode} =
52- DICache {Mode} (target, gp, jp)
63+ @inline _wrap_cache (target, gp, jp, :: Val{Mode} ) where {Mode} = DICache {Mode} (
64+ target, gp, jp, nothing , nothing , nothing
65+ )
5366
5467function AbstractPPL. prepare (
5568 adtype:: AbstractADType ,
5669 problem,
5770 x:: AbstractVector{<:Real} ;
5871 check_dims:: Bool = true ,
5972 context:: Tuple = (),
73+ order:: Int = 1 ,
6074)
6175 evaluator = AbstractPPL. prepare (problem, x; check_dims, context):: VectorEvaluator
6276 arity = _ad_output_arity (evaluator (x))
77+ if order == 2
78+ arity === :scalar || Evaluators. _throw_hessian_needs_scalar ()
79+ if length (x) == 0
80+ # DI Hessian prep crashes on length-0 input; the AD entry
81+ # short-circuits before any DI call. `Val(0)` is a non-`Nothing`
82+ # sentinel for `hessian_prep` so dispatch recognises this as an
83+ # order=2 prep (mirrors the order=1 empty-input pattern below).
84+ cache = _wrap_hessian_cache (
85+ _call_evaluator, Val (0 ), nothing , nothing , Val (length (context))
86+ )
87+ return Prepared (adtype, evaluator, cache)
88+ end
89+ target, hessian_prep, mode = _prepare_di (DI. prepare_hessian, adtype, x, evaluator)
90+ # Buffers pre-allocated from `x` (shape and eltype): the hot path is
91+ # zero-allocation on the gradient/Hessian outputs, and the returned
92+ # arrays alias these slots — copy if you need to retain them.
93+ grad_buf = similar (x)
94+ hess_buf = similar (x, length (x), length (x))
95+ cache = _wrap_hessian_cache (target, hessian_prep, grad_buf, hess_buf, mode)
96+ return Prepared (adtype, evaluator, cache)
97+ end
98+ order == 1 || throw (ArgumentError (" `order` must be 1 or 2, got $order ." ))
6399 if length (x) == 0
64100 # DI prep crashes on length-0 input (e.g. ForwardDiff `BoundsError`).
65101 # `Val(0)` is an arity sentinel for the `gradient_prep === nothing`
@@ -78,36 +114,35 @@ function AbstractPPL.prepare(
78114 return Prepared (adtype, evaluator, _wrap_cache (target, nothing , jacobian_prep, mode))
79115end
80116
117+ @inline _wrap_hessian_cache (target, hp, g, h, :: Val{Mode} ) where {Mode} = DICache {Mode} (
118+ target, nothing , nothing , hp, g, h
119+ )
120+
81121# Hot-path dispatch is by `Mode` (closure vs constants), resolved at compile
82122# time. The unconstrained method matches every non-`:closure` `Mode` (i.e.
83123# any `Int N`); `:closure` is strictly more specific and wins for compiled
84124# tapes. On the constants path we always pass `DI.Constant(eval.f)` plus the
85125# `N` context constants — `N == 0` collapses the `map` splat to nothing.
86- @inline _di_value_and_gradient (c:: DICache{:closure} , ad, x, _) =
87- DI. value_and_gradient (c. target, c. gradient_prep, ad, x)
126+ @inline _di_value_and_gradient (c:: DICache{:closure} , ad, x, _) = DI. value_and_gradient (
127+ c. target, c. gradient_prep, ad, x
128+ )
88129@inline _di_value_and_gradient (c:: DICache , ad, x, eval) = DI. value_and_gradient (
89- c. target,
90- c. gradient_prep,
91- ad,
92- x,
93- DI. Constant (eval. f),
94- map (DI. Constant, eval. context)... ,
130+ c. target, c. gradient_prep, ad, x, DI. Constant (eval. f), map (DI. Constant, eval. context)...
95131)
96132
97- @inline _di_value_and_jacobian (c:: DICache{:closure} , ad, x, _) =
98- DI. value_and_jacobian (c. target, c. jacobian_prep, ad, x)
133+ @inline _di_value_and_jacobian (c:: DICache{:closure} , ad, x, _) = DI. value_and_jacobian (
134+ c. target, c. jacobian_prep, ad, x
135+ )
99136@inline _di_value_and_jacobian (c:: DICache , ad, x, eval) = DI. value_and_jacobian (
100- c. target,
101- c. jacobian_prep,
102- ad,
103- x,
104- DI. Constant (eval. f),
105- map (DI. Constant, eval. context)... ,
137+ c. target, c. jacobian_prep, ad, x, DI. Constant (eval. f), map (DI. Constant, eval. context)...
106138)
107139
108140@inline function AbstractPPL. value_and_gradient!! (
109141 p:: Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache} , x:: AbstractVector{T}
110142) where {T<: Real }
143+ # Both `=== nothing` branches fold at compile time: each instantiation
144+ # has concrete field types, so only the relevant branch survives.
145+ p. cache. hessian_prep === nothing || Evaluators. _throw_use_value_gradient_and_hessian ()
111146 p. cache. gradient_prep === nothing && Evaluators. _throw_gradient_needs_scalar ()
112147 Evaluators. _check_ad_input (p. evaluator, x)
113148 # Bypass DI on length-0 input — DI prep paths fail (e.g. ForwardDiff
119154@inline function AbstractPPL. value_and_jacobian!! (
120155 p:: Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache} , x:: AbstractVector{T}
121156) where {T<: Real }
157+ p. cache. hessian_prep === nothing || Evaluators. _throw_use_value_gradient_and_hessian ()
122158 p. cache. jacobian_prep === nothing && Evaluators. _throw_jacobian_needs_vector ()
123159 Evaluators. _check_ad_input (p. evaluator, x)
124160 if length (x) == 0
128164 return _di_value_and_jacobian (p. cache, p. adtype, x, p. evaluator)
129165end
130166
167+ # Hessian hot-path dispatch mirrors the gradient/jacobian helpers above:
168+ # `:closure` (compiled-tape) vs constants `Mode`, resolved at compile time.
169+ # Uses DI's in-place variant `value_gradient_and_hessian!` with caller-owned
170+ # buffers; the returned `(val, grad, hess)` aliases `c.grad_buf` / `c.hess_buf`.
171+ @inline _di_value_gradient_and_hessian (c:: DICache{:closure} , ad, x, _) = DI. value_gradient_and_hessian! (
172+ c. target, c. grad_buf, c. hess_buf, c. hessian_prep, ad, x
173+ )
174+ @inline _di_value_gradient_and_hessian (c:: DICache , ad, x, eval) = DI. value_gradient_and_hessian! (
175+ c. target,
176+ c. grad_buf,
177+ c. hess_buf,
178+ c. hessian_prep,
179+ ad,
180+ x,
181+ DI. Constant (eval. f),
182+ map (DI. Constant, eval. context)... ,
183+ )
184+
185+ @inline function AbstractPPL. value_gradient_and_hessian!! (
186+ p:: Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache} , x:: AbstractVector{T}
187+ ) where {T<: Real }
188+ # Order=1 preps have `hessian_prep === nothing` (compile-folded check).
189+ p. cache. hessian_prep === nothing && Evaluators. _throw_hessian_needs_order_2_prep ()
190+ Evaluators. _check_ad_input (p. evaluator, x)
191+ # Empty-input shortcut — same reasoning as the order=1 path.
192+ length (x) == 0 && return (p. evaluator (x), T[], similar (x, 0 , 0 ))
193+ return _di_value_gradient_and_hessian (p. cache, p. adtype, x, p. evaluator)
194+ end
195+
131196end # module
0 commit comments