Skip to content

Commit e58de0a

Browse files
yebaiclaude
andcommitted
Refine AD extension wrappers and buffer reuse.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 4932058 commit e58de0a

6 files changed

Lines changed: 17 additions & 26 deletions

ext/AbstractPPLDifferentiationInterfaceExt.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ end
1313
AbstractPPL.capabilities(::Type{<:DIPrepared}) = DerivativeOrder{1}()
1414
AbstractPPL.dimension(p::DIPrepared) = AbstractPPL.dimension(p.evaluator)
1515

16-
function (p::DIPrepared)(x)
17-
return p.evaluator(x)
18-
end
16+
(p::DIPrepared)(x) = p.evaluator(x)
1917

2018
function AbstractPPL.prepare(
2119
adtype::ADTypes.AbstractADType, problem, x::AbstractVector{<:AbstractFloat}

ext/AbstractPPLEnzymeExt.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,40 @@ using AbstractPPL: AbstractPPL, DerivativeOrder
44
using ADTypes: AutoEnzyme
55
using Enzyme: Enzyme
66

7-
struct EnzymePrepared{E}
7+
struct EnzymePrepared{E,G}
88
evaluator::E
9+
gradient::G
910
end
1011

1112
AbstractPPL.capabilities(::Type{<:EnzymePrepared}) = DerivativeOrder{1}()
1213
AbstractPPL.dimension(p::EnzymePrepared) = AbstractPPL.dimension(p.evaluator)
1314

14-
function (p::EnzymePrepared)(x)
15-
return p.evaluator(x)
16-
end
15+
(p::EnzymePrepared)(x) = p.evaluator(x)
1716

1817
function AbstractPPL.prepare(::AutoEnzyme, problem, x::AbstractVector{<:AbstractFloat})
1918
evaluator = AbstractPPL.ADProblems.VectorEvaluator(
2019
AbstractPPL.prepare(problem, x), length(x)
2120
)
22-
return EnzymePrepared(evaluator)
21+
return EnzymePrepared(evaluator, similar(x))
2322
end
2423

2524
@inline function AbstractPPL.value_and_gradient(
2625
p::EnzymePrepared, x::AbstractVector{<:AbstractFloat}
2726
)
28-
dx = zero(x)
27+
dx = p.gradient
28+
length(dx) == length(x) || throw(
29+
DimensionMismatch(
30+
"Expected a vector of length $(length(dx)), but got length $(length(x))."
31+
),
32+
)
33+
fill!(dx, zero(eltype(dx)))
2934
result = Enzyme.autodiff(
3035
Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal),
3136
Enzyme.Const(p.evaluator),
3237
Enzyme.Active,
3338
Enzyme.Duplicated(x, dx),
3439
)
35-
val = result[2] # The primal value is returned in the second tuple entry.
36-
return (val, dx)
40+
return (result[2], copy(dx))
3741
end
3842

3943
end # module

ext/AbstractPPLFiniteDifferencesExt.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ function (p::FDPrepared{<:AbstractPPL.ADProblems.NamedTupleEvaluator})(values::N
2323
return p.evaluator(values)
2424
end
2525

26-
function (p::FDPrepared)(x)
27-
return p.evaluator(x)
28-
end
26+
(p::FDPrepared)(x) = p.evaluator(x)
2927

3028
function AbstractPPL.prepare(adtype::AutoFiniteDifferences, problem, values::NamedTuple)
3129
evaluator = AbstractPPL.ADProblems.NamedTupleEvaluator(

ext/AbstractPPLForwardDiffExt.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@ function (p::ForwardDiffPrepared{<:AbstractPPL.ADProblems.NamedTupleEvaluator})(
3030
return p.evaluator(values)
3131
end
3232

33-
function (p::ForwardDiffPrepared)(x)
34-
return p.evaluator(x)
35-
end
33+
(p::ForwardDiffPrepared)(x) = p.evaluator(x)
3634

3735
function AbstractPPL.prepare(::AutoForwardDiff, problem, values::NamedTuple)
3836
evaluator = AbstractPPL.ADProblems.NamedTupleEvaluator(

ext/AbstractPPLMooncakeExt.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ end
1212
AbstractPPL.capabilities(::Type{<:MooncakePrepared}) = DerivativeOrder{1}()
1313
AbstractPPL.dimension(p::MooncakePrepared) = AbstractPPL.dimension(p.evaluator)
1414

15-
function (p::MooncakePrepared)(x)
16-
return p.evaluator(x)
17-
end
15+
(p::MooncakePrepared)(x) = p.evaluator(x)
1816

1917
function _prepare(adtype, problem, values::NamedTuple)
2018
evaluator = AbstractPPL.ADProblems.NamedTupleEvaluator(
@@ -68,11 +66,6 @@ end
6866
p::MooncakePrepared{<:AbstractPPL.ADProblems.VectorEvaluator},
6967
x::AbstractVector{<:AbstractFloat},
7068
)
71-
AbstractPPL.dimension(p.evaluator) == length(x) || throw(
72-
DimensionMismatch(
73-
"Expected a vector of length $(AbstractPPL.dimension(p.evaluator)), but got length $(length(x)).",
74-
),
75-
)
7669
val, (_, grad) = Mooncake.value_and_gradient!!(p.cache, p.evaluator, x)
7770
return (val, grad)
7871
end

test/ext/mooncake/mooncake.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ const config_forward = ADTypes.AutoMooncakeForward(; config=Mooncake.Config())
114114
sprint(showerror, err),
115115
)
116116
@test_throws MethodError prepared([3, 1, 2])
117-
@test_throws DimensionMismatch AbstractPPL.value_and_gradient(
117+
@test_throws Mooncake.PreparedCacheSpecError AbstractPPL.value_and_gradient(
118118
prepared, [3.0, 1.0, 2.0, 3.0]
119119
)
120120
end

0 commit comments

Comments
 (0)