Skip to content

Commit 216232b

Browse files
yebaiclaude
andcommitted
Rename evaluator API to ADProblems and fold autograd testing into the core interface.
This moves the evaluator surface into a self-contained ADProblems module and replaces the old finite-difference test helper with test_autograd backed by AutoFiniteDifferences. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 2fa0b36 commit 216232b

13 files changed

Lines changed: 204 additions & 32 deletions

File tree

AGENTS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ AbstractPPL.jl is a Julia interface package for probabilistic programming. It is
1616

1717
- Preserve the model invariants documented in `src/abstractprobprog.jl`: `condition`/`decondition` and `fix`/`unfix` are intended to round-trip when supported.
1818
- `rand(model)` and `predict(model, params)` have default RNG/type forwarding behaviour covered by tests; changes here should stay consistent with `AbstractMCMC` expectations.
19-
- The evaluator API in `src/evaluator.jl` is structural. `prepare(..., prototype::NamedTuple)` fixes field structure, `capabilities` defaults conservatively to `DerivativeOrder{0}()`, and AD-aware prepared objects are expected to return gradients with the same named structure as inputs.
19+
- The ADProblem API in `src/ADProblems.jl` is structural. `prepare(..., prototype::NamedTuple)` fixes field structure, `capabilities` defaults conservatively to `DerivativeOrder{0}()`, and AD-aware prepared objects are expected to return gradients with the same named structure as inputs.
2020
- `VarName` and optics are the main complexity in this repo. Preserve equality, hashing, pretty-printing, composition/decomposition, and type-stability behaviour.
2121
- Dynamic indices (`begin`, `end`, expressions containing them) are intentionally deferred until `concretize`; do not silently erase that distinction.
2222
- Unconcretized dynamic indices must not be serialised. If serialization changes, keep `varname_to_string` / `string_to_varname` round-tripping for supported index types.

docs/src/pplapi.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ evaluate!!
1919
AbstractModelTrace
2020
```
2121

22-
## Evaluator interface
22+
## ADProblem interface
2323

2424
```@docs
2525
DerivativeOrder

ext/AbstractPPLDifferentiationInterfaceExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ end
2525

2626
# This extension handles the generic `AbstractADType` vector path directly so
2727
# DifferentiationInterface backends can opt in without a fallback method in
28-
# `src/evaluator.jl` forcing a precompile-time method overwrite.
28+
# `src/ADProblems.jl` forcing a precompile-time method overwrite.
2929
function AbstractPPL.prepare(
3030
adtype::ADTypes.AbstractADType, problem, x::AbstractVector{<:AbstractFloat}
3131
)

ext/AbstractPPLFiniteDifferencesExt.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@ end
1414

1515
AbstractPPL.capabilities(::Type{<:FDPrepared}) = DerivativeOrder{1}()
1616

