@@ -4,36 +4,40 @@ using AbstractPPL: AbstractPPL, DerivativeOrder
44using ADTypes: AutoEnzyme
55using Enzyme: Enzyme
66
7- struct EnzymePrepared{E}
7+ struct EnzymePrepared{E,G }
88 evaluator:: E
9+ gradient:: G
910end
1011
1112AbstractPPL. capabilities (:: Type{<:EnzymePrepared} ) = DerivativeOrder {1} ()
1213AbstractPPL. dimension (p:: EnzymePrepared ) = AbstractPPL. dimension (p. evaluator)
1314
14- function (p:: EnzymePrepared )(x)
15- return p. evaluator (x)
16- end
15+ (p:: EnzymePrepared )(x) = p. evaluator (x)
1716
1817function AbstractPPL. prepare (:: AutoEnzyme , problem, x:: AbstractVector{<:AbstractFloat} )
1918 evaluator = AbstractPPL. ADProblems. VectorEvaluator (
2019 AbstractPPL. prepare (problem, x), length (x)
2120 )
22- return EnzymePrepared (evaluator)
21+ return EnzymePrepared (evaluator, similar (x) )
2322end
2423
2524@inline function AbstractPPL. value_and_gradient (
2625 p:: EnzymePrepared , x:: AbstractVector{<:AbstractFloat}
2726)
28- dx = zero (x)
27+ dx = p. gradient
28+ length (dx) == length (x) || throw (
29+ DimensionMismatch (
30+ " Expected a vector of length $(length (dx)) , but got length $(length (x)) ."
31+ ),
32+ )
33+ fill! (dx, zero (eltype (dx)))
2934 result = Enzyme. autodiff (
3035 Enzyme. set_runtime_activity (Enzyme. ReverseWithPrimal),
3136 Enzyme. Const (p. evaluator),
3237 Enzyme. Active,
3338 Enzyme. Duplicated (x, dx),
3439 )
35- val = result[2 ] # The primal value is returned in the second tuple entry.
36- return (val, dx)
40+ return (result[2 ], copy (dx))
3741end
3842
3943end # module
0 commit comments