Skip to content

Commit 11fae89

Browse files
committed
Use non-closure Enzyme logdensity call
1 parent de0e6b4 commit 11fae89

2 files changed

Lines changed: 11 additions & 11 deletions

File tree

ext/DynamicPPLEnzymeExt.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module DynamicPPLEnzymeExt
22

3-
using DynamicPPL: ADTypes, DynamicPPL
3+
using DynamicPPL: ADTypes, DynamicPPL, logdensity_at
44
using Enzyme: Enzyme
55

66
_enzyme_gradient_mode(::ADTypes.AutoEnzyme{Nothing}) = Enzyme.ReverseWithPrimal
@@ -30,21 +30,19 @@ function DynamicPPL._value_and_gradient(
3030
transform_strategy::DynamicPPL.AbstractTransformStrategy,
3131
accs::DynamicPPL.AccumulatorTuple,
3232
)
33-
f = DynamicPPL.LogDensityAt(
34-
model, getlogdensity, varname_ranges, transform_strategy, accs
35-
)
3633
dx = prep.dx
3734
fill!(dx, zero(eltype(dx)))
38-
# Const(f): LogDensityAt is not being differentiated; without Const, Enzyme errors
39-
# because it cannot prove the function argument is readonly.
40-
# We always use reverse mode to obtain the full gradient in one pass, but preserve
41-
# runtime-activity settings from `adtype.mode` when they were requested.
42-
# autodiff(ReverseWithPrimal, ...) returns ((), val); dx is mutated in-place.
35+
# Pass the plain function plus Const arguments; Enzyme is brittle with closure-like callables.
4336
_, val = Enzyme.autodiff(
4437
_enzyme_gradient_mode(adtype),
45-
Enzyme.Const(f),
38+
logdensity_at,
4639
Enzyme.Active,
4740
Enzyme.Duplicated(params, dx),
41+
Enzyme.Const(model),
42+
Enzyme.Const(getlogdensity),
43+
Enzyme.Const(varname_ranges),
44+
Enzyme.Const(transform_strategy),
45+
Enzyme.Const(accs),
4846
)
4947
return val, copy(dx)
5048
end

ext/DynamicPPLMooncakeExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ end
111111
@compile_workload begin
112112
for dist in (Normal(), InverseGamma(2, 3), Beta(2, 2))
113113
@model f() = x ~ dist
114-
ldf = LogDensityFunction(f(), getlogjoint_internal, LinkAll(); adtype=AutoMooncake())
114+
ldf = LogDensityFunction(
115+
f(), getlogjoint_internal, LinkAll(); adtype=AutoMooncake()
116+
)
115117
DynamicPPL.LogDensityProblems.logdensity_and_gradient(ldf, [0.5])
116118
end
117119
end

0 commit comments

Comments
 (0)