Skip to content

Commit 32c6a0d

Browse files
yebaiclaude
andcommitted
Add order=2 Hessian preparation via value_gradient_and_hessian!!
Extends `prepare(adtype, problem, x; order=2)` to build Hessian machinery for scalar-valued problems on the DI and Mooncake extensions, returning `(value, gradient, hessian)` from a new `value_gradient_and_hessian!!` generic. Unifies the per-extension caches (`DICache`, `MooncakeCache`) so one struct carries every derivative order, with explicit cross-arity error messages replacing prior `MethodError`s. DI uses the in-place `DI.value_gradient_and_hessian!` with caller-owned buffers; Mooncake uses its native `prepare_hessian_cache` API. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 1599516 commit 32c6a0d

8 files changed

Lines changed: 398 additions & 67 deletions

File tree

docs/src/evaluators.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,56 @@ library invokes the inner callable many times with same-length dual arrays
138138
derived from a single user-supplied `x`; re-validating on each invocation
139139
would be redundant work in the hot path.
140140

141+
## Hessian (`order=2`)
142+
143+
Pass `order=2` to `prepare` to build a Hessian-capable evaluator. The
144+
returned object answers `value_gradient_and_hessian!!`, which returns
145+
`(value, gradient, hessian)` in a single call. `order=2` requires
146+
`problem` to be scalar-valued; a vector-valued probe throws at preparation
147+
time.
148+
149+
```julia
150+
using AbstractPPL: prepare, value_gradient_and_hessian!!
151+
using ADTypes: AutoForwardDiff
152+
using ForwardDiff, DifferentiationInterface
153+
154+
quadratic(x) = sum(abs2, x)
155+
prepared = prepare(AutoForwardDiff(), quadratic, zeros(3); order=2)
156+
val, grad, hess = value_gradient_and_hessian!!(prepared, [1.0, 2.0, 3.0])
157+
# val == 14.0
158+
# grad == [2.0, 4.0, 6.0]
159+
# hess == [2 0 0; 0 2 0; 0 0 2]
160+
```
161+
162+
Both `context=` and `check_dims=` apply to `order=2` preps with the same
163+
semantics as for `order=1`. The `!!` aliasing contract also extends: the
164+
returned gradient and Hessian may alias internal cache buffers of
165+
`prepared`, so copy before retaining them past the next call. NamedTuple
166+
inputs are not supported at `order=2`.
167+
168+
For DifferentiationInterface, `adtype` can be either a single backend
169+
(letting DI pick its own Hessian strategy) or a
170+
[`DifferentiationInterface.SecondOrder(outer, inner)`](https://juliadiff.org/DifferentiationInterface.jl/stable/api/#DifferentiationInterface.SecondOrder)
171+
composition that selects the outer differentiator and the inner gradient
172+
backend independently — typically forward-over-reverse:
173+
174+
```julia
175+
using DifferentiationInterface: SecondOrder
176+
using ADTypes: AutoForwardDiff, AutoReverseDiff
177+
178+
adtype = SecondOrder(AutoForwardDiff(), AutoReverseDiff())
179+
prepared = prepare(adtype, quadratic, zeros(3); order=2)
180+
```
181+
182+
`SecondOrder <: AbstractADType`, so the same `prepare(adtype, problem, x; order=2)`
183+
entry handles it.
184+
185+
Calling `value_gradient_and_hessian!!` on an `order=1` prep throws an
186+
`ArgumentError` — re-prepare with `order=2` instead. Likewise, calling
187+
`value_and_gradient!!` or `value_and_jacobian!!` on an `order=2` prep is
188+
unsupported; use `value_gradient_and_hessian!!` and discard the unused
189+
return value.
190+
141191
## Constant context arguments
142192

143193
When the underlying callable naturally takes the form `f(x, context...)`
@@ -177,4 +227,5 @@ p([1.0, 2.0, 3.0])
177227
AbstractPPL.prepare
178228
AbstractPPL.value_and_gradient!!
179229
AbstractPPL.value_and_jacobian!!
230+
AbstractPPL.value_gradient_and_hessian!!
180231
```

ext/AbstractPPLDifferentiationInterfaceExt.jl

Lines changed: 94 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,39 @@ using AbstractPPL.Evaluators: Evaluators, Prepared, VectorEvaluator, _ad_output_
55
using ADTypes: AbstractADType, AutoReverseDiff
66
using DifferentiationInterface: DifferentiationInterface as DI
77

8-
# AD target used by both `DICache` modes. `Vararg{Any,N}` with a free `N`
8+
# AD target used by every `DICache` mode. `Vararg{Any,N}` with a free `N`
99
# forces specialization on the trailing arity (a bare `Vararg{Any}` would
1010
# skip it). DI invokes this as `_call_evaluator(x, f, c1, …, cN)` on the
1111
# constants path, and as `_call_evaluator(x, evaluator)` (via `Fix2`) on
1212
# the closure path — empty `ctx` then makes the splat a no-op.
1313
@inline _call_evaluator(x, f::F, ctx::Vararg{Any,N}) where {F,N} = f(x, ctx...)
1414

1515
# `Mode` tags the cache shape:
16-
# * `:closure` — compiled-tape ReverseDiff: target is a `Fix2` closure,
17-
# the AD call passes **0** `DI.Constant`s.
18-
# * `N::Int` — constants path: `N == length(evaluator.context)`, the
19-
# AD call passes **N + 1** `DI.Constant`s (`f` plus the
20-
# `N` context values).
21-
# Encoding `Mode` in the type resolves the dispatch in `_di_value_and_*`
22-
# at compile time without a runtime branch.
23-
struct DICache{Mode,F,GP,JP}
16+
# * `:closure` — compiled-tape ReverseDiff: target is a `Fix2` closure, the
17+
# AD call passes **0** `DI.Constant`s.
18+
# * `N::Int` — constants path: `N == length(evaluator.context)`, the AD
19+
# call passes **N + 1** `DI.Constant`s (`f` plus the `N`
20+
# context values).
21+
# Encoding `Mode` in the type resolves the dispatch in `_di_value_and_*` at
22+
# compile time without a runtime branch.
23+
#
24+
# Single cache for every derivative order. At most one of `gradient_prep`,
25+
# `jacobian_prep`, `hessian_prep` is non-`Nothing` at any time; the hot-path
26+
# methods discriminate via `=== nothing` checks (folded at compile time since
27+
# field types are concrete in each instantiation). `grad_buf` / `hess_buf` are
28+
# non-`Nothing` only for order=2 — caller-owned output buffers handed to
29+
# `DI.value_gradient_and_hessian!`. Returned arrays alias them (`!!` contract).
30+
struct DICache{Mode,F,GP,JP,HP,G,H}
2431
target::F
2532
gradient_prep::GP
2633
jacobian_prep::JP
27-
function DICache{Mode}(target::F, gp::GP, jp::JP) where {Mode,F,GP,JP}
28-
return new{Mode,F,GP,JP}(target, gp, jp)
34+
hessian_prep::HP
35+
grad_buf::G
36+
hess_buf::H
37+
function DICache{Mode}(
38+
target::F, gp::GP, jp::JP, hp::HP, g::G, h::H
39+
) where {Mode,F,GP,JP,HP,G,H}
40+
return new{Mode,F,GP,JP,HP,G,H}(target, gp, jp, hp, g, h)
2941
end
3042
end
3143

@@ -48,18 +60,42 @@ function _prepare_di(prep::F, adtype::AbstractADType, x, evaluator) where {F}
4860
)
4961
end
5062

51-
@inline _wrap_cache(target, gp, jp, ::Val{Mode}) where {Mode} =
52-
DICache{Mode}(target, gp, jp)
63+
@inline _wrap_cache(target, gp, jp, ::Val{Mode}) where {Mode} = DICache{Mode}(
64+
target, gp, jp, nothing, nothing, nothing
65+
)
5366

5467
function AbstractPPL.prepare(
5568
adtype::AbstractADType,
5669
problem,
5770
x::AbstractVector{<:Real};
5871
check_dims::Bool=true,
5972
context::Tuple=(),
73+
order::Int=1,
6074
)
6175
evaluator = AbstractPPL.prepare(problem, x; check_dims, context)::VectorEvaluator
6276
arity = _ad_output_arity(evaluator(x))
77+
if order == 2
78+
arity === :scalar || Evaluators._throw_hessian_needs_scalar()
79+
if length(x) == 0
80+
# DI Hessian prep crashes on length-0 input; the AD entry
81+
# short-circuits before any DI call. `Val(0)` is a non-`Nothing`
82+
# sentinel for `hessian_prep` so dispatch recognises this as an
83+
# order=2 prep (mirrors the order=1 empty-input pattern below).
84+
cache = _wrap_hessian_cache(
85+
_call_evaluator, Val(0), nothing, nothing, Val(length(context))
86+
)
87+
return Prepared(adtype, evaluator, cache)
88+
end
89+
target, hessian_prep, mode = _prepare_di(DI.prepare_hessian, adtype, x, evaluator)
90+
# Buffers pre-allocated from `x` (shape and eltype): the hot path is
91+
# zero-allocation on the gradient/Hessian outputs, and the returned
92+
# arrays alias these slots — copy if you need to retain them.
93+
grad_buf = similar(x)
94+
hess_buf = similar(x, length(x), length(x))
95+
cache = _wrap_hessian_cache(target, hessian_prep, grad_buf, hess_buf, mode)
96+
return Prepared(adtype, evaluator, cache)
97+
end
98+
order == 1 || throw(ArgumentError("`order` must be 1 or 2, got $order."))
6399
if length(x) == 0
64100
# DI prep crashes on length-0 input (e.g. ForwardDiff `BoundsError`).
65101
# `Val(0)` is an arity sentinel for the `gradient_prep === nothing`
@@ -78,36 +114,35 @@ function AbstractPPL.prepare(
78114
return Prepared(adtype, evaluator, _wrap_cache(target, nothing, jacobian_prep, mode))
79115
end
80116

117+
@inline _wrap_hessian_cache(target, hp, g, h, ::Val{Mode}) where {Mode} = DICache{Mode}(
118+
target, nothing, nothing, hp, g, h
119+
)
120+
81121
# Hot-path dispatch is by `Mode` (closure vs constants), resolved at compile
82122
# time. The unconstrained method matches every non-`:closure` `Mode` (i.e.
83123
# any `Int N`); `:closure` is strictly more specific and wins for compiled
84124
# tapes. On the constants path we always pass `DI.Constant(eval.f)` plus the
85125
# `N` context constants — `N == 0` collapses the `map` splat to nothing.
86-
@inline _di_value_and_gradient(c::DICache{:closure}, ad, x, _) =
87-
DI.value_and_gradient(c.target, c.gradient_prep, ad, x)
126+
@inline _di_value_and_gradient(c::DICache{:closure}, ad, x, _) = DI.value_and_gradient(
127+
c.target, c.gradient_prep, ad, x
128+
)
88129
@inline _di_value_and_gradient(c::DICache, ad, x, eval) = DI.value_and_gradient(
89-
c.target,
90-
c.gradient_prep,
91-
ad,
92-
x,
93-
DI.Constant(eval.f),
94-
map(DI.Constant, eval.context)...,
130+
c.target, c.gradient_prep, ad, x, DI.Constant(eval.f), map(DI.Constant, eval.context)...
95131
)
96132

97-
@inline _di_value_and_jacobian(c::DICache{:closure}, ad, x, _) =
98-
DI.value_and_jacobian(c.target, c.jacobian_prep, ad, x)
133+
@inline _di_value_and_jacobian(c::DICache{:closure}, ad, x, _) = DI.value_and_jacobian(
134+
c.target, c.jacobian_prep, ad, x
135+
)
99136
@inline _di_value_and_jacobian(c::DICache, ad, x, eval) = DI.value_and_jacobian(
100-
c.target,
101-
c.jacobian_prep,
102-
ad,
103-
x,
104-
DI.Constant(eval.f),
105-
map(DI.Constant, eval.context)...,
137+
c.target, c.jacobian_prep, ad, x, DI.Constant(eval.f), map(DI.Constant, eval.context)...
106138
)
107139

108140
@inline function AbstractPPL.value_and_gradient!!(
109141
p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T}
110142
) where {T<:Real}
143+
# Both `=== nothing` branches fold at compile time: each instantiation
144+
# has concrete field types, so only the relevant branch survives.
145+
p.cache.hessian_prep === nothing || Evaluators._throw_use_value_gradient_and_hessian()
111146
p.cache.gradient_prep === nothing && Evaluators._throw_gradient_needs_scalar()
112147
Evaluators._check_ad_input(p.evaluator, x)
113148
# Bypass DI on length-0 input — DI prep paths fail (e.g. ForwardDiff
@@ -119,6 +154,7 @@ end
119154
@inline function AbstractPPL.value_and_jacobian!!(
120155
p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T}
121156
) where {T<:Real}
157+
p.cache.hessian_prep === nothing || Evaluators._throw_use_value_gradient_and_hessian()
122158
p.cache.jacobian_prep === nothing && Evaluators._throw_jacobian_needs_vector()
123159
Evaluators._check_ad_input(p.evaluator, x)
124160
if length(x) == 0
@@ -128,4 +164,33 @@ end
128164
return _di_value_and_jacobian(p.cache, p.adtype, x, p.evaluator)
129165
end
130166

167+
# Hessian hot-path dispatch mirrors the gradient/jacobian helpers above:
168+
# `:closure` (compiled-tape) vs constants `Mode`, resolved at compile time.
169+
# Uses DI's in-place variant `value_gradient_and_hessian!` with caller-owned
170+
# buffers; the returned `(val, grad, hess)` aliases `c.grad_buf` / `c.hess_buf`.
171+
@inline _di_value_gradient_and_hessian(c::DICache{:closure}, ad, x, _) = DI.value_gradient_and_hessian!(
172+
c.target, c.grad_buf, c.hess_buf, c.hessian_prep, ad, x
173+
)
174+
@inline _di_value_gradient_and_hessian(c::DICache, ad, x, eval) = DI.value_gradient_and_hessian!(
175+
c.target,
176+
c.grad_buf,
177+
c.hess_buf,
178+
c.hessian_prep,
179+
ad,
180+
x,
181+
DI.Constant(eval.f),
182+
map(DI.Constant, eval.context)...,
183+
)
184+
185+
@inline function AbstractPPL.value_gradient_and_hessian!!(
186+
p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T}
187+
) where {T<:Real}
188+
# Order=1 preps have `hessian_prep === nothing` (compile-folded check).
189+
p.cache.hessian_prep === nothing && Evaluators._throw_hessian_needs_order_2_prep()
190+
Evaluators._check_ad_input(p.evaluator, x)
191+
# Empty-input shortcut — same reasoning as the order=1 path.
192+
length(x) == 0 && return (p.evaluator(x), T[], similar(x, 0, 0))
193+
return _di_value_gradient_and_hessian(p.cache, p.adtype, x, p.evaluator)
194+
end
195+
131196
end # module

0 commit comments

Comments
 (0)