Skip to content

Commit 2fa0b36

Browse files
committed
fix test_grad
1 parent b7684e2 commit 2fa0b36

4 files changed

Lines changed: 23 additions & 18 deletions

File tree

ext/AbstractPPLFiniteDifferencesExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ end
1414

1515
AbstractPPL.capabilities(::Type{<:FDPrepared}) = DerivativeOrder{1}()
1616

17+
function AbstractPPL.test_grad(f, x::AbstractVector{<:AbstractFloat})
18+
return FiniteDifferences.grad(FiniteDifferences.central_fdm(5, 1), f, x)[1]
19+
end
20+
1721
function AbstractPPL.dimension(::FDPrepared{<:Any,<:Any,<:NamedTuple})
1822
throw(
1923
ArgumentError(

src/AbstractPPL.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ export AbstractProbabilisticProgram,
88
export AbstractModelTrace
99

1010
# Evaluator interface
11-
export DerivativeOrder, capabilities, prepare, value_and_gradient, dimension
11+
export DerivativeOrder, capabilities, prepare, value_and_gradient, test_grad, dimension
1212

1313
include("abstractmodeltrace.jl")
1414
include("abstractprobprog.jl")

src/evaluator.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,21 @@ function value_and_gradient(prepared, x::AbstractVector{<:AbstractFloat})
6262
)
6363
end
6464

65+
"""
66+
test_grad(f, x::AbstractVector{<:AbstractFloat})
67+
68+
Return a finite-difference reference gradient for a scalar-valued callable `f`
69+
evaluated at the vector input `x`.
70+
71+
If the FiniteDifferences extension is not loaded, this warns and returns `nothing`.
72+
"""
73+
function test_grad end
74+
75+
function test_grad(f, x)
76+
@warn "Finite-difference reference gradients require `using FiniteDifferences`; skipping test_grad."
77+
return nothing
78+
end
79+
6580
"""
6681
dimension(prepared)::Int
6782

test/test_utils.jl

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,15 @@ Include this file inside `@testset` blocks in `test/ext/*/` tests after loading
55
AbstractPPL and Test.
66
"""
77

8-
function _fd_gradient(f, x::AbstractVector)
9-
T = float(eltype(x))
10-
h = cbrt(eps(T))
11-
grad = Vector{T}(undef, length(x))
12-
for i in eachindex(x)
13-
xp = copy(x)
14-
xp[i] += h
15-
xm = copy(x)
16-
xm[i] -= h
17-
grad[i] = (f(xp) - f(xm)) / (2h)
18-
end
19-
return grad
20-
end
21-
228
"""
239
test_autograd(prepared, x::AbstractVector; atol=1e-5, rtol=1e-5)
2410
25-
Compare `value_and_gradient(prepared, x)` against a central finite-difference
11+
Compare `value_and_gradient(prepared, x)` against a finite-difference
2612
reference. Calls `@test` internally; use inside a `@testset` block.
2713
"""
2814
function test_autograd(prepared, x::AbstractVector; atol=1e-5, rtol=1e-5)
2915
val_ad, grad_ad = AbstractPPL.value_and_gradient(prepared, x)
30-
grad_fd = _fd_gradient(prepared, x)
16+
grad_fd = AbstractPPL.test_grad(prepared, x)
3117
@test val_ad prepared(x)
32-
@test grad_ad grad_fd atol = atol rtol = rtol
18+
return isnothing(grad_fd) || @test grad_ad grad_fd atol = atol rtol = rtol
3319
end

0 commit comments

Comments
 (0)