Skip to content

Commit 2cc3790

Browse files
yebaiclaude
andcommitted
Add order=2 Hessian preparation via value_gradient_and_hessian!!
Extends `prepare(adtype, problem, x; order=2)` to build Hessian machinery for scalar-valued problems on the DI and Mooncake extensions, returning `(value, gradient, hessian)` from a new `value_gradient_and_hessian!!` generic. Unifies the per-extension caches (`DICache`, `MooncakeCache`) so one struct carries every derivative order, with explicit cross-arity error messages replacing prior `MethodError`s. DI uses the in-place `DI.value_gradient_and_hessian!` with caller-owned buffers; Mooncake uses its native `prepare_hessian_cache` API. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 1599516 commit 2cc3790

8 files changed

Lines changed: 387 additions & 50 deletions

File tree

docs/src/evaluators.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,56 @@ library invokes the inner callable many times with same-length dual arrays
138138
derived from a single user-supplied `x`; re-validating on each invocation
139139
would be redundant work in the hot path.
140140

141+
## Hessian (`order=2`)
142+
143+
Pass `order=2` to `prepare` to build a Hessian-capable evaluator. The
144+
returned object answers `value_gradient_and_hessian!!`, which returns
145+
`(value, gradient, hessian)` in a single call. `order=2` requires
146+
`problem` to be scalar-valued; a vector-valued probe throws at preparation
147+
time.
148+
149+
```julia
150+
using AbstractPPL: prepare, value_gradient_and_hessian!!
151+
using ADTypes: AutoForwardDiff
152+
using ForwardDiff, DifferentiationInterface
153+
154+
quadratic(x) = sum(abs2, x)
155+
prepared = prepare(AutoForwardDiff(), quadratic, zeros(3); order=2)
156+
val, grad, hess = value_gradient_and_hessian!!(prepared, [1.0, 2.0, 3.0])
157+
# val == 14.0
158+
# grad == [2.0, 4.0, 6.0]
159+
# hess == [2 0 0; 0 2 0; 0 0 2]
160+
```
161+
162+
Both `context=` and `check_dims=` apply to `order=2` preps with the same
163+
semantics as for `order=1`. The `!!` aliasing contract also extends: the
164+
returned gradient and Hessian may alias internal cache buffers of
165+
`prepared`, so copy before retaining them past the next call. NamedTuple
166+
inputs are not supported at `order=2`.
167+
168+
For DifferentiationInterface, `adtype` can be either a single backend
169+
(letting DI pick its own Hessian strategy) or a
170+
[`DifferentiationInterface.SecondOrder(outer, inner)`](https://juliadiff.org/DifferentiationInterface.jl/stable/api/#DifferentiationInterface.SecondOrder)
171+
composition that selects the outer differentiator and the inner gradient
172+
backend independently — typically forward-over-reverse:
173+
174+
```julia
175+
using DifferentiationInterface: SecondOrder
176+
using ADTypes: AutoForwardDiff, AutoReverseDiff
177+
178+
adtype = SecondOrder(AutoForwardDiff(), AutoReverseDiff())
179+
prepared = prepare(adtype, quadratic, zeros(3); order=2)
180+
```
181+
182+
`SecondOrder <: AbstractADType`, so the same `prepare(adtype, problem, x; order=2)`
183+
entry handles it.
184+
185+
Calling `value_gradient_and_hessian!!` on an `order=1` prep throws an
186+
`ArgumentError` — re-prepare with `order=2` instead. Likewise, calling
187+
`value_and_gradient!!` or `value_and_jacobian!!` on an `order=2` prep is
188+
unsupported; use `value_gradient_and_hessian!!` and discard the unused
189+
return value.
190+
141191
## Constant context arguments
142192

143193
When the underlying callable naturally takes the form `f(x, context...)`
@@ -177,4 +227,5 @@ p([1.0, 2.0, 3.0])
177227
AbstractPPL.prepare
178228
AbstractPPL.value_and_gradient!!
179229
AbstractPPL.value_and_jacobian!!
230+
AbstractPPL.value_gradient_and_hessian!!
180231
```

ext/AbstractPPLDifferentiationInterfaceExt.jl

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,39 @@ using AbstractPPL.Evaluators: Evaluators, Prepared, VectorEvaluator, _ad_output_
55
using ADTypes: AbstractADType, AutoReverseDiff
66
using DifferentiationInterface: DifferentiationInterface as DI
77

8-
# AD target used by both `DICache` modes. `Vararg{Any,N}` with a free `N`
8+
# AD target used by every `DICache` mode. `Vararg{Any,N}` with a free `N`
99
# forces specialization on the trailing arity (a bare `Vararg{Any}` would
1010
# skip it). DI invokes this as `_call_evaluator(x, f, c1, …, cN)` on the
1111
# constants path, and as `_call_evaluator(x, evaluator)` (via `Fix2`) on
1212
# the closure path — empty `ctx` then makes the splat a no-op.
1313
@inline _call_evaluator(x, f::F, ctx::Vararg{Any,N}) where {F,N} = f(x, ctx...)
1414

1515
# `Mode` tags the cache shape:
16-
# * `:closure` — compiled-tape ReverseDiff: target is a `Fix2` closure,
17-
# the AD call passes **0** `DI.Constant`s.
18-
# * `N::Int` — constants path: `N == length(evaluator.context)`, the
19-
# AD call passes **N + 1** `DI.Constant`s (`f` plus the
20-
# `N` context values).
21-
# Encoding `Mode` in the type resolves the dispatch in `_di_value_and_*`
22-
# at compile time without a runtime branch.
23-
struct DICache{Mode,F,GP,JP}
16+
# * `:closure` — compiled-tape ReverseDiff: target is a `Fix2` closure, the
17+
# AD call passes **0** `DI.Constant`s.
18+
# * `N::Int` — constants path: `N == length(evaluator.context)`, the AD
19+
# call passes **N + 1** `DI.Constant`s (`f` plus the `N`
20+
# context values).
21+
# Encoding `Mode` in the type resolves the dispatch in `_di_value_and_*` at
22+
# compile time without a runtime branch.
23+
#
24+
# Single cache for every derivative order. At most one of `gradient_prep`,
25+
# `jacobian_prep`, `hessian_prep` is non-`Nothing` at any time; the hot-path
26+
# methods discriminate via `=== nothing` checks (folded at compile time since
27+
# field types are concrete in each instantiation). `grad_buf` / `hess_buf` are
28+
# non-`Nothing` only for order=2 — caller-owned output buffers handed to
29+
# `DI.value_gradient_and_hessian!`. Returned arrays alias them (`!!` contract).
30+
struct DICache{Mode,F,GP,JP,HP,G,H}
2431
target::F
2532
gradient_prep::GP
2633
jacobian_prep::JP
27-
function DICache{Mode}(target::F, gp::GP, jp::JP) where {Mode,F,GP,JP}
28-
return new{Mode,F,GP,JP}(target, gp, jp)
34+
hessian_prep::HP
35+
grad_buf::G
36+
hess_buf::H
37+
function DICache{Mode}(
38+
target::F, gp::GP, jp::JP, hp::HP, g::G, h::H
39+
) where {Mode,F,GP,JP,HP,G,H}
40+
return new{Mode,F,GP,JP,HP,G,H}(target, gp, jp, hp, g, h)
2941
end
3042
end
3143

@@ -49,17 +61,40 @@ function _prepare_di(prep::F, adtype::AbstractADType, x, evaluator) where {F}
4961
end
5062

5163
@inline _wrap_cache(target, gp, jp, ::Val{Mode}) where {Mode} =
52-
DICache{Mode}(target, gp, jp)
64+
DICache{Mode}(target, gp, jp, nothing, nothing, nothing)
5365

5466
function AbstractPPL.prepare(
5567
adtype::AbstractADType,
5668
problem,
5769
x::AbstractVector{<:Real};
5870
check_dims::Bool=true,
5971
context::Tuple=(),
72+
order::Int=1,
6073
)
6174
evaluator = AbstractPPL.prepare(problem, x; check_dims, context)::VectorEvaluator
6275
arity = _ad_output_arity(evaluator(x))
76+
if order == 2
77+
arity === :scalar || Evaluators._throw_hessian_needs_scalar()
78+
if length(x) == 0
79+
# DI Hessian prep crashes on length-0 input; the AD entry
80+
# short-circuits before any DI call. `Val(0)` is a non-`Nothing`
81+
# sentinel for `hessian_prep` so dispatch recognises this as an
82+
# order=2 prep (mirrors the order=1 empty-input pattern below).
83+
cache = _wrap_hessian_cache(
84+
_call_evaluator, Val(0), nothing, nothing, Val(length(context))
85+
)
86+
return Prepared(adtype, evaluator, cache)
87+
end
88+
target, hessian_prep, mode = _prepare_di(DI.prepare_hessian, adtype, x, evaluator)
89+
# Buffers pre-allocated from `x` (shape and eltype): the hot path is
90+
# zero-allocation on the gradient/Hessian outputs, and the returned
91+
# arrays alias these slots — copy if you need to retain them.
92+
grad_buf = similar(x)
93+
hess_buf = similar(x, length(x), length(x))
94+
cache = _wrap_hessian_cache(target, hessian_prep, grad_buf, hess_buf, mode)
95+
return Prepared(adtype, evaluator, cache)
96+
end
97+
order == 1 || throw(ArgumentError("`order` must be 1 or 2, got $order."))
6398
if length(x) == 0
6499
# DI prep crashes on length-0 input (e.g. ForwardDiff `BoundsError`).
65100
# `Val(0)` is an arity sentinel for the `gradient_prep === nothing`
@@ -78,6 +113,9 @@ function AbstractPPL.prepare(
78113
return Prepared(adtype, evaluator, _wrap_cache(target, nothing, jacobian_prep, mode))
79114
end
80115

116+
@inline _wrap_hessian_cache(target, hp, g, h, ::Val{Mode}) where {Mode} =
117+
DICache{Mode}(target, nothing, nothing, hp, g, h)
118+
81119
# Hot-path dispatch is by `Mode` (closure vs constants), resolved at compile
82120
# time. The unconstrained method matches every non-`:closure` `Mode` (i.e.
83121
# any `Int N`); `:closure` is strictly more specific and wins for compiled
@@ -108,6 +146,9 @@ end
108146
@inline function AbstractPPL.value_and_gradient!!(
109147
p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T}
110148
) where {T<:Real}
149+
# Both `=== nothing` branches fold at compile time: each instantiation
150+
# has concrete field types, so only the relevant branch survives.
151+
p.cache.hessian_prep === nothing || Evaluators._throw_use_value_gradient_and_hessian()
111152
p.cache.gradient_prep === nothing && Evaluators._throw_gradient_needs_scalar()
112153
Evaluators._check_ad_input(p.evaluator, x)
113154
# Bypass DI on length-0 input — DI prep paths fail (e.g. ForwardDiff
@@ -119,6 +160,7 @@ end
119160
@inline function AbstractPPL.value_and_jacobian!!(
120161
p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T}
121162
) where {T<:Real}
163+
p.cache.hessian_prep === nothing || Evaluators._throw_use_value_gradient_and_hessian()
122164
p.cache.jacobian_prep === nothing && Evaluators._throw_jacobian_needs_vector()
123165
Evaluators._check_ad_input(p.evaluator, x)
124166
if length(x) == 0
@@ -128,4 +170,33 @@ end
128170
return _di_value_and_jacobian(p.cache, p.adtype, x, p.evaluator)
129171
end
130172

173+
# Hessian hot-path dispatch mirrors the gradient/jacobian helpers above:
174+
# `:closure` (compiled-tape) vs constants `Mode`, resolved at compile time.
175+
# Uses DI's in-place variant `value_gradient_and_hessian!` with caller-owned
176+
# buffers; the returned `(val, grad, hess)` aliases `c.grad_buf` / `c.hess_buf`.
177+
@inline _di_value_gradient_and_hessian(c::DICache{:closure}, ad, x, _) =
178+
DI.value_gradient_and_hessian!(c.target, c.grad_buf, c.hess_buf, c.hessian_prep, ad, x)
179+
@inline _di_value_gradient_and_hessian(c::DICache, ad, x, eval) =
180+
DI.value_gradient_and_hessian!(
181+
c.target,
182+
c.grad_buf,
183+
c.hess_buf,
184+
c.hessian_prep,
185+
ad,
186+
x,
187+
DI.Constant(eval.f),
188+
map(DI.Constant, eval.context)...,
189+
)
190+
191+
@inline function AbstractPPL.value_gradient_and_hessian!!(
192+
p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T}
193+
) where {T<:Real}
194+
# Order=1 preps have `hessian_prep === nothing` (compile-folded check).
195+
p.cache.hessian_prep === nothing && Evaluators._throw_hessian_needs_order_2_prep()
196+
Evaluators._check_ad_input(p.evaluator, x)
197+
# Empty-input shortcut — same reasoning as the order=1 path.
198+
length(x) == 0 && return (p.evaluator(x), T[], similar(x, 0, 0))
199+
return _di_value_gradient_and_hessian(p.cache, p.adtype, x, p.evaluator)
200+
end
201+
131202
end # module

0 commit comments

Comments
 (0)