17-
function AbstractPPL.test_grad(f, x::AbstractVector{<:AbstractFloat})
18-
return FiniteDifferences.grad(FiniteDifferences.central_fdm(5, 1), f, x)[1]
19-
end
20-
2117
function AbstractPPL.dimension(::FDPrepared{<:Any,<:Any,<:NamedTuple})
2218
throw(
2319
ArgumentError(

src/evaluator.jl renamed to src/ADProblems.jl

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
module ADProblems
2+
13
using ADTypes: ADTypes
24

5+
export DerivativeOrder, capabilities, prepare, value_and_gradient, test_autograd, dimension
6+
37
"""
48
DerivativeOrder{K}
59
@@ -63,17 +67,48 @@ function value_and_gradient(prepared, x::AbstractVector{<:AbstractFloat})
6367
end
6468

6569
"""
66-
test_grad(f, x::AbstractVector{<:AbstractFloat})
70+
test_autograd(prepared, x::AbstractVector; atol=1e-5, rtol=1e-5, finite_difference_kwargs...)
6771
68-
Return a finite-difference reference gradient for a scalar-valued callable `f`
69-
evaluated at the vector input `x`.
72+
Compare `value_and_gradient(prepared, x)` against a finite-difference reference
73+
computed via `value_and_gradient(prepare(AutoFiniteDifferences(...), problem, x), x)`.
74+
Throws an informative error on mismatch. Returns `nothing`.
7075
71-
If the FiniteDifferences extension is not loaded, this warns and returns `nothing`.
76+
Backends that want this helper should define `prepare_for_test_autograd(prepared, x)`
77+
to return a pair `(problem, prototype)` suitable for `prepare(AutoFiniteDifferences(...), ...)`.
78+
Additional keyword arguments are forwarded to `ADTypes.AutoFiniteDifferences`.
7279
"""
73-
function test_grad end
80+
function test_autograd end
81+
82+
function prepare_for_test_autograd end
83+
84+
function prepare_for_test_autograd(prepared, x)
85+
throw(
86+
ArgumentError(
87+
"`test_autograd` needs a finite-difference preparation path for $(typeof(prepared)). Define `prepare_for_test_autograd(prepared, x)` to return `(problem, prototype)`.",
88+
),
89+
)
90+
end
7491

75-
function test_grad(f, x)
76-
@warn "Finite-difference reference gradients require `using FiniteDifferences`; skipping test_grad."
92+
function test_autograd(
93+
prepared, x::AbstractVector; atol=1e-5, rtol=1e-5, finite_difference_kwargs...
94+
)
95+
val_ad, grad_ad = value_and_gradient(prepared, x)
96+
problem, prototype = prepare_for_test_autograd(prepared, x)
97+
fd_prepared = prepare(
98+
ADTypes.AutoFiniteDifferences(; finite_difference_kwargs...), problem, prototype
99+
)
100+
val_fd, grad_fd = value_and_gradient(fd_prepared, x)
101+
102+
isapprox(val_ad, val_fd) || throw(
103+
ArgumentError(
104+
"Value mismatch against finite differences: got $val_ad, expected $val_fd."
105+
),
106+
)
107+
isapprox(grad_ad, grad_fd; atol=atol, rtol=rtol) || throw(
108+
ArgumentError(
109+
"Gradient mismatch against finite differences with atol=$atol and rtol=$rtol.",
110+
),
111+
)
77112
return nothing
78113
end
79114

@@ -83,3 +118,5 @@ end
83118
Return the number of scalar entries in the vector input expected by a prepared evaluator.
84119
"""
85120
function dimension end
121+
122+
end # module

src/AbstractPPL.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@ export AbstractProbabilisticProgram,
77
# Abstract traces
88
export AbstractModelTrace
99

10-
# Evaluator interface
11-
export DerivativeOrder, capabilities, prepare, value_and_gradient, test_grad, dimension
12-
10+
# ADProblem interface
1311
include("abstractmodeltrace.jl")
1412
include("abstractprobprog.jl")
1513
include("evaluate.jl")
16-
include("evaluator.jl")
14+
include("ADProblems.jl")
15+
using .ADProblems:
16+
DerivativeOrder, capabilities, prepare, value_and_gradient, test_autograd, dimension
17+
export DerivativeOrder, capabilities, prepare, value_and_gradient, test_autograd, dimension
1718
include("utils.jl")
1819

1920
include("varname/optic.jl")

test/ADProblems.jl

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
using AbstractPPL
2+
using ADTypes: ADTypes
3+
using Test
4+
5+
struct DummyProblem end
6+
7+
struct DummyPrepared
8+
prototype_keys::Tuple
9+
end
10+
11+
function AbstractPPL.prepare(problem::DummyProblem, values::NamedTuple)
12+
return DummyPrepared(keys(values))
13+
end
14+
15+
function (p::DummyPrepared)(values::NamedTuple)
16+
keys(values) == p.prototype_keys ||
17+
error("expected fields $(p.prototype_keys), got $(keys(values))")
18+
return sum(x -> x isa AbstractArray ? sum(x) : x, values)
19+
end
20+
21+
struct DummyADPrepared
22+
dim::Int
23+
end
24+
25+
function AbstractPPL.prepare(
26+
::ADTypes.AbstractADType, problem::DummyProblem, x::AbstractVector{<:AbstractFloat}
27+
)
28+
return DummyADPrepared(length(x))
29+
end
30+
31+
function (p::DummyADPrepared)(x::AbstractVector{<:AbstractFloat})
32+
length(x) == p.dim || error("expected vector of length $(p.dim)")
33+
return sum(x)
34+
end
35+
36+
AbstractPPL.capabilities(::Type{DummyADPrepared}) = DerivativeOrder{1}()
37+
38+
function AbstractPPL.value_and_gradient(
39+
p::DummyADPrepared, x::AbstractVector{<:AbstractFloat}
40+
)
41+
return (sum(x), ones(length(x)))
42+
end
43+
44+
struct DummyVectorPrepared
45+
dim::Int
46+
end
47+
48+
AbstractPPL.dimension(p::DummyVectorPrepared) = p.dim
49+
50+
function (p::DummyVectorPrepared)(x::AbstractVector)
51+
length(x) == p.dim || error("expected vector of length $(p.dim)")
52+
return sum(x)
53+
end
54+
55+
@testset "ADProblem interface" begin
56+
@testset "DerivativeOrder" begin
57+
err = try
58+
DerivativeOrder{3}()
59+
nothing
60+
catch err
61+
err
62+
end
63+
@test err isa ArgumentError
64+
@test occursin("must be 0, 1, or 2", sprint(showerror, err))
65+
@test_throws ArgumentError DerivativeOrder{-1}()
66+
@test DerivativeOrder{0}() < DerivativeOrder{1}()
67+
@test DerivativeOrder{1}() >= DerivativeOrder{1}()
68+
@test DerivativeOrder{1}() < DerivativeOrder{2}()
69+
@test !(DerivativeOrder{2}() < DerivativeOrder{1}())
70+
end
71+
72+
@testset "capabilities default" begin
73+
@test capabilities(Int) == DerivativeOrder{0}()
74+
@test capabilities(42) == DerivativeOrder{0}()
75+
@test capabilities(DummyPrepared((:x,))) == DerivativeOrder{0}()
76+
@test capabilities(DummyPrepared((:x,))) < DerivativeOrder{1}()
77+
end
78+
79+
@testset "prepare (structural)" begin
80+
problem = DummyProblem()
81+
values = (x=0.0, y=[1.0, 2.0])
82+
prepared = prepare(problem, values)
83+
@test prepared isa DummyPrepared
84+
@test prepared.prototype_keys == (:x, :y)
85+
86+
lp = prepared((x=0.5, y=[1.5, 2.5]))
87+
@test lp 0.5 + 1.5 + 2.5
88+
89+
@test_throws Exception prepared((a=1.0, b=2.0))
90+
end
91+
92+
@testset "prepare (AD-aware)" begin
93+
problem = DummyProblem()
94+
x0 = zeros(3)
95+
adtype = ADTypes.AutoForwardDiff()
96+
prepared = prepare(adtype, problem, x0)
97+
@test prepared isa DummyADPrepared
98+
@test capabilities(prepared) == DerivativeOrder{1}()
99+
100+
x = [0.5, 1.5, 2.5]
101+
@test prepared(x) 0.5 + 1.5 + 2.5
102+
103+
val, grad = value_and_gradient(prepared, x)
104+
@test val 0.5 + 1.5 + 2.5
105+
@test grad [1.0, 1.0, 1.0]
106+
end
107+
108+
@testset "dimension and vector adapter" begin
109+
prepared = DummyVectorPrepared(3)
110+
@test dimension(prepared) == 3
111+
@test prepared(ones(3)) 3.0
112+
@test_throws Exception prepared(ones(5))
113+
end
114+
115+
@testset "flatten / unflatten edge cases" begin
116+
empty = NamedTuple()
117+
@test AbstractPPL.Utils.flatten_to!!(nothing, empty) == Float64[]
118+
@test AbstractPPL.Utils.unflatten_to!!(empty, Float64[]) == empty
119+
120+
view_values = (x=@view([1.0, 2.0, 3.0][2:3]),)
121+
flat = AbstractPPL.Utils.flatten_to!!(nothing, view_values)
122+
rebuilt = AbstractPPL.Utils.unflatten_to!!(view_values, flat)
123+
@test collect(rebuilt.x) == [2.0, 3.0]
124+
@test axes(rebuilt.x) == axes(view_values.x)
125+
@test parent(rebuilt.x) == [2.0, 3.0]
126+
end
127+
end

test/ext/differentiation_interface/differentiation_interface.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ function (::QuadraticPrepared)(x::AbstractVector{<:AbstractFloat})
2323
return sum(xi -> xi^2, x)
2424
end
2525

26+
function AbstractPPL.ADProblems.prepare_for_test_autograd(
27+
::AbstractPPL.DIPrepared, x::AbstractVector
28+
)
29+
return (QuadraticProblem(), x)
30+
end
31+
2632
# Use a backend without a native AbstractPPL extension so this test exercises
2733
# AbstractPPLDifferentiationInterfaceExt dispatch directly.
2834
const fdm = FiniteDifferences.central_fdm(5, 1)

test/ext/enzyme/enzyme.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ function (::QuadraticPrepared)(x::AbstractVector{<:AbstractFloat})
2121
return sum(xi -> xi^2, x)
2222
end
2323

24+
function AbstractPPL.ADProblems.prepare_for_test_autograd(
25+
::AbstractPPL.EnzymePrepared, x::AbstractVector
26+
)
27+
return (QuadraticProblem(), x)
28+
end
29+
2430
@testset "AbstractPPLEnzymeExt" begin
2531
problem = QuadraticProblem()
2632
x0 = zeros(3)

test/ext/forward_diff/forward_diff.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ function (::QuadraticVecPrepared)(x::AbstractVector{<:Real})
3030
return sum(xi -> xi^2, x)
3131
end
3232

33+
function AbstractPPL.ADProblems.prepare_for_test_autograd(
34+
::AbstractPPL.ForwardDiffPrepared{<:Any,<:Any,<:Any,<:AbstractVector}, x::AbstractVector
35+
)
36+
return (QuadraticProblem(), x)
37+
end
38+
3339
@testset "AbstractPPLForwardDiffExt" begin
3440
@testset "NamedTuple path" begin
3541
problem = QuadraticProblem()

0 commit comments

Comments
 (0)