@@ -5,79 +5,62 @@ using AbstractPPL.Utils: flatten_to!!, unflatten_to!!
55using ADTypes: AutoForwardDiff
66using ForwardDiff: ForwardDiff
77
8- struct ForwardDiffPrepared{F,C,R,P }
9- evaluator
8+ struct ForwardDiffPrepared{E, F,C,R}
9+ evaluator:: E
1010 f:: F
1111 config:: C
1212 result:: R
13- inputspec:: P
1413end
1514
1615AbstractPPL. capabilities (:: Type{<:ForwardDiffPrepared} ) = DerivativeOrder {1} ()
16+ AbstractPPL. dimension (p:: ForwardDiffPrepared ) = AbstractPPL. dimension (p. evaluator)
1717
18- function AbstractPPL. dimension (:: ForwardDiffPrepared{<:Any,<:Any,<:Any,<:NamedTuple} )
19- throw (
20- ArgumentError (
21- " `dimension` is only available for evaluators prepared with a vector of floating-point numbers." ,
22- ),
23- )
24- end
25- function AbstractPPL. dimension (p:: ForwardDiffPrepared{<:Any,<:Any,<:Any,<:AbstractVector} )
26- return length (p. inputspec)
18+ function (p:: ForwardDiffPrepared )(x:: AbstractVector{<:Integer} )
19+ throw (MethodError (p, (x,)))
2720end
2821
29- function (p:: ForwardDiffPrepared{<:Any,<:Any,<:Any,<:NamedTuple} )(values:: NamedTuple )
30- typeof (values) === typeof (p. inputspec) || throw (
22+ function (p:: ForwardDiffPrepared{<:AbstractPPL.ADProblems.NamedTupleEvaluator} )(
23+ values:: NamedTuple
24+ )
25+ typeof (values) === typeof (p. evaluator. inputspec) || throw (
3126 ArgumentError (
3227 " Expected the same NamedTuple structure that was used to prepare this evaluator." ,
3328 ),
3429 )
3530 return p. evaluator (values)
3631end
3732
38- function (p:: ForwardDiffPrepared{<:Any,<:Any,<:Any,<:AbstractVector} )(
39- x:: AbstractVector{<:Integer}
40- )
41- throw (MethodError (p, (x,)))
42- end
43-
44- function (p:: ForwardDiffPrepared{<:Any,<:Any,<:Any,<:AbstractVector} )(x:: AbstractVector )
45- length (x) == length (p. inputspec) || throw (
46- DimensionMismatch (
47- " Expected a vector of length $(length (p. inputspec)) , but got length $(length (x)) ." ,
48- ),
49- )
50- return p. evaluator (x)
51- end
52-
5333function (p:: ForwardDiffPrepared )(x)
54- throw ( MethodError (p, (x,)) )
34+ return p . evaluator (x )
5535end
5636
5737function AbstractPPL. prepare (:: AutoForwardDiff , problem, values:: NamedTuple )
58- evaluator = AbstractPPL. prepare (problem, values)
38+ evaluator = AbstractPPL. ADProblems. NamedTupleEvaluator (
39+ AbstractPPL. prepare (problem, values), values
40+ )
5941 x = flatten_to!! (nothing , values)
6042 f = let evaluator = evaluator, values = values
6143 x -> evaluator (unflatten_to!! (values, x))
6244 end
6345 result = ForwardDiff. DiffResults. MutableDiffResult (zero (eltype (x)), (similar (x),))
6446 cfg = ForwardDiff. GradientConfig (f, x)
65- return ForwardDiffPrepared (evaluator, f, cfg, result, values )
47+ return ForwardDiffPrepared (evaluator, f, cfg, result)
6648end
6749
6850function AbstractPPL. prepare (:: AutoForwardDiff , problem, x:: AbstractVector{<:AbstractFloat} )
69- evaluator = AbstractPPL. prepare (problem, x)
70- f = evaluator
71- cfg = ForwardDiff. GradientConfig (f, x)
51+ evaluator = AbstractPPL. ADProblems. VectorEvaluator (
52+ AbstractPPL. prepare (problem, x), length (x)
53+ )
54+ cfg = ForwardDiff. GradientConfig (evaluator, x)
7255 grad_buf = similar (x)
7356 result = ForwardDiff. DiffResults. MutableDiffResult (zero (eltype (x)), (grad_buf,))
74- return ForwardDiffPrepared (evaluator, f , cfg, result, x )
57+ return ForwardDiffPrepared (evaluator, evaluator , cfg, result)
7558end
7659
7760@inline function AbstractPPL. value_and_gradient (
78- p:: ForwardDiffPrepared{<:Any,<:Any,<:Any,<:NamedTuple } , values:: NamedTuple
61+ p:: ForwardDiffPrepared{<:AbstractPPL.ADProblems.NamedTupleEvaluator } , values:: NamedTuple
7962)
80- typeof (values) === typeof (p. inputspec) || throw (
63+ typeof (values) === typeof (p. evaluator . inputspec) || throw (
8164 ArgumentError (
8265 " Expected the same NamedTuple structure that was used to prepare this evaluator." ,
8366 ),
8669 ForwardDiff. gradient! (p. result, p. f, x, p. config)
8770 val = ForwardDiff. DiffResults. value (p. result)
8871 grad = copy (ForwardDiff. DiffResults. gradient (p. result))
89- return (val, unflatten_to!! (p. inputspec, grad))
72+ return (val, unflatten_to!! (p. evaluator . inputspec, grad))
9073end
9174
9275@inline function AbstractPPL. value_and_gradient (
93- p:: ForwardDiffPrepared{<:Any,<:Any,<:Any,<:AbstractVector } ,
76+ p:: ForwardDiffPrepared{<:AbstractPPL.ADProblems.VectorEvaluator } ,
9477 x:: AbstractVector{<:AbstractFloat} ,
9578)
96- length (x) == length (p. inputspec) || throw (
97- DimensionMismatch (
98- " Expected a vector of length $(length (p. inputspec)) , but got length $(length (x)) ." ,
99- ),
100- )
10179 ForwardDiff. gradient! (p. result, p. f, x, p. config)
10280 val = ForwardDiff. DiffResults. value (p. result)
10381 grad = copy (ForwardDiff. DiffResults. gradient (p. result))
0 commit comments