@@ -5,27 +5,39 @@ using AbstractPPL.Evaluators: Evaluators, Prepared, VectorEvaluator, _ad_output_
55using ADTypes: AbstractADType, AutoReverseDiff
66using DifferentiationInterface: DifferentiationInterface as DI
77
8- # AD target used by both `DICache` modes . `Vararg{Any,N}` with a free `N`
8+ # AD target used by every `DICache` mode . `Vararg{Any,N}` with a free `N`
99# forces specialization on the trailing arity (a bare `Vararg{Any}` would
1010# skip it). DI invokes this as `_call_evaluator(x, f, c1, …, cN)` on the
1111# constants path, and as `_call_evaluator(x, evaluator)` (via `Fix2`) on
1212# the closure path — empty `ctx` then makes the splat a no-op.
1313@inline _call_evaluator (x, f:: F , ctx:: Vararg{Any,N} ) where {F,N} = f (x, ctx... )
1414
1515# `Mode` tags the cache shape:
16- # * `:closure` — compiled-tape ReverseDiff: target is a `Fix2` closure,
17- # the AD call passes **0** `DI.Constant`s.
18- # * `N::Int` — constants path: `N == length(evaluator.context)`, the
19- # AD call passes **N + 1** `DI.Constant`s (`f` plus the
20- # `N` context values).
21- # Encoding `Mode` in the type resolves the dispatch in `_di_value_and_*`
22- # at compile time without a runtime branch.
23- struct DICache{Mode,F,GP,JP}
16+ # * `:closure` — compiled-tape ReverseDiff: target is a `Fix2` closure, the
17+ # AD call passes **0** `DI.Constant`s.
18+ # * `N::Int` — constants path: `N == length(evaluator.context)`, the AD
19+ # call passes **N + 1** `DI.Constant`s (`f` plus the `N`
20+ # context values).
21+ # Encoding `Mode` in the type resolves the dispatch in `_di_value_and_*` at
22+ # compile time without a runtime branch.
23+ #
24+ # Single cache for every derivative order. At most one of `gradient_prep`,
25+ # `jacobian_prep`, `hessian_prep` is non-`Nothing` at any time; the hot-path
26+ # methods discriminate via `=== nothing` checks (folded at compile time since
27+ # field types are concrete in each instantiation). `grad_buf` / `hess_buf` are
28+ # non-`Nothing` only for order=2 — caller-owned output buffers handed to
29+ # `DI.value_gradient_and_hessian!`. Returned arrays alias them (`!!` contract).
30+ struct DICache{Mode,F,GP,JP,HP,G,H}
2431 target:: F
2532 gradient_prep:: GP
2633 jacobian_prep:: JP
27- function DICache {Mode} (target:: F , gp:: GP , jp:: JP ) where {Mode,F,GP,JP}
28- return new {Mode,F,GP,JP} (target, gp, jp)
34+ hessian_prep:: HP
35+ grad_buf:: G
36+ hess_buf:: H
37+ function DICache {Mode} (
38+ target:: F , gp:: GP , jp:: JP , hp:: HP , g:: G , h:: H
39+ ) where {Mode,F,GP,JP,HP,G,H}
40+ return new {Mode,F,GP,JP,HP,G,H} (target, gp, jp, hp, g, h)
2941 end
3042end
3143
@@ -49,17 +61,40 @@ function _prepare_di(prep::F, adtype::AbstractADType, x, evaluator) where {F}
4961end
5062
5163@inline _wrap_cache (target, gp, jp, :: Val{Mode} ) where {Mode} =
52- DICache {Mode} (target, gp, jp)
64+ DICache {Mode} (target, gp, jp, nothing , nothing , nothing )
5365
5466function AbstractPPL. prepare (
5567 adtype:: AbstractADType ,
5668 problem,
5769 x:: AbstractVector{<:Real} ;
5870 check_dims:: Bool = true ,
5971 context:: Tuple = (),
72+ order:: Int = 1 ,
6073)
6174 evaluator = AbstractPPL. prepare (problem, x; check_dims, context):: VectorEvaluator
6275 arity = _ad_output_arity (evaluator (x))
76+ if order == 2
77+ arity === :scalar || Evaluators. _throw_hessian_needs_scalar ()
78+ if length (x) == 0
79+ # DI Hessian prep crashes on length-0 input; the AD entry
80+ # short-circuits before any DI call. `Val(0)` is a non-`Nothing`
81+ # sentinel for `hessian_prep` so dispatch recognises this as an
82+ # order=2 prep (mirrors the order=1 empty-input pattern below).
83+ cache = _wrap_hessian_cache (
84+ _call_evaluator, Val (0 ), nothing , nothing , Val (length (context))
85+ )
86+ return Prepared (adtype, evaluator, cache)
87+ end
88+ target, hessian_prep, mode = _prepare_di (DI. prepare_hessian, adtype, x, evaluator)
89+ # Buffers pre-allocated from `x` (shape and eltype): the hot path is
90+ # zero-allocation on the gradient/Hessian outputs, and the returned
91+ # arrays alias these slots — copy if you need to retain them.
92+ grad_buf = similar (x)
93+ hess_buf = similar (x, length (x), length (x))
94+ cache = _wrap_hessian_cache (target, hessian_prep, grad_buf, hess_buf, mode)
95+ return Prepared (adtype, evaluator, cache)
96+ end
97+ order == 1 || throw (ArgumentError (" `order` must be 1 or 2, got $order ." ))
6398 if length (x) == 0
6499 # DI prep crashes on length-0 input (e.g. ForwardDiff `BoundsError`).
65100 # `Val(0)` is an arity sentinel for the `gradient_prep === nothing`
@@ -78,6 +113,9 @@ function AbstractPPL.prepare(
78113 return Prepared (adtype, evaluator, _wrap_cache (target, nothing , jacobian_prep, mode))
79114end
80115
116+ @inline _wrap_hessian_cache (target, hp, g, h, :: Val{Mode} ) where {Mode} =
117+ DICache {Mode} (target, nothing , nothing , hp, g, h)
118+
81119# Hot-path dispatch is by `Mode` (closure vs constants), resolved at compile
82120# time. The unconstrained method matches every non-`:closure` `Mode` (i.e.
83121# any `Int N`); `:closure` is strictly more specific and wins for compiled
108146@inline function AbstractPPL. value_and_gradient!! (
109147 p:: Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache} , x:: AbstractVector{T}
110148) where {T<: Real }
149+ # Both `=== nothing` branches fold at compile time: each instantiation
150+ # has concrete field types, so only the relevant branch survives.
151+ p. cache. hessian_prep === nothing || Evaluators. _throw_use_value_gradient_and_hessian ()
111152 p. cache. gradient_prep === nothing && Evaluators. _throw_gradient_needs_scalar ()
112153 Evaluators. _check_ad_input (p. evaluator, x)
113154 # Bypass DI on length-0 input — DI prep paths fail (e.g. ForwardDiff
119160@inline function AbstractPPL. value_and_jacobian!! (
120161 p:: Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache} , x:: AbstractVector{T}
121162) where {T<: Real }
163+ p. cache. hessian_prep === nothing || Evaluators. _throw_use_value_gradient_and_hessian ()
122164 p. cache. jacobian_prep === nothing && Evaluators. _throw_jacobian_needs_vector ()
123165 Evaluators. _check_ad_input (p. evaluator, x)
124166 if length (x) == 0
128170 return _di_value_and_jacobian (p. cache, p. adtype, x, p. evaluator)
129171end
130172
173+ # Hessian hot-path dispatch mirrors the gradient/jacobian helpers above:
174+ # `:closure` (compiled-tape) vs constants `Mode`, resolved at compile time.
175+ # Uses DI's in-place variant `value_gradient_and_hessian!` with caller-owned
176+ # buffers; the returned `(val, grad, hess)` aliases `c.grad_buf` / `c.hess_buf`.
177+ @inline _di_value_gradient_and_hessian (c:: DICache{:closure} , ad, x, _) =
178+ DI. value_gradient_and_hessian! (c. target, c. grad_buf, c. hess_buf, c. hessian_prep, ad, x)
179+ @inline _di_value_gradient_and_hessian (c:: DICache , ad, x, eval) =
180+ DI. value_gradient_and_hessian! (
181+ c. target,
182+ c. grad_buf,
183+ c. hess_buf,
184+ c. hessian_prep,
185+ ad,
186+ x,
187+ DI. Constant (eval. f),
188+ map (DI. Constant, eval. context)... ,
189+ )
190+
191+ @inline function AbstractPPL. value_gradient_and_hessian!! (
192+ p:: Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache} , x:: AbstractVector{T}
193+ ) where {T<: Real }
194+ # Order=1 preps have `hessian_prep === nothing` (compile-folded check).
195+ p. cache. hessian_prep === nothing && Evaluators. _throw_hessian_needs_order_2_prep ()
196+ Evaluators. _check_ad_input (p. evaluator, x)
197+ # Empty-input shortcut — same reasoning as the order=1 path.
198+ length (x) == 0 && return (p. evaluator (x), T[], similar (x, 0 , 0 ))
199+ return _di_value_gradient_and_hessian (p. cache, p. adtype, x, p. evaluator)
200+ end
201+
131202end # module
0 commit comments