|
1 | 1 | module DynamicPPLEnzymeExt |
2 | 2 |
|
3 | | -using DynamicPPL: ADTypes, DynamicPPL |
| 3 | +using DynamicPPL: ADTypes, DynamicPPL, logdensity_at |
4 | 4 | using Enzyme: Enzyme |
5 | 5 |
|
6 | 6 | _enzyme_gradient_mode(::ADTypes.AutoEnzyme{Nothing}) = Enzyme.ReverseWithPrimal |
@@ -30,21 +30,19 @@ function DynamicPPL._value_and_gradient( |
30 | 30 | transform_strategy::DynamicPPL.AbstractTransformStrategy, |
31 | 31 | accs::DynamicPPL.AccumulatorTuple, |
32 | 32 | ) |
33 | | - f = DynamicPPL.LogDensityAt( |
34 | | - model, getlogdensity, varname_ranges, transform_strategy, accs |
35 | | - ) |
36 | 33 | dx = prep.dx |
37 | 34 | fill!(dx, zero(eltype(dx))) |
38 | | - # Const(f): LogDensityAt is not being differentiated; without Const, Enzyme errors |
39 | | - # because it cannot prove the function argument is readonly. |
40 | | - # We always use reverse mode to obtain the full gradient in one pass, but preserve |
41 | | - # runtime-activity settings from `adtype.mode` when they were requested. |
42 | | - # autodiff(ReverseWithPrimal, ...) returns ((), val); dx is mutated in-place. |
| 35 | + # Pass the plain function plus Const arguments; Enzyme is brittle with closure-like callables. |
43 | 36 | _, val = Enzyme.autodiff( |
44 | 37 | _enzyme_gradient_mode(adtype), |
45 | | - Enzyme.Const(f), |
| 38 | + logdensity_at, |
46 | 39 | Enzyme.Active, |
47 | 40 | Enzyme.Duplicated(params, dx), |
| 41 | + Enzyme.Const(model), |
| 42 | + Enzyme.Const(getlogdensity), |
| 43 | + Enzyme.Const(varname_ranges), |
| 44 | + Enzyme.Const(transform_strategy), |
| 45 | + Enzyme.Const(accs), |
48 | 46 | ) |
49 | 47 | return val, copy(dx) |
50 | 48 | end |
|
0 commit comments