Skip to content

Commit 9ccf137

Browse files
committed
Use Mooncake value_and_gradient!! helpers
1 parent 5e3d21e commit 9ccf137

3 files changed

Lines changed: 6 additions & 24 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ IrrationalConstants = "0.1, 0.2"
6363
LazyArrays = "2"
6464
LogExpFunctions = "0.3.3"
6565
MappedArrays = "0.2.2, 0.3, 0.4"
66-
Mooncake = "0.4.95, 0.5"
66+
Mooncake = "0.5.26"
6767
Reexport = "0.2, 1"
6868
ReverseDiff = "1"
6969
Roots = "1.3.15, 2"

ext/BijectorsMooncakeExt.jl

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -68,28 +68,10 @@ end
6868

6969
## Forward-mode implementations (column-by-column JVPs)
7070

71-
function _value_and_gradient(
72-
f, backend::AutoMooncakeForward, x::AbstractVector{T}
73-
) where {T}
74-
val = f(x)
75-
n = length(x)
76-
cache = prepare_derivative_cache(f, x; config=_mooncake_config(backend))
77-
df = _mooncake_zero_tangent_or_primal(f, backend)
78-
if n == 0
79-
return val, Vector{eltype(x)}(undef, 0)
80-
end
81-
dx = zeros(T, n)
82-
dx[1] = one(T)
83-
_, first_jvp = value_and_derivative!!(cache, (f, df), (x, dx))
84-
grad = Vector{typeof(first_jvp)}(undef, n)
85-
grad[1] = first_jvp
86-
for j in 2:n
87-
fill!(dx, zero(T))
88-
dx[j] = one(T)
89-
_, jvp = value_and_derivative!!(cache, (f, df), (x, dx))
90-
grad[j] = jvp
91-
end
92-
return val, grad
71+
function _value_and_gradient(f, backend::AutoMooncakeForward, x::AbstractVector)
72+
cache = prepare_gradient_cache(f, x; config=_mooncake_config(backend))
73+
val, (_, x_grad) = value_and_gradient!!(cache, f, x)
74+
return val, x_grad
9375
end
9476

9577
function _value_and_jacobian(

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ LazyArrays = "1, 2"
5353
LogDensityProblems = "2"
5454
LogExpFunctions = "0.3.1"
5555
MCMCDiagnosticTools = "0.3"
56-
Mooncake = "0.4, 0.5"
56+
Mooncake = "0.5.26"
5757
PDMats = "0.11"
5858
ReverseDiff = "1.4.2"
5959
StableRNGs = "1"

0 commit comments

Comments
 (0)