|
| 1 | +module AbstractPPLForwardDiffExt |
| 2 | + |
| 3 | +using AbstractPPL: AbstractPPL |
| 4 | +using AbstractPPL.Evaluators: Evaluators, Prepared, VectorEvaluator, _ad_output_arity |
| 5 | +using ADTypes: AutoForwardDiff |
| 6 | +using ForwardDiff: ForwardDiff |
| 7 | +using DiffResults: DiffResults |
| 8 | + |
| 9 | +# `AutoForwardDiff{CS}` carries the chunk size as a type parameter; `nothing` |
| 10 | +# defers the choice to ForwardDiff. |
| 11 | +_fd_chunk(::AutoForwardDiff{nothing}, x) = ForwardDiff.Chunk(x) |
| 12 | +_fd_chunk(::AutoForwardDiff{CS}, _) where {CS} = ForwardDiff.Chunk{CS}() |
| 13 | + |
| 14 | +# A user-supplied `adtype.tag` (for nested differentiation) is threaded into the |
| 15 | +# `*Config` constructors; `nothing` (the ADTypes default) reproduces |
| 16 | +# ForwardDiff's per-constructor default of `Tag(target, eltype(x))`. |
| 17 | +@inline _fd_tag(adtype::AutoForwardDiff, target, x) = |
| 18 | + adtype.tag === nothing ? ForwardDiff.Tag(target, eltype(x)) : adtype.tag |
| 19 | + |
| 20 | +# `A::Symbol` ∈ `(:scalar, :vector, :hessian)` encodes both output arity |
| 21 | +# (order=1) and order (order=2 ≡ `:hessian`), so dispatch resolves the hot path |
| 22 | +# and the arity-mismatch failure modes at compile time without a runtime branch. |
| 23 | +# `gradient_result` / `gradient_config` are populated only on `:hessian` caches |
| 24 | +# so `value_and_gradient!!` on an order=2 prep skips the O(n²) Hessian work. |
| 25 | +# `result::Nothing` is the empty-input sentinel: hot paths dispatch on |
| 26 | +# `FDCache{A,Nothing}` to short-circuit before any ForwardDiff call (chunk |
| 27 | +# selection `BoundsError`s on length-zero inputs). The stored `result` aliases |
| 28 | +# the arrays returned by `value_and_*!!`, per the `!!` contract. |
| 29 | +struct FDCache{A,R,C,GR,GC} |
| 30 | + result::R |
| 31 | + config::C |
| 32 | + gradient_result::GR |
| 33 | + gradient_config::GC |
| 34 | + function FDCache{A}( |
| 35 | + result::R, config::C, gradient_result::GR=nothing, gradient_config::GC=nothing |
| 36 | + ) where {A,R,C,GR,GC} |
| 37 | + return new{A,R,C,GR,GC}(result, config, gradient_result, gradient_config) |
| 38 | + end |
| 39 | +end |
| 40 | + |
| 41 | +""" |
| 42 | + prepare(adtype::AutoForwardDiff, problem, x; check_dims=true, context::Tuple=(), order=1) |
| 43 | +
|
| 44 | +Prepare a ForwardDiff gradient, Jacobian, or Hessian evaluator for a vector |
| 45 | +input. `order=1` (default) picks gradient/Jacobian by output arity; `order=2` |
| 46 | +builds Hessian machinery and requires a scalar-valued problem. `context` and |
| 47 | +`check_dims` follow the base `prepare` contract. |
| 48 | +""" |
| 49 | +function AbstractPPL.prepare( |
| 50 | + adtype::AutoForwardDiff, |
| 51 | + problem, |
| 52 | + x::AbstractVector{<:Real}; |
| 53 | + check_dims::Bool=true, |
| 54 | + context::Tuple=(), |
| 55 | + order::Int=1, |
| 56 | +) |
| 57 | + Evaluators._validate_ad_order(order) |
| 58 | + evaluator = AbstractPPL.prepare(problem, x; check_dims, context)::VectorEvaluator |
| 59 | + # Probe the output once: the value classifies arity, and the vector branch |
| 60 | + # reuses it as the Jacobian-result prototype. The base `prepare` contract |
| 61 | + # promises one prep-time call into `problem`. |
| 62 | + y_probe = evaluator(x) |
| 63 | + arity = _ad_output_arity(y_probe) |
| 64 | + chunk = _fd_chunk(adtype, x) |
| 65 | + target = Base.Fix2(_fd_call, evaluator) |
| 66 | + tag = _fd_tag(adtype, target, x) |
| 67 | + |
| 68 | + if order == 2 |
| 69 | + arity === :scalar || Evaluators._throw_hessian_needs_scalar() |
| 70 | + length(x) == 0 && |
| 71 | + return Prepared(adtype, evaluator, FDCache{:hessian}(nothing, nothing), Val(2)) |
| 72 | + hess_result = DiffResults.MutableDiffResult( |
| 73 | + zero(eltype(x)), (similar(x), similar(x, length(x), length(x))) |
| 74 | + ) |
| 75 | + hess_config = ForwardDiff.HessianConfig(target, hess_result, x, chunk, tag) |
| 76 | + grad_result = DiffResults.MutableDiffResult(zero(eltype(x)), (similar(x),)) |
| 77 | + grad_config = ForwardDiff.GradientConfig(target, x, chunk, tag) |
| 78 | + cache = FDCache{:hessian}(hess_result, hess_config, grad_result, grad_config) |
| 79 | + return Prepared(adtype, evaluator, cache, Val(2)) |
| 80 | + end |
| 81 | + |
| 82 | + if arity === :scalar |
| 83 | + length(x) == 0 && |
| 84 | + return Prepared(adtype, evaluator, FDCache{:scalar}(nothing, nothing)) |
| 85 | + result = DiffResults.MutableDiffResult(zero(eltype(x)), (similar(x),)) |
| 86 | + config = ForwardDiff.GradientConfig(target, x, chunk, tag) |
| 87 | + return Prepared(adtype, evaluator, FDCache{:scalar}(result, config)) |
| 88 | + else |
| 89 | + length(x) == 0 && |
| 90 | + return Prepared(adtype, evaluator, FDCache{:vector}(nothing, nothing)) |
| 91 | + result = DiffResults.MutableDiffResult( |
| 92 | + similar(y_probe), (similar(y_probe, length(y_probe), length(x)),) |
| 93 | + ) |
| 94 | + config = ForwardDiff.JacobianConfig(target, x, chunk, tag) |
| 95 | + return Prepared(adtype, evaluator, FDCache{:vector}(result, config)) |
| 96 | + end |
| 97 | +end |
| 98 | + |
| 99 | +# Top-level so `typeof(_fd_call)` is stable across `prepare` and the hot paths. |
| 100 | +# ForwardDiff's `*Config` keys its `Tag` on the target type; a closure built |
| 101 | +# inside one method would have a different type from one built inside another, |
| 102 | +# desyncing the per-call `Base.Fix2(_fd_call, evaluator)` target from the |
| 103 | +# config captured at prep time. |
| 104 | +@inline _fd_call(x, e::VectorEvaluator) = e.f(x, e.context...) |
| 105 | + |
| 106 | +# `Val(false)` on every hot-path call below skips `ForwardDiff.checktag`. A |
| 107 | +# user-supplied `adtype.tag` (e.g. DynamicPPL's `DynamicPPLTag` sentinel for |
| 108 | +# nested AD) has a tag-type parameter that does not equal `typeof(target)`, so |
| 109 | +# the default check would error. The tag's role is only to label the outer |
| 110 | +# Dual scope; the config we built at prep time already encodes the right tag. |
| 111 | + |
| 112 | +@inline function AbstractPPL.value_and_gradient!!( |
| 113 | + p::Prepared{ |
| 114 | + <:AutoForwardDiff, |
| 115 | + <:VectorEvaluator, |
| 116 | + <:Union{FDCache{:scalar,Nothing},FDCache{:hessian,Nothing}}, |
| 117 | + }, |
| 118 | + x::AbstractVector{T}, |
| 119 | +) where {T<:Real} |
| 120 | + Evaluators._check_ad_input(p.evaluator, x) |
| 121 | + return (p.evaluator(x), T[]) |
| 122 | +end |
| 123 | + |
| 124 | +@inline function AbstractPPL.value_and_gradient!!( |
| 125 | + p::Prepared{<:AutoForwardDiff,<:VectorEvaluator,<:FDCache{:scalar}}, |
| 126 | + x::AbstractVector{<:Real}, |
| 127 | +) |
| 128 | + Evaluators._check_ad_input(p.evaluator, x) |
| 129 | + ForwardDiff.gradient!( |
| 130 | + p.cache.result, Base.Fix2(_fd_call, p.evaluator), x, p.cache.config, Val(false) |
| 131 | + ) |
| 132 | + return (DiffResults.value(p.cache.result), DiffResults.gradient(p.cache.result)) |
| 133 | +end |
| 134 | + |
| 135 | +# Order=2 prep also satisfies the order=1 gradient contract via the dedicated |
| 136 | +# gradient cache built at prep time — skips the O(n²) Hessian work. |
| 137 | +@inline function AbstractPPL.value_and_gradient!!( |
| 138 | + p::Prepared{<:AutoForwardDiff,<:VectorEvaluator,<:FDCache{:hessian}}, |
| 139 | + x::AbstractVector{<:Real}, |
| 140 | +) |
| 141 | + Evaluators._check_ad_input(p.evaluator, x) |
| 142 | + ForwardDiff.gradient!( |
| 143 | + p.cache.gradient_result, |
| 144 | + Base.Fix2(_fd_call, p.evaluator), |
| 145 | + x, |
| 146 | + p.cache.gradient_config, |
| 147 | + Val(false), |
| 148 | + ) |
| 149 | + return ( |
| 150 | + DiffResults.value(p.cache.gradient_result), |
| 151 | + DiffResults.gradient(p.cache.gradient_result), |
| 152 | + ) |
| 153 | +end |
| 154 | + |
| 155 | +# Arity-mismatch rejections live on dedicated cache tags so dispatch resolves |
| 156 | +# the failure mode at compile time. |
| 157 | +@inline function AbstractPPL.value_and_gradient!!( |
| 158 | + ::Prepared{<:AutoForwardDiff,<:VectorEvaluator,<:FDCache{:vector}}, |
| 159 | + ::AbstractVector{<:Real}, |
| 160 | +) |
| 161 | + return Evaluators._throw_gradient_needs_scalar() |
| 162 | +end |
| 163 | + |
| 164 | +@inline function AbstractPPL.value_and_jacobian!!( |
| 165 | + ::Prepared{ |
| 166 | + <:AutoForwardDiff,<:VectorEvaluator,<:Union{FDCache{:scalar},FDCache{:hessian}} |
| 167 | + }, |
| 168 | + ::AbstractVector{<:Real}, |
| 169 | +) |
| 170 | + return Evaluators._throw_jacobian_needs_vector() |
| 171 | +end |
| 172 | + |
| 173 | +@inline function AbstractPPL.value_and_jacobian!!( |
| 174 | + p::Prepared{<:AutoForwardDiff,<:VectorEvaluator,<:FDCache{:vector,Nothing}}, |
| 175 | + x::AbstractVector{<:Real}, |
| 176 | +) |
| 177 | + Evaluators._check_ad_input(p.evaluator, x) |
| 178 | + val = p.evaluator(x) |
| 179 | + return (val, similar(x, length(val), 0)) |
| 180 | +end |
| 181 | + |
| 182 | +@inline function AbstractPPL.value_and_jacobian!!( |
| 183 | + p::Prepared{<:AutoForwardDiff,<:VectorEvaluator,<:FDCache{:vector}}, |
| 184 | + x::AbstractVector{<:Real}, |
| 185 | +) |
| 186 | + Evaluators._check_ad_input(p.evaluator, x) |
| 187 | + ForwardDiff.jacobian!( |
| 188 | + p.cache.result, Base.Fix2(_fd_call, p.evaluator), x, p.cache.config, Val(false) |
| 189 | + ) |
| 190 | + return (DiffResults.value(p.cache.result), DiffResults.jacobian(p.cache.result)) |
| 191 | +end |
| 192 | + |
| 193 | +@inline function AbstractPPL.value_gradient_and_hessian!!( |
| 194 | + ::Prepared{ |
| 195 | + <:AutoForwardDiff,<:VectorEvaluator,<:Union{FDCache{:scalar},FDCache{:vector}} |
| 196 | + }, |
| 197 | + ::AbstractVector{<:Real}, |
| 198 | +) |
| 199 | + return Evaluators._throw_hessian_needs_order_2_prep() |
| 200 | +end |
| 201 | + |
| 202 | +@inline function AbstractPPL.value_gradient_and_hessian!!( |
| 203 | + p::Prepared{<:AutoForwardDiff,<:VectorEvaluator,<:FDCache{:hessian,Nothing}}, |
| 204 | + x::AbstractVector{T}, |
| 205 | +) where {T<:Real} |
| 206 | + Evaluators._check_ad_input(p.evaluator, x) |
| 207 | + return (p.evaluator(x), T[], similar(x, 0, 0)) |
| 208 | +end |
| 209 | + |
| 210 | +@inline function AbstractPPL.value_gradient_and_hessian!!( |
| 211 | + p::Prepared{<:AutoForwardDiff,<:VectorEvaluator,<:FDCache{:hessian}}, |
| 212 | + x::AbstractVector{<:Real}, |
| 213 | +) |
| 214 | + Evaluators._check_ad_input(p.evaluator, x) |
| 215 | + ForwardDiff.hessian!( |
| 216 | + p.cache.result, Base.Fix2(_fd_call, p.evaluator), x, p.cache.config, Val(false) |
| 217 | + ) |
| 218 | + return ( |
| 219 | + DiffResults.value(p.cache.result), |
| 220 | + DiffResults.gradient(p.cache.result), |
| 221 | + DiffResults.hessian(p.cache.result), |
| 222 | + ) |
| 223 | +end |
| 224 | + |
| 225 | +end # module |
0 commit comments