Skip to content

Commit 9ebf18f

Browse files
Copilotyebaiclaude
authored
Add direct ForwardDiff extension (#166)
* Initial plan * Add ForwardDiff extension using ForwardDiff's public API Adds `AbstractPPLForwardDiffExt` that directly uses ForwardDiff's public API (gradient!, jacobian!, hessian! with DiffResults and pre-allocated configs), mirroring the Mooncake extension pattern. - ext/AbstractPPLForwardDiffExt.jl: full extension with prepare, value_and_gradient!!, value_and_jacobian!!, and value_gradient_and_hessian!! implementations - Project.toml: ForwardDiff + DiffResults as weakdeps with extension registration and compat entries - test/ext/forwarddiff/: dedicated test environment running all standard test cases plus context and empty-input tests Agent-Logs-Url: https://github.com/TuringLang/AbstractPPL.jl/sessions/2f9552bb-c72d-4891-a973-8ecc68959e06 Co-authored-by: yebai <3279477+yebai@users.noreply.github.com> * Fix CI: format ext files and route DI cache structural test through AutoReverseDiff Run JuliaFormatter on ext/AbstractPPLForwardDiffExt.jl and test/ext/forwarddiff/main.jl to satisfy the Format CI job. The "DI cache encodes the call mode as a type parameter" testset asserted `DIGradientCache{0}` and `DIGradientCache{2}` cache types for `AutoForwardDiff`, but the new direct `AbstractPPLForwardDiffExt` path now takes precedence over DI when both extensions are loaded. `AutoReverseDiff()` (non-compiled) exercises the same DI constants path and keeps the structural assertion meaningful. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Apply scrutinise findings to ForwardDiff extension - Drop section banners and WHAT-comments in the extension; keep WHYs (chunk-size dispatch, fresh `Fix2` per call is Tag-type-stable, separate gradient cache on the order=2 prep). - Tighten the `prepare` docstring's second paragraph. - Remove the "empty input" testset: `run_testcases(Val(:vector))` and `run_testcases(Val(:hessian))` already cover zero-length input for every arity / order combination via `AbstractPPLTestExt`. - Remove the trailing arity-mismatch `@test_throws` from the "context-lowered gradient" testset: `run_testcases(Val(:edge))` already covers "jacobian of scalar output". - Drop now-unused imports (`value_and_jacobian!!`, `value_gradient_and_hessian!!`, `order`, `DiffResults`). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Wire ext/forwarddiff into CI Add `ext/forwarddiff` to `VALID_LABELS` in `test/run_extras.jl` and to the CI ext matrix so the chunk-size and context tests this branch introduces actually run on CI (they were silently skipped before — `AutoForwardDiff` was only exercised via the DI test env). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Unify FD gradient/Jacobian/Hessian caches into FDCache{A} Replace `FDGradientCache`, `FDJacobianCache`, and `FDHessianCache` with one parametric `FDCache{A,R,C,GR,GC}` keyed on an arity/order `Symbol` `A ∈ (:scalar, :vector, :hessian)`, mirroring the `MooncakeCache{A}` pattern. Hot paths and arity-mismatch rejections dispatch on the tag at compile time exactly as before; `result::Nothing` remains the empty-input sentinel. Verified type-stable on all four `!!` entries. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Honor AutoForwardDiff tag and probe the problem once in prepare `ext/AbstractPPLForwardDiffExt.jl`: * Thread `adtype.tag` into every `*Config` constructor via a small `_fd_tag` helper; `nothing` (the ADTypes default) reproduces ForwardDiff's per-constructor default of `Tag(target, eltype(x))`, so callers can now use `AutoForwardDiff(; tag=…)` for nested differentiation through AbstractPPL. * Hoist the arity-probe `evaluator(x)` to a single `y_probe` local and reuse it as the Jacobian-result prototype on the vector branch. The base `prepare` contract promises one prep-time call into `problem`; the vector path was invoking it twice. * Cache `target = _fd_target(evaluator)` once locally rather than reconstructing the `Fix2` per config. `test/ext/forwarddiff/main.jl`: add a regression test asserting the user-supplied tag flows into the stored config's first type parameter. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Tidy two minor redundancies in the FD extension * Drop the unused `where {T<:Real}` binding on the empty-input Jacobian method; the non-empty sibling already uses `x::AbstractVector{<:Real}` directly. * Pass two `nothing`s to `FDCache{:hessian}(nothing, nothing)` for the empty-input order=2 cache instead of four — the constructor defaults `gradient_result` and `gradient_config` to `nothing`, so the resulting type is identical and the line is consistent with the `:scalar` / `:vector` empty-input shortcuts above. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Bump to 0.15.2 with HISTORY entry for the FD extension Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Inline _fd_target and clarify why _fd_call must stay top-level `_fd_target(e)` was just `Base.Fix2(_fd_call, e)` — inline the five call sites and drop the helper. `_fd_call` stays as a top-level named function: ForwardDiff's `*Config` keys its `Tag` on the target type, and a closure built inside one method would have a different type from one built inside another, desyncing the per-call target from the config captured at prep time. Reworded the comment to make that constraint (not the harmless cost of per-call `Fix2`) the WHY. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Share :allocations and :type_stability groups via AbstractPPLTestExt The pending changes in `test/ext/forwarddiff/main.jl` and `test/ext/mooncake/main.jl` had duplicated the same helper functions and testset bodies for "allocation-free hot paths" and "type-stable hot paths". Lift the shared logic into `AbstractPPLTestExt`: * `IdentityProblem`: allocation-free vector-output problem (avoids `VectorValuedProblem`'s result allocation masking AD-path allocations). * `_inferred_*` helpers wrap `@inferred` so it can be marked broken via `@test_broken`. * `run_testcases(Val(:allocations); ...)`: `@allocated == 0` checks on scalar gradient and vector Jacobian, with `gradient_broken` / `jacobian_broken` kwargs for backends with known regressions. * `run_testcases(Val(:type_stability); ...)`: `@inferred` checks on gradient/Jacobian/Hessian hot paths, with matching `*_broken` kwargs. Both extension test files now invoke the shared groups; Mooncake passes `jacobian_broken=true` for `:allocations` (both modes) and for `:type_stability` only on `AutoMooncakeForward` (`Tuple{Any, Union{Array{T,3}, Matrix}}` inference). Docstring on `generate_testcases` updated to list the new keys. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Apply scrutinise findings to the shared test groups * Generalize the `*_broken` comment: previously cited only "Mooncake's forward-mode Jacobian", but the kwargs cover other regressions too (Mooncake's `value_and_jacobian!!` allocates on every call across both modes; only the forward-mode Jacobian *inference* is broken). * Unify the `:allocations` vs `:type_stability` branch style — both now use the same `if/else` form rather than the ternary the former was using inconsistently. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Fix CI: format the ext, gate :allocations on Julia 1.10, add :context group * Format: match the CI JuliaFormatter v1.0.62 baseline (the local one I was using is on 2.x and disagrees on `return`-keyword placement). * `:allocations` group: Julia 1.10 heap-allocates `Fix2`/closure captures that 1.11+ elides. Mark `gradient_broken=VERSION < v"1.11"` (and `jacobian_broken=VERSION < v"1.11"` on FD) so min CI doesn't flag the older runtime as a regression. * New `:context` group: lifts the inline "context-lowered gradient" testset from the FD test into `AbstractPPLTestExt`. Verifies `prepare(adtype, f, x; context=(c,))` lowers the context out of the gradient. FD calls it in place of the inline testset; Mooncake adds it alongside its richer Mooncake-specific context testset (forward parity, vector arity rejection, empty input with context). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Skip ForwardDiff.checktag so custom Tag sentinels work in hot paths The custom-tag path was only structurally tested (the tag flowed into the config type parameter), not exercised through an actual AD call. DynamicPPL's downstream tests caught the gap: `AutoForwardDiff(; tag=Tag{DynamicPPLTag,Float64}())` carries a sentinel tag whose first type parameter is *not* `typeof(target)`, so ForwardDiff's default `checktag` errors when the hot path calls `ForwardDiff.gradient!`. Pass `Val(false)` to skip `checktag` in all four hot-path calls (this is what DifferentiationInterface does). The tag is purely a label for the outer Dual scope; the config built at prep time already encodes the right tag, so the check is redundant and harmful in the custom-tag case. Strengthen the regression test to actually run `value_and_gradient!!` on a prep built with a custom sentinel tag and assert the gradient matches the analytic value — would have caught the original bug. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Honor caller atol in the :context group's gradient assertion The gradient comparison was hardcoding `atol = 1e-10` while the value comparison above it (and every other group in this file) used `atol = atol`. The hardcoded value silently overrode the caller's kwarg — Mooncake calls with `atol = 1e-6` were getting the tighter 1e-10. Use `atol = atol` to match the surrounding pattern. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Unify the conformance harness into TestCase + run_testcase Collapse the four case structs (`ValueCase`, `HessianCase`, `ErrorCase`, `CacheReuseCase`) and seven `run_testcases(Val(:group); ...)` methods into a single tagged `TestCase` and a single `run_testcase(case; ...)` that dispatches on `case.tag` via `Val`. Each backend's test script is now a single uniform loop: ```julia for case in generate_testcases() run_testcase(case; adtype=AutoForwardDiff(), atol=1e-6, rtol=1e-6, allocations=:test, type_stability=:test) end ``` Tags are `:vector`, `:hessian`, `:context`, `:edge`, `:cache_reuse`, `:namedtuple`. NamedTuple-input cases live in `generate_namedtuple_testcases()` so backends that don't support that input shape don't need to filter at the call site. `allocations` / `type_stability` accept `:skip` / `:test` / `:broken` (`:broken` wraps as `@test_broken` for known regressions). Per-case `allocations_safe::Bool` defaults to `true`; cases with allocating primals (`VectorValuedProblem` result vector, empty-input shortcuts, hessian scratch, cache-reuse loops) opt out so the runner skips the alloc check regardless of caller intent. Case types and stubs (`TestCase`, `generate_testcases`, `generate_namedtuple_testcases`, `run_testcase`) live in `src/AbstractPPL.jl`; the generators and dispatched runner live in `ext/AbstractPPLTestExt.jl`. The old `Val{group}` API and the standalone `IdentityProblem` fixture are gone. Backend-specific broken predicates (`_mooncake_alloc`, `_mooncake_inferred`) sit next to the loop they drive — they encode Mooncake's known issues (allocating Jacobian, forward-mode Jacobian and context inference) without touching the shared harness. Local: FD 111/111, Mooncake 149 pass + 3 broken, DI 115/115. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Move TestCase to the Test extension; one Val-dispatched generator `src/AbstractPPL.jl` keeps only the two function stubs `generate_testcases` and `run_testcase`. The `TestCase` struct (and its keyword-arg constructor) moves into `ext/AbstractPPLTestExt.jl` — test scripts only access `case.tag` via field access, so the type itself doesn't need to live in main. Collapse the two separate generators into a single Val-dispatched function: generate_testcases(Val(:vector)) — all vector-input cases generate_testcases(Val(:namedtuple)) — NamedTuple-input cases Backends iterate `generate_testcases(Val(:vector))` (and Mooncake also `Val(:namedtuple)`). Local: FD 111/111, Mooncake 149 pass + 3 broken, DI 115/115 — no behavioural change. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Merge :vector and :context runners The two `_run(Val{...})` methods differed only by passing `context=case.context` to `prepare` (no-op for `:vector` cases since `case.context` defaults to `()`). Collapse into one method with `Union{Val{:vector},Val{:context}}` dispatch. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Document the remaining allocations_safe=false reasons The empty-input hessian case and both cache-reuse cases set `allocations_safe=false` without an inline reason, while the other four instances do. Add brief explanations to match. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Drop Julia 1.10 broken-marker on scalar-gradient allocations CI reported `Unexpected Pass` on Julia 1.10 for `quadratic (scalar output)` (FD + Mooncake) and `scalar gradient with context` (FD): the recent FD-ext tweaks (skip-checktag, hoisted target/tag locals) made these paths alloc-free on 1.10 too. Drop the `VERSION < v"1.11"` gating; the per-case `allocations_safe=false` still filters the genuinely-allocating paths (vector jacobian, empty-input shortcuts, hessian, cache-reuse loops). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Skip Mooncake :alloc checks on Julia 1.10 — they're resolver-flaky CI reported back-to-back inconsistent results on Julia 1.10 for the same code: one run had Mooncake's scalar-gradient `@allocated` come out 0 (an Unexpected Pass when marked `:broken`), the next run had it at 256 (a Test Failed when marked `:test`). The dependency resolver picks slightly different Mooncake versions between runs, and Mooncake 0.5.x's allocation behaviour on 1.10 isn't stable across them. Set `_mooncake_alloc` to return `:skip` on Julia 1.10 instead of either `:test` or `:broken` — that way the check doesn't fire on min, regardless of which Mooncake version the resolver picked. Latest-Julia coverage is unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: yebai <3279477+yebai@users.noreply.github.com> Co-authored-by: Hong Ge <hg344@cam.ac.uk> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 2d35e51 commit 9ebf18f

11 files changed

Lines changed: 763 additions & 270 deletions

File tree

.github/workflows/CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ jobs:
6161
matrix:
6262
label:
6363
- ext/differentiationinterface
64+
- ext/forwarddiff
6465
- ext/mooncake
6566
version:
6667
- '1'

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.15.2
2+
3+
Added `AbstractPPLForwardDiffExt`, a direct ForwardDiff path for `AutoForwardDiff` (gradient, Jacobian, Hessian, `context`, chunk size, custom `tag`).
4+
15
## 0.15.1
26

37
Added Hessian support to the AD interface. Pass `order=2` to `prepare(adtype, problem, x)` to build a Hessian-capable evaluator. The new `value_gradient_and_hessian!!(prepared, x)` then returns `(value, gradient, hessian)` in a single call. Both the DifferentiationInterface and Mooncake extensions implement this.

Project.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
33
keywords = ["probabilistic programming"]
44
license = "MIT"
55
desc = "Common interfaces for probabilistic programming"
6-
version = "0.15.1"
6+
version = "0.15.2"
77

88
[deps]
99
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -19,14 +19,17 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1919
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2020

2121
[weakdeps]
22+
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
2223
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
2324
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
25+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2426
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2527
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2628

2729
[extensions]
2830
AbstractPPLDifferentiationInterfaceExt = ["DifferentiationInterface"]
2931
AbstractPPLDistributionsExt = ["Distributions", "LinearAlgebra"]
32+
AbstractPPLForwardDiffExt = ["ForwardDiff", "DiffResults"]
3033
AbstractPPLMooncakeExt = ["Mooncake"]
3134
AbstractPPLTestExt = ["Test"]
3235

@@ -36,8 +39,10 @@ AbstractMCMC = "2, 3, 4, 5"
3639
Accessors = "0.1"
3740
BangBang = "0.4"
3841
DensityInterface = "0.4"
42+
DiffResults = "1"
3943
DifferentiationInterface = "0.6, 0.7"
4044
Distributions = "0.25"
45+
ForwardDiff = "0.10, 1"
4146
JSON = "0.19 - 0.21, 1"
4247
LinearAlgebra = "<0.0.1, 1"
4348
MacroTools = "0.5"

ext/AbstractPPLForwardDiffExt.jl

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
module AbstractPPLForwardDiffExt
2+
3+
using AbstractPPL: AbstractPPL
4+
using AbstractPPL.Evaluators: Evaluators, Prepared, VectorEvaluator, _ad_output_arity
5+
using ADTypes: AutoForwardDiff
6+
using ForwardDiff: ForwardDiff
7+
using DiffResults: DiffResults
8+
9+
# `AutoForwardDiff{CS}` carries the chunk size as a type parameter; `nothing`
10+
# defers the choice to ForwardDiff.
11+
_fd_chunk(::AutoForwardDiff{nothing}, x) = ForwardDiff.Chunk(x)
12+
_fd_chunk(::AutoForwardDiff{CS}, _) where {CS} = ForwardDiff.Chunk{CS}()
13+
14+
# A user-supplied `adtype.tag` (for nested differentiation) is threaded into the
15+
# `*Config` constructors; `nothing` (the ADTypes default) reproduces
16+
# ForwardDiff's per-constructor default of `Tag(target, eltype(x))`.
17+
@inline _fd_tag(adtype::AutoForwardDiff, target, x) =
18+
adtype.tag === nothing ? ForwardDiff.Tag(target, eltype(x)) : adtype.tag
19+
20+
# `A::Symbol` ∈ `(:scalar, :vector, :hessian)` encodes both output arity
21+
# (order=1) and order (order=2 ≡ `:hessian`), so dispatch resolves the hot path
22+
# and the arity-mismatch failure modes at compile time without a runtime branch.
23+
# `gradient_result` / `gradient_config` are populated only on `:hessian` caches
24+
# so `value_and_gradient!!` on an order=2 prep skips the O(n²) Hessian work.
25+
# `result::Nothing` is the empty-input sentinel: hot paths dispatch on
26+
# `FDCache{A,Nothing}` to short-circuit before any ForwardDiff call (chunk
27+
# selection `BoundsError`s on length-zero inputs). The stored `result` aliases
28+
# the arrays returned by `value_and_*!!`, per the `!!` contract.
29+
struct FDCache{A,R,C,GR,GC}
30+
result::R
31+
config::C
32+
gradient_result::GR
33+
gradient_config::GC
34+
function FDCache{A}(
35+
result::R, config::C, gradient_result::GR=nothing, gradient_config::GC=nothing
36+
) where {A,R,C,GR,GC}
37+
return new{A,R,C,GR,GC}(result, config, gradient_result, gradient_config)
38+
end
39+
end
40+
41+
"""
42+
prepare(adtype::AutoForwardDiff, problem, x; check_dims=true, context::Tuple=(), order=1)
43+
44+
Prepare a ForwardDiff gradient, Jacobian, or Hessian evaluator for a vector
45+
input. `order=1` (default) picks gradient/Jacobian by output arity; `order=2`
46+
builds Hessian machinery and requires a scalar-valued problem. `context` and
47+
`check_dims` follow the base `prepare` contract.
48+
"""
49+
function AbstractPPL.prepare(
50+
adtype::AutoForwardDiff,
51+
problem,
52+
x::AbstractVector{<:Real};
53+
check_dims::Bool=true,
54+
context::Tuple=(),
55+
order::Int=1,
56+
)
57+
Evaluators._validate_ad_order(order)
58+
evaluator = AbstractPPL.prepare(problem, x; check_dims, context)::VectorEvaluator
59+
# Probe the output once: the value classifies arity, and the vector branch
60+
# reuses it as the Jacobian-result prototype. The base `prepare` contract
61+
# promises one prep-time call into `problem`.
62+
y_probe = evaluator(x)
63+
arity = _ad_output_arity(y_probe)
64+
chunk = _fd_chunk(adtype, x)
65+
target = Base.Fix2(_fd_call, evaluator)
66+
tag = _fd_tag(adtype, target, x)
67+
68+
if order == 2
69+
arity === :scalar || Evaluators._throw_hessian_needs_scalar()
70+
length(x) == 0 &&
71+
return Prepared(adtype, evaluator, FDCache{:hessian}(nothing, nothing), Val(2))
72+
hess_result = DiffResults.MutableDiffResult(
73+
zero(eltype(x)), (similar(x), similar(x, length(x), length(x)))
74+
)
75+
hess_config = ForwardDiff.HessianConfig(target, hess_result, x, chunk, tag)
76+
grad_result = DiffResults.MutableDiffResult(zero(eltype(x)), (similar(x),))
77+
grad_config = ForwardDiff.GradientConfig(target, x, chunk, tag)
78+
cache = FDCache{:hessian}(hess_result, hess_config, grad_result, grad_config)
79+
return Prepared(adtype, evaluator, cache, Val(2))
80+
end
81+
82+
if arity === :scalar
83+
length(x) == 0 &&
84+
return Prepared(adtype, evaluator, FDCache{:scalar}(nothing, nothing))
85+
result = DiffResults.MutableDiffResult(zero(eltype(x)), (similar(x),))
86+
config = ForwardDiff.GradientConfig(target, x, chunk, tag)
87+
return Prepared(adtype, evaluator, FDCache{:scalar}(result, config))
88+
else
89+
length(x) == 0 &&
90+
return Prepared(adtype, evaluator, FDCache{:vector}(nothing, nothing))
91+
result = DiffResults.MutableDiffResult(
92+
similar(y_probe), (similar(y_probe, length(y_probe), length(x)),)
93+
)
94+
config = ForwardDiff.JacobianConfig(target, x, chunk, tag)
95+
return Prepared(adtype, evaluator, FDCache{:vector}(result, config))
96+
end
97+
end
98+
99+
# Top-level so `typeof(_fd_call)` is stable across `prepare` and the hot paths.
100+
# ForwardDiff's `*Config` keys its `Tag` on the target type; a closure built
101+
# inside one method would have a different type from one built inside another,
102+
# desyncing the per-call `Base.Fix2(_fd_call, evaluator)` target from the
103+
# config captured at prep time.
104+
@inline _fd_call(x, e::VectorEvaluator) = e.f(x, e.context...)
105+
106+
# `Val(false)` on every hot-path call below skips `ForwardDiff.checktag`. A
107+
# user-supplied `adtype.tag` (e.g. DynamicPPL's `DynamicPPLTag` sentinel for
108+
# nested AD) has a tag-type parameter that does not equal `typeof(target)`, so
109+
# the default check would error. The tag's role is only to label the outer
110+
# Dual scope; the config we built at prep time already encodes the right tag.
111+
112+
@inline function AbstractPPL.value_and_gradient!!(
113+
p::Prepared{
114+
<:AutoForwardDiff,
115+
<:VectorEvaluator,
116+
<:Union{FDCache{:scalar,Nothing},FDCache{:hessian,Nothing}},
117+
},
118+
x::AbstractVector{T},
119+
) where {T<:Real}
120+
Evaluators._check_ad_input(p.evaluator, x)
121+
return (p.evaluator(x), T[])
122+
end
123+
124+
@inline function AbstractPPL.value_and_gradient!!(
125+
p::Prepared{<:AutoForwardDiff,<:VectorEvaluator,<:FDCache{:scalar}},
126+
x::AbstractVector{<:Real},
127+
)
128+
Evaluators._check_ad_input(p.evaluator, x)
129+
ForwardDiff.gradient!(
130+
p.cache.result, Base.Fix2(_fd_call, p.evaluator), x, p.cache.config, Val(false)
131+
)
132+
return (DiffResults.value(p.cache.result), DiffResults.gradient(p.cache.result))
133+
end
134+
135+
# Order=2 prep also satisfies the order=1 gradient contract via the dedicated
136+
# gradient cache built at prep time — skips the O(n²) Hessian work.
137+
@inline function AbstractPPL.value_and_gradient!!(
138+
p::Prepared{<:AutoForwardDiff,<:VectorEvaluator,<:FDCache{:hessian}},
139+
x::AbstractVector{<:Real},
140+
)
141+
Evaluators._check_ad_input(p.evaluator, x)
142+
ForwardDiff.gradient!(
143+
p.cache.gradient_result,
144+
Base.Fix2(_fd_call, p.evaluator),
145+
x,
146+
p.cache.gradient_config,
147+
Val(false),
148+
)
149+
return (
150+
DiffResults.value(p.cache.gradient_result),
151+
DiffResults.gradient(p.cache.gradient_result),
152+
)
153+
end
154+
155+
# Arity-mismatch rejections live on dedicated cache tags so dispatch resolves
156+
# the failure mode at compile time.
157+
@inline function AbstractPPL.value_and_gradient!!(
158+
::Prepared{<:AutoForwardDiff,<:VectorEvaluator,<:FDCache{:vector}},
159+
::AbstractVector{<:Real},
160+
)
161+
return Evaluators._throw_gradient_needs_scalar()
162+
end
163+
164+
@inline function AbstractPPL.value_and_jacobian!!(
165+
::Prepared{
166+
<:AutoForwardDiff,<:VectorEvaluator,<:Union{FDCache{:scalar},FDCache{:hessian}}
167+
},
168+
::AbstractVector{<:Real},
169+
)
170+
return Evaluators._throw_jacobian_needs_vector()
171+
end
172+
173+
@inline function AbstractPPL.value_and_jacobian!!(
174+
p::Prepared{<:AutoForwardDiff,<:VectorEvaluator,<:FDCache{:vector,Nothing}},
175+
x::AbstractVector{<:Real},
176+
)
177+
Evaluators._check_ad_input(p.evaluator, x)
178+
val = p.evaluator(x)
179+
return (val, similar(x, length(val), 0))
180+
end
181+
182+
@inline function AbstractPPL.value_and_jacobian!!(
183+
p::Prepared{<:AutoForwardDiff,<:VectorEvaluator,<:FDCache{:vector}},
184+
x::AbstractVector{<:Real},
185+
)
186+
Evaluators._check_ad_input(p.evaluator, x)
187+
ForwardDiff.jacobian!(
188+
p.cache.result, Base.Fix2(_fd_call, p.evaluator), x, p.cache.config, Val(false)
189+
)
190+
return (DiffResults.value(p.cache.result), DiffResults.jacobian(p.cache.result))
191+
end
192+
193+
@inline function AbstractPPL.value_gradient_and_hessian!!(
194+
::Prepared{
195+
<:AutoForwardDiff,<:VectorEvaluator,<:Union{FDCache{:scalar},FDCache{:vector}}
196+
},
197+
::AbstractVector{<:Real},
198+
)
199+
return Evaluators._throw_hessian_needs_order_2_prep()
200+
end
201+
202+
@inline function AbstractPPL.value_gradient_and_hessian!!(
203+
p::Prepared{<:AutoForwardDiff,<:VectorEvaluator,<:FDCache{:hessian,Nothing}},
204+
x::AbstractVector{T},
205+
) where {T<:Real}
206+
Evaluators._check_ad_input(p.evaluator, x)
207+
return (p.evaluator(x), T[], similar(x, 0, 0))
208+
end
209+
210+
@inline function AbstractPPL.value_gradient_and_hessian!!(
211+
p::Prepared{<:AutoForwardDiff,<:VectorEvaluator,<:FDCache{:hessian}},
212+
x::AbstractVector{<:Real},
213+
)
214+
Evaluators._check_ad_input(p.evaluator, x)
215+
ForwardDiff.hessian!(
216+
p.cache.result, Base.Fix2(_fd_call, p.evaluator), x, p.cache.config, Val(false)
217+
)
218+
return (
219+
DiffResults.value(p.cache.result),
220+
DiffResults.gradient(p.cache.result),
221+
DiffResults.hessian(p.cache.result),
222+
)
223+
end
224+
225+
end # module

0 commit comments

Comments
 (0)