Skip to content

Commit b3c9ed1

Browse files
yebaiclaude
andcommitted
Address review feedback: correctness, clarity, and robustness fixes
- ForwardDiff: use DiffResults (via ForwardDiff.DiffResults) for single-pass value+gradient, removing the double primal evaluation - ForwardDiff: remove redundant chunk_size guard in _prepare_gradient (tweak_adtype already normalises it to a concrete positive integer) - AutoMooncakeForward: handle empty params edge case (loop doesn't execute) - Mooncake _cache_config: use Accessors.@set to preserve all Config fields when overriding friendly_tangents=false, instead of forwarding only two known fields - Mooncake @compile_workload: remove redundant single-element for-loop - EnzymeExt: document that adtype.mode is intentionally ignored (always reverse) - src/logdensityfunction.jl: add fallback error for _value_and_gradient with unknown AD backends, pointing users to ForwardDiff (the default) or DI - test/logdensityfunction.jl: revert formatter noise (accumulate_assume!!, accumulate_observe!!, ::Type{T}=... syntax) - test/Project.toml: remove accidentally-added DynamicPPL dep Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 0cb79f8 commit b3c9ed1

7 files changed

Lines changed: 51 additions & 38 deletions

File tree

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ FillArrays = "0.13, 1"
3737
ForwardDiff = "0.10, 1"
3838
LogDensityProblems = "2"
3939
MCMCChains = "5, 6, 7"
40-
MarginalLogDensities = "0.4"
40+
MarginalLogDensities = "0.4.3"
4141
OrderedCollections = "1"
4242
StableRNGs = "1"
4343
StatsFuns = "1"

ext/DynamicPPLEnzymeExt.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@ module DynamicPPLEnzymeExt
33
using DynamicPPL: ADTypes, DynamicPPL
44
using Enzyme: Enzyme
55

6+
_enzyme_gradient_mode(::ADTypes.AutoEnzyme{Nothing}) = Enzyme.ReverseWithPrimal
7+
function _enzyme_gradient_mode(adtype::ADTypes.AutoEnzyme)
8+
return Enzyme.EnzymeCore.set_runtime_activity(Enzyme.ReverseWithPrimal, adtype.mode)
9+
end
10+
611
function DynamicPPL._prepare_gradient(
712
::ADTypes.AutoEnzyme,
813
x::AbstractVector{<:Real},
@@ -16,7 +21,7 @@ function DynamicPPL._prepare_gradient(
1621
end
1722

1823
function DynamicPPL._value_and_gradient(
19-
::ADTypes.AutoEnzyme,
24+
adtype::ADTypes.AutoEnzyme,
2025
prep,
2126
params::AbstractVector{<:Real},
2227
model::DynamicPPL.Model,
@@ -32,9 +37,11 @@ function DynamicPPL._value_and_gradient(
3237
fill!(dx, zero(eltype(dx)))
3338
# Const(f): LogDensityAt is not being differentiated; without Const, Enzyme errors
3439
# because it cannot prove the function argument is readonly.
40+
# We always use reverse mode to obtain the full gradient in one pass, but preserve
41+
# runtime-activity settings from `adtype.mode` when they were requested.
3542
# autodiff(ReverseWithPrimal, ...) returns ((), val); dx is mutated in-place.
3643
_, val = Enzyme.autodiff(
37-
Enzyme.ReverseWithPrimal,
44+
_enzyme_gradient_mode(adtype),
3845
Enzyme.Const(f),
3946
Enzyme.Active,
4047
Enzyme.Duplicated(params, dx),

ext/DynamicPPLForwardDiffExt.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ module DynamicPPLForwardDiffExt
22

33
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems
44
using ForwardDiff
5+
# DiffResults is a direct dependency of ForwardDiff; access it through ForwardDiff's namespace
6+
# rather than listing it as a separate (weak)dep of DynamicPPL.
7+
const DiffResults = ForwardDiff.DiffResults
58

69
# check if the AD type already has a tag
710
use_dynamicppl_tag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true
@@ -40,14 +43,11 @@ function DynamicPPL._prepare_gradient(
4043
f = DynamicPPL.LogDensityAt(
4144
model, getlogdensity, varname_ranges, transform_strategy, accs
4245
)
43-
chunk = if chunk_size == 0 || chunk_size === nothing
44-
ForwardDiff.Chunk(x)
45-
else
46-
ForwardDiff.Chunk(length(x), chunk_size)
47-
end
46+
# chunk_size is already set to a concrete positive integer by tweak_adtype
47+
chunk = ForwardDiff.Chunk(length(x), chunk_size)
4848
cfg = ForwardDiff.GradientConfig(f, x, chunk, adtype.tag)
49-
grad = similar(x)
50-
return (; cfg, grad)
49+
result = DiffResults.GradientResult(similar(x))
50+
return (; cfg, result)
5151
end
5252

5353
function DynamicPPL._value_and_gradient(
@@ -65,10 +65,8 @@ function DynamicPPL._value_and_gradient(
6565
)
6666
# Val{false}() skips tag checking, since our DynamicPPLTag is reused across calls
6767
# with different LogDensityAt instances.
68-
ForwardDiff.gradient!(prep.grad, f, params, prep.cfg, Val{false}())
69-
# gradient!(::AbstractArray, ...) doesn't return the value, so evaluate separately.
70-
value = f(params)
71-
return value, copy(prep.grad)
68+
ForwardDiff.gradient!(prep.result, f, params, prep.cfg, Val{false}())
69+
return DiffResults.value(prep.result), copy(DiffResults.gradient(prep.result))
7270
end
7371

7472
end # module

ext/DynamicPPLMooncakeExt.jl

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,16 @@ Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{
2424

2525
using DynamicPPL: @model, LinkAll, LogDensityAt, getlogjoint_internal, LogDensityFunction
2626
using ADTypes: AutoMooncake, AutoMooncakeForward
27+
using Accessors: Accessors
2728
using Distributions: Normal, InverseGamma, Beta
2829
using PrecompileTools: @setup_workload, @compile_workload
2930

3031
function _cache_config(::Union{AutoMooncake{Nothing},AutoMooncakeForward{Nothing}})
3132
return Mooncake.Config(; friendly_tangents=false)
3233
end
3334
function _cache_config(adtype::Union{AutoMooncake,AutoMooncakeForward})
34-
config = adtype.config
35-
return Mooncake.Config(;
36-
debug_mode=config.debug_mode,
37-
silence_debug_messages=config.silence_debug_messages,
38-
friendly_tangents=false,
39-
)
35+
# Use Accessors to set friendly_tangents=false while preserving all other config fields.
36+
return Accessors.@set adtype.config.friendly_tangents = false
4037
end
4138

4239
# LogDensityAt is a constant w.r.t. differentiation; NoTangent avoids tangent allocation.
@@ -96,6 +93,8 @@ function DynamicPPL._value_and_gradient(
9693
)
9794
f = LogDensityAt(model, getlogdensity, varname_ranges, transform_strategy, accs)
9895
(; cache, dx, grad) = prep
96+
# Handle empty parameter vector: value_and_derivative!! loop won't execute.
97+
isempty(params) && return f(params), copy(grad)
9998
value = zero(eltype(grad))
10099
fill!(dx, zero(eltype(dx)))
101100
@inbounds for i in eachindex(grad, dx)
@@ -110,12 +109,10 @@ end
110109

111110
@setup_workload begin
112111
@compile_workload begin
113-
for adtype in (AutoMooncake(),)
114-
for dist in (Normal(), InverseGamma(2, 3), Beta(2, 2))
115-
@model f() = x ~ dist
116-
ldf = LogDensityFunction(f(), getlogjoint_internal, LinkAll(); adtype)
117-
DynamicPPL.LogDensityProblems.logdensity_and_gradient(ldf, [0.5])
118-
end
112+
for dist in (Normal(), InverseGamma(2, 3), Beta(2, 2))
113+
@model f() = x ~ dist
114+
ldf = LogDensityFunction(f(), getlogjoint_internal, LinkAll(); adtype=AutoMooncake())
115+
DynamicPPL.LogDensityProblems.logdensity_and_gradient(ldf, [0.5])
119116
end
120117
end
121118
end

src/logdensityfunction.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,18 @@ end
407407
function _prepare_gradient end
408408
function _value_and_gradient end
409409

410+
function _value_and_gradient(adtype::ADTypes.AbstractADType, args...)
411+
throw(
412+
ArgumentError(
413+
"No gradient implementation found for AD backend $adtype. " *
414+
"If you intended to use the default (ForwardDiff), ensure that ForwardDiff is " *
415+
"loaded (e.g. `using ForwardDiff`). For other backends, load the corresponding " *
416+
"package (e.g. `using Mooncake`, `using Enzyme`) or load " *
417+
"DifferentiationInterface as a fallback.",
418+
),
419+
)
420+
end
421+
410422
function LogDensityProblems.logdensity(
411423
ldf::LogDensityFunction, params::AbstractVector{<:Real}
412424
)

test/Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
1313
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1414
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1515
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
16-
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
1716
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1817
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1918
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
@@ -53,7 +52,7 @@ InvertedIndices = "1"
5352
LogDensityProblems = "2"
5453
MCMCChains = "7.2.1"
5554
MacroTools = "0.5.6"
56-
MarginalLogDensities = "0.4"
55+
MarginalLogDensities = "0.4.3"
5756
OffsetArrays = "1"
5857
OrderedCollections = "1"
5958
ReverseDiff = "1"

test/logdensityfunction.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -178,12 +178,12 @@ end
178178
struct ErrorAccumulatorException <: Exception end
179179
struct ErrorAccumulator <: DynamicPPL.AbstractAccumulator end
180180
DynamicPPL.accumulator_name(::ErrorAccumulator) = :ERROR
181-
DynamicPPL.accumulate_assume!!(::ErrorAccumulator, ::Any, ::Any, ::Any, ::VarName, ::Distribution, ::Any) = throw(
182-
ErrorAccumulatorException()
183-
)
184-
DynamicPPL.accumulate_observe!!(::ErrorAccumulator, ::Distribution, ::Any, ::Union{VarName,Nothing}, ::Any) = throw(
185-
ErrorAccumulatorException()
186-
)
181+
DynamicPPL.accumulate_assume!!(
182+
::ErrorAccumulator, ::Any, ::Any, ::Any, ::VarName, ::Distribution, ::Any
183+
) = throw(ErrorAccumulatorException())
184+
DynamicPPL.accumulate_observe!!(
185+
::ErrorAccumulator, ::Distribution, ::Any, ::Union{VarName,Nothing}, ::Any
186+
) = throw(ErrorAccumulatorException())
187187
DynamicPPL.reset(ea::ErrorAccumulator) = ea
188188
Base.copy(ea::ErrorAccumulator) = ea
189189
# Construct an LDF
@@ -497,7 +497,7 @@ end
497497
return LogDensityProblems.logdensity_and_gradient(ldf, m[:])
498498
end
499499

500-
@model function scalar_matrix_model((::Type{T})=Float64) where {T<:Real}
500+
@model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real}
501501
m = Matrix{T}(undef, 2, 3)
502502
return m ~ filldist(MvNormal(zeros(2), I), 3)
503503
end
@@ -506,14 +506,14 @@ end
506506
scalar_matrix_model, test_m, ref_adtype
507507
)
508508

509-
@model function matrix_model((::Type{T})=Matrix{Float64}) where {T}
509+
@model function matrix_model(::Type{T}=Matrix{Float64}) where {T}
510510
m = T(undef, 2, 3)
511511
return m ~ filldist(MvNormal(zeros(2), I), 3)
512512
end
513513

514514
matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype)
515515

516-
@model function scalar_array_model((::Type{T})=Float64) where {T<:Real}
516+
@model function scalar_array_model(::Type{T}=Float64) where {T<:Real}
517517
m = Array{T}(undef, 2, 3)
518518
return m ~ filldist(MvNormal(zeros(2), I), 3)
519519
end
@@ -522,7 +522,7 @@ end
522522
scalar_array_model, test_m, ref_adtype
523523
)
524524

525-
@model function array_model((::Type{T})=Array{Float64}) where {T}
525+
@model function array_model(::Type{T}=Array{Float64}) where {T}
526526
m = T(undef, 2, 3)
527527
return m ~ filldist(MvNormal(zeros(2), I), 3)
528528
end

0 commit comments

Comments
 (0)