Skip to content

Commit e70bf72

Browse files
yebaiclaude
andcommitted
Use native AD APIs for ForwardDiff, Enzyme, and Mooncake; make DI optional
- Move DifferentiationInterface to [weakdeps]; add DynamicPPLDifferentiationInterfaceExt as fallback for backends without native implementations - Add native ForwardDiff gradient via GradientConfig (DynamicPPLForwardDiffExt) - Add native Enzyme gradient via autodiff(ReverseWithPrimal, ...) (new DynamicPPLEnzymeExt) - Keep native Mooncake reverse/forward gradient (DynamicPPLMooncakeExt) - Add Enzyme to test env; drop DI from test env Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 31cc13b commit e70bf72

9 files changed

Lines changed: 189 additions & 70 deletions

Project.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
1212
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
1313
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1414
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
15-
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1615
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1716
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1817
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
@@ -29,6 +28,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2928
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3029

3130
[weakdeps]
31+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
32+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
3233
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3334
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3435
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
@@ -38,11 +39,13 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
3839
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3940

4041
[extensions]
42+
DynamicPPLDifferentiationInterfaceExt = ["DifferentiationInterface"]
4143
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
44+
DynamicPPLEnzymeExt = ["Enzyme"]
4245
DynamicPPLForwardDiffExt = ["ForwardDiff"]
4346
DynamicPPLMCMCChainsExt = ["MCMCChains"]
4447
DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"]
45-
DynamicPPLMooncakeExt = ["Mooncake", "DifferentiationInterface"]
48+
DynamicPPLMooncakeExt = ["Mooncake"]
4649
DynamicPPLReverseDiffExt = ["ReverseDiff"]
4750

4851
[compat]
@@ -55,9 +58,9 @@ Bijectors = "0.15.17"
5558
Chairmarks = "1.3.1"
5659
Compat = "4"
5760
ConstructionBase = "1.5.4"
58-
DifferentiationInterface = "0.6.41, 0.7"
5961
Distributions = "0.25"
6062
DocStringExtensions = "0.9"
63+
Enzyme = "0.13"
6164
EnzymeCore = "0.6 - 0.8"
6265
FillArrays = "1.16.0"
6366
ForwardDiff = "0.10.12, 1"
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
module DynamicPPLDifferentiationInterfaceExt
2+
3+
import DifferentiationInterface as DI
4+
using DynamicPPL:
5+
DynamicPPL,
6+
AccumulatorTuple,
7+
LogDensityAt,
8+
Model,
9+
VarNamedTuple,
10+
AbstractTransformStrategy,
11+
_use_closure,
12+
logdensity_at
13+
using ADTypes: ADTypes
14+
15+
function DynamicPPL._prepare_gradient(
16+
adtype::ADTypes.AbstractADType,
17+
x::AbstractVector{<:Real},
18+
model::Model,
19+
getlogdensity::Any,
20+
varname_ranges::VarNamedTuple,
21+
transform_strategy::AbstractTransformStrategy,
22+
accs::AccumulatorTuple,
23+
)
24+
args = (model, getlogdensity, varname_ranges, transform_strategy, accs)
25+
return if _use_closure(adtype)
26+
DI.prepare_gradient(LogDensityAt(args...), adtype, x)
27+
else
28+
DI.prepare_gradient(logdensity_at, adtype, x, map(DI.Constant, args)...)
29+
end
30+
end
31+
32+
function DynamicPPL._value_and_gradient(
33+
adtype::ADTypes.AbstractADType,
34+
prep,
35+
params::AbstractVector{<:Real},
36+
model::Model,
37+
getlogdensity::Any,
38+
varname_ranges::VarNamedTuple,
39+
transform_strategy::AbstractTransformStrategy,
40+
accs::AccumulatorTuple,
41+
)
42+
return if _use_closure(adtype)
43+
DI.value_and_gradient(
44+
LogDensityAt(model, getlogdensity, varname_ranges, transform_strategy, accs),
45+
prep,
46+
adtype,
47+
params,
48+
)
49+
else
50+
DI.value_and_gradient(
51+
logdensity_at,
52+
prep,
53+
adtype,
54+
params,
55+
DI.Constant(model),
56+
DI.Constant(getlogdensity),
57+
DI.Constant(varname_ranges),
58+
DI.Constant(transform_strategy),
59+
DI.Constant(accs),
60+
)
61+
end
62+
end
63+
64+
end # module

ext/DynamicPPLEnzymeExt.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
module DynamicPPLEnzymeExt
2+
3+
using DynamicPPL: ADTypes, DynamicPPL
4+
using Enzyme: Enzyme
5+
6+
function DynamicPPL._prepare_gradient(
7+
::ADTypes.AutoEnzyme,
8+
x::AbstractVector{<:Real},
9+
model::DynamicPPL.Model,
10+
getlogdensity::Any,
11+
varname_ranges::DynamicPPL.VarNamedTuple,
12+
transform_strategy::DynamicPPL.AbstractTransformStrategy,
13+
accs::DynamicPPL.AccumulatorTuple,
14+
)
15+
return (; dx=similar(x))
16+
end
17+
18+
function DynamicPPL._value_and_gradient(
19+
::ADTypes.AutoEnzyme,
20+
prep,
21+
params::AbstractVector{<:Real},
22+
model::DynamicPPL.Model,
23+
getlogdensity::Any,
24+
varname_ranges::DynamicPPL.VarNamedTuple,
25+
transform_strategy::DynamicPPL.AbstractTransformStrategy,
26+
accs::DynamicPPL.AccumulatorTuple,
27+
)
28+
f = DynamicPPL.LogDensityAt(
29+
model, getlogdensity, varname_ranges, transform_strategy, accs
30+
)
31+
dx = prep.dx
32+
fill!(dx, zero(eltype(dx)))
33+
# Const(f): LogDensityAt is not being differentiated; without Const, Enzyme errors
34+
# because it cannot prove the function argument is readonly.
35+
# autodiff(ReverseWithPrimal, ...) returns ((), val); dx is mutated in-place.
36+
_, val = Enzyme.autodiff(
37+
Enzyme.ReverseWithPrimal,
38+
Enzyme.Const(f),
39+
Enzyme.Active,
40+
Enzyme.Duplicated(params, dx),
41+
)
42+
return val, copy(dx)
43+
end
44+
45+
end # module

ext/DynamicPPLForwardDiffExt.jl

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@ function DynamicPPL.tweak_adtype(
1212
) where {chunk_size}
1313
# Use DynamicPPL tag to improve stack traces
1414
# https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
15-
# NOTE: DifferentiationInterface disables tag checking if the
16-
# tag inside the AutoForwardDiff type is not nothing. See
17-
# https://github.com/JuliaDiff/DifferentiationInterface.jl/blob/1df562180bdcc3e91c885aa5f4162a0be2ced850/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl#L338-L350.
18-
# So we don't currently need to override ForwardDiff.checktag as well.
1915
tag = if use_dynamicppl_tag(ad)
2016
ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(params))
2117
else
@@ -32,4 +28,47 @@ function DynamicPPL.tweak_adtype(
3228
return ADTypes.AutoForwardDiff(; chunksize=ForwardDiff.chunksize(chunk), tag=tag)
3329
end
3430

31+
function DynamicPPL._prepare_gradient(
32+
adtype::ADTypes.AutoForwardDiff{chunk_size},
33+
x::AbstractVector{<:Real},
34+
model::DynamicPPL.Model,
35+
getlogdensity::Any,
36+
varname_ranges::DynamicPPL.VarNamedTuple,
37+
transform_strategy::DynamicPPL.AbstractTransformStrategy,
38+
accs::DynamicPPL.AccumulatorTuple,
39+
) where {chunk_size}
40+
f = DynamicPPL.LogDensityAt(
41+
model, getlogdensity, varname_ranges, transform_strategy, accs
42+
)
43+
chunk = if chunk_size == 0 || chunk_size === nothing
44+
ForwardDiff.Chunk(x)
45+
else
46+
ForwardDiff.Chunk(length(x), chunk_size)
47+
end
48+
cfg = ForwardDiff.GradientConfig(f, x, chunk, adtype.tag)
49+
grad = similar(x)
50+
return (; cfg, grad)
51+
end
52+
53+
function DynamicPPL._value_and_gradient(
54+
::ADTypes.AutoForwardDiff,
55+
prep,
56+
params::AbstractVector{<:Real},
57+
model::DynamicPPL.Model,
58+
getlogdensity::Any,
59+
varname_ranges::DynamicPPL.VarNamedTuple,
60+
transform_strategy::DynamicPPL.AbstractTransformStrategy,
61+
accs::DynamicPPL.AccumulatorTuple,
62+
)
63+
f = DynamicPPL.LogDensityAt(
64+
model, getlogdensity, varname_ranges, transform_strategy, accs
65+
)
66+
# Val{false}() skips tag checking, since our DynamicPPLTag is reused across calls
67+
# with different LogDensityAt instances.
68+
ForwardDiff.gradient!(prep.grad, f, params, prep.cfg, Val{false}())
69+
# gradient!(::AbstractArray, ...) doesn't return the value, so evaluate separately.
70+
value = f(params)
71+
return value, copy(prep.grad)
72+
end
73+
3574
end # module

ext/DynamicPPLMooncakeExt.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@ module DynamicPPLMooncakeExt
33
using DynamicPPL: DynamicPPL, is_transformed
44
using Mooncake:
55
Mooncake,
6+
Dual,
67
NoTangent,
78
prepare_derivative_cache,
89
prepare_gradient_cache,
10+
primal,
11+
tangent,
12+
value_and_derivative!!,
913
value_and_gradient!!
1014

1115
# These are purely optimisations (although quite significant ones sometimes, especially for
@@ -61,7 +65,8 @@ function DynamicPPL._prepare_gradient(
6165
accs::DynamicPPL.AccumulatorTuple,
6266
)
6367
f = LogDensityAt(model, getlogdensity, varname_ranges, transform_strategy, accs)
64-
return prepare_derivative_cache(f, x; config=_cache_config(adtype))
68+
cache = prepare_derivative_cache(f, x; config=_cache_config(adtype))
69+
return (; cache, dx=similar(x), grad=similar(x))
6570
end
6671

6772
function DynamicPPL._value_and_gradient(
@@ -90,7 +95,16 @@ function DynamicPPL._value_and_gradient(
9095
accs::DynamicPPL.AccumulatorTuple,
9196
)
9297
f = LogDensityAt(model, getlogdensity, varname_ranges, transform_strategy, accs)
93-
value, grad = value_and_gradient!!(prep, f, params)
98+
(; cache, dx, grad) = prep
99+
value = zero(eltype(grad))
100+
fill!(dx, zero(eltype(dx)))
101+
@inbounds for i in eachindex(grad, dx)
102+
dx[i] = one(eltype(dx))
103+
result = value_and_derivative!!(cache, Dual(f, NoTangent()), Dual(params, dx))
104+
value = primal(result)
105+
grad[i] = tangent(result)
106+
dx[i] = zero(eltype(dx))
107+
end
94108
return value, copy(grad)
95109
end
96110

src/logdensityfunction.jl

Lines changed: 3 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ using ADTypes: ADTypes
2323
using BangBang: BangBang
2424
using AbstractPPL: AbstractPPL, VarName
2525
using LogDensityProblems: LogDensityProblems
26-
import DifferentiationInterface as DI
2726
using Random: Random
2827

2928
"""
@@ -404,54 +403,9 @@ function (f::LogDensityAt)(params::AbstractVector{<:Real})
404403
)
405404
end
406405

407-
function _prepare_gradient(
408-
adtype::ADTypes.AbstractADType,
409-
x::AbstractVector{<:Real},
410-
model::Model,
411-
getlogdensity::Any,
412-
varname_ranges::VarNamedTuple,
413-
transform_strategy::AbstractTransformStrategy,
414-
accs::AccumulatorTuple,
415-
)
416-
args = (model, getlogdensity, varname_ranges, transform_strategy, accs)
417-
return if _use_closure(adtype)
418-
DI.prepare_gradient(LogDensityAt(args...), adtype, x)
419-
else
420-
DI.prepare_gradient(logdensity_at, adtype, x, map(DI.Constant, args)...)
421-
end
422-
end
423-
424-
function _value_and_gradient(
425-
adtype::ADTypes.AbstractADType,
426-
prep,
427-
params::AbstractVector{<:Real},
428-
model::Model,
429-
getlogdensity::Any,
430-
varname_ranges::VarNamedTuple,
431-
transform_strategy::AbstractTransformStrategy,
432-
accs::AccumulatorTuple,
433-
)
434-
return if _use_closure(adtype)
435-
DI.value_and_gradient(
436-
LogDensityAt(model, getlogdensity, varname_ranges, transform_strategy, accs),
437-
prep,
438-
adtype,
439-
params,
440-
)
441-
else
442-
DI.value_and_gradient(
443-
logdensity_at,
444-
prep,
445-
adtype,
446-
params,
447-
DI.Constant(model),
448-
DI.Constant(getlogdensity),
449-
DI.Constant(varname_ranges),
450-
DI.Constant(transform_strategy),
451-
DI.Constant(accs),
452-
)
453-
end
454-
end
406+
# Extensible hooks: backends provide methods via package extensions.
407+
function _prepare_gradient end
408+
function _value_and_gradient end
455409

456410
function LogDensityProblems.logdensity(
457411
ldf::LogDensityFunction, params::AbstractVector{<:Real}

src/test_utils/ad.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module AD
22

33
using ADTypes: AbstractADType, AutoForwardDiff
44
using Chairmarks: @be
5-
import DifferentiationInterface as DI
65
using DocStringExtensions
76
using DynamicPPL:
87
DynamicPPL,

test/Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
99
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
1010
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1111
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
12-
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1312
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
1413
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1514
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1615
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
16+
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
17+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1718
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1819
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
1920
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -43,19 +44,18 @@ BangBang = "0.4"
4344
Bijectors = "0.15.17"
4445
Chairmarks = "1"
4546
Combinatorics = "1"
46-
DifferentiationInterface = "0.6.41, 0.7"
4747
DimensionalData = "0.30"
4848
Distributions = "0.25"
4949
Documenter = "1"
50+
Enzyme = "0.13"
5051
ForwardDiff = "0.10.12, 1"
5152
InvertedIndices = "1"
5253
LogDensityProblems = "2"
5354
MCMCChains = "7.2.1"
5455
MacroTools = "0.5.6"
5556
MarginalLogDensities = "0.4"
56-
Mooncake = "0.4, 0.5"
57-
OrderedCollections = "1"
5857
OffsetArrays = "1"
58+
OrderedCollections = "1"
5959
ReverseDiff = "1"
6060
SpecialFunctions = "2.6.1"
6161
StableRNGs = "1"

test/logdensityfunction.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ using LogDensityProblems: LogDensityProblems
1313
using Random: Xoshiro
1414
using StableRNGs: StableRNG
1515

16+
using Enzyme: Enzyme
1617
using ForwardDiff: ForwardDiff
17-
using ReverseDiff: ReverseDiff
1818
using Mooncake: Mooncake
19+
using ReverseDiff: ReverseDiff
1920

2021
@testset "LogDensityFunction: constructors" begin
2122
dist = Beta(2, 2)
@@ -177,12 +178,12 @@ end
177178
struct ErrorAccumulatorException <: Exception end
178179
struct ErrorAccumulator <: DynamicPPL.AbstractAccumulator end
179180
DynamicPPL.accumulator_name(::ErrorAccumulator) = :ERROR
180-
DynamicPPL.accumulate_assume!!(
181-
::ErrorAccumulator, ::Any, ::Any, ::Any, ::VarName, ::Distribution, ::Any
182-
) = throw(ErrorAccumulatorException())
183-
DynamicPPL.accumulate_observe!!(
184-
::ErrorAccumulator, ::Distribution, ::Any, ::Union{VarName,Nothing}, ::Any
185-
) = throw(ErrorAccumulatorException())
181+
DynamicPPL.accumulate_assume!!(::ErrorAccumulator, ::Any, ::Any, ::Any, ::VarName, ::Distribution, ::Any) = throw(
182+
ErrorAccumulatorException()
183+
)
184+
DynamicPPL.accumulate_observe!!(::ErrorAccumulator, ::Distribution, ::Any, ::Union{VarName,Nothing}, ::Any) = throw(
185+
ErrorAccumulatorException()
186+
)
186187
DynamicPPL.reset(ea::ErrorAccumulator) = ea
187188
Base.copy(ea::ErrorAccumulator) = ea
188189
# Construct an LDF

0 commit comments

Comments
 (0)