Use native gradient API for ForwardDiff, Enzyme, Mooncake#1354
Use native gradient API for ForwardDiff, Enzyme, Mooncake#1354
Conversation
- Refactor `_prepare_gradient` and `_value_and_gradient` into
overridable dispatch methods so backends can bypass DI entirely
- Implement AutoMooncakeForward in the Mooncake extension using
Mooncake's native derivative cache and a column-by-column sweep
- Force `friendly_tangents=false` in `_cache_config` for both
AutoMooncake and AutoMooncakeForward to keep caches valid across calls
- Declare `tangent_type(LogDensityAt) = NoTangent` so Mooncake treats
the function object as a constant
- Relax ADP type parameter on LogDensityFunction from
`Union{Nothing,DI.GradientPrep}` to unconstrained, to accommodate
custom prep objects (e.g. the NamedTuple used by AutoMooncakeForward)
- Add AutoMooncakeForward to the precompile workload and test suite,
including a test that friendly_tangents=true config is handled correctly
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
99de26f to
b9bca61
Compare
Benchmark Report
Computer InformationBenchmark Results |
b9bca61 to
9262d08
Compare
|
DynamicPPL.jl documentation for PR #1354 is available at: |
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
9262d08 to
2a88f94
Compare
|
@yebai can you explain the rationale behind getting rid of DifferentiationInterface? |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
9df03ef to
31cc13b
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1354 +/- ##
==========================================
- Coverage 78.62% 78.30% -0.33%
==========================================
Files 50 52 +2
Lines 3631 3697 +66
==========================================
+ Hits 2855 2895 +40
- Misses 776 802 +26 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
0fe3386 to
e70bf72
Compare
There was a problem hiding this comment.
Remaining comments which cannot be posted as a review comment to avoid GitHub Rate Limit
JuliaFormatter v1.0.62
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/ldf/models.md
Line 128 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/ldf/models.md
Line 130 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/ldf/models.md
Line 136 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/ldf/models.md
Line 138 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/ldf/models.md
Line 142 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/ldf/models.md
Line 155 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/ldf/models.md
Line 157 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/ldf/models.md
Line 161 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/ldf/overview.md
Line 47 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/ldf/overview.md
Line 70 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/tilde.md
Line 7 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/tilde.md
Line 80 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/tilde.md
Line 83 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/tilde.md
Line 97 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/tilde.md
Line 113 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/tilde.md
Line 129 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/tilde.md
Line 132 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/tilde.md
Line 140 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/tilde.md
Line 144 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/tilde.md
Line 147 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/tilde.md
Line 149 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/tilde.md
Line 151 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/tilde.md
Line 162 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/tilde.md
Line 166 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/tilde.md
Line 170 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/tilde.md
Line 175 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/tilde.md
Line 186 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/tilde.md
Line 190 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/transforms.md
Line 6 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/transforms.md
Line 26 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/transforms.md
Line 34 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/transforms.md
Line 47 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/transforms.md
Line 73 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/transforms.md
Line 76 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/transforms.md
Line 81 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/transforms.md
Line 153 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/vnt/design.md
Line 45 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/vnt/design.md
Line 165 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/vnt/design.md
Line 285 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/vnt/implementation.md
Line 115 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/vnt/implementation.md
Line 117 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/vnt/implementation.md
Line 121 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/vnt/motivation.md
Line 144 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/vnt/motivation.md
Line 146 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/docs/src/vnt/motivation.md
Line 175 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/ext/DynamicPPLMCMCChainsExt.jl
Lines 66 to 67 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/ext/DynamicPPLMCMCChainsExt.jl
Lines 80 to 81 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/ext/DynamicPPLReverseDiffExt.jl
Lines 6 to 8 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/src/abstract_varinfo.jl
Line 97 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
Line 324 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/src/contexts/conditionfix.jl
Line 120 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/src/contexts/conditionfix.jl
Line 284 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/src/contexts/conditionfix.jl
Line 296 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/src/contexts/conditionfix.jl
Line 308 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
Line 166 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
Line 661 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/src/test_utils/models.jl
Lines 601 to 602 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
Line 58 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/src/varnamedtuple/getset.jl
Line 428 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/src/varnamedtuple/vnt.jl
Lines 237 to 240 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/src/varnamedtuple/vnt.jl
Lines 243 to 246 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
DynamicPPL.jl/test/logdensityfunction.jl
Lines 181 to 186 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
Line 576 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
Line 583 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
Line 602 in e44a2d4
[JuliaFormatter v1.0.62] reported by reviewdog 🐶
Line 620 in e44a2d4
260d0df to
4f6efda
Compare
…ional - Move DifferentiationInterface to [weakdeps]; add DynamicPPLDifferentiationInterfaceExt as fallback for backends without native implementations - Add native ForwardDiff gradient via GradientConfig (DynamicPPLForwardDiffExt) - Add native Enzyme gradient via autodiff(ReverseWithPrimal, ...) (new DynamicPPLEnzymeExt) - Keep native Mooncake reverse/forward gradient (DynamicPPLMooncakeExt) - Add Enzyme to test env; drop DI from test env Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
4f6efda to
0cb79f8
Compare
|
@yebai can you please open issues on the DI repository to describe the "robustness issues" you're referring to above? |
6e4ac27 to
b3c9ed1
Compare
| f = DynamicPPL.LogDensityAt( | ||
| model, getlogdensity, varname_ranges, transform_strategy, accs | ||
| ) | ||
| dx = prep.dx |
There was a problem hiding this comment.
@penelopeysm @yebai you should use the version without a closure here [otherwise it's breaking]
- ForwardDiff: use DiffResults (via ForwardDiff.DiffResults) for single-pass value+gradient, removing the double primal evaluation - ForwardDiff: remove redundant chunk_size guard in _prepare_gradient (tweak_adtype already normalises it to a concrete positive integer) - AutoMooncakeForward: handle empty params edge case (loop doesn't execute) - Mooncake _cache_config: use Accessors.@set to preserve all Config fields when overriding friendly_tangents=false, instead of forwarding only two known fields - Mooncake @compile_workload: remove redundant single-element for-loop - EnzymeExt: document that adtype.mode is intentionally ignored (always reverse) - src/logdensityfunction.jl: add fallback error for _value_and_gradient with unknown AD backends, pointing users to ForwardDiff (the default) or DI - test/logdensityfunction.jl: revert formatter noise (accumulate_assume!!, accumulate_observe!!, ::Type{T}=... syntax) - test/Project.toml: remove accidentally-added DynamicPPL dep Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
b3c9ed1 to
de0e6b4
Compare
d446e29 to
11fae89
Compare
8f212fe to
da2b77f
Compare
| transform_strategy::DynamicPPL.AbstractTransformStrategy, | ||
| accs::DynamicPPL.AccumulatorTuple, | ||
| ) | ||
| # Pass the plain function plus Const arguments; Enzyme is brittle with closure-like callables. |
There was a problem hiding this comment.
this is not true, Enzyme (as is julia) is higher performance without closures
| Enzyme.Const(accs), | ||
| ), | ||
| ) | ||
| return val, copy(dx) |
da2b77f to
8c6c52a
Compare
|
@wsmoses, feel free to suggest concrete code changes for Enzyme. Coding agents don't always give optimal solutions. |
28d9100 to
8be907f
Compare
|
Please don't do this. We want more interoperability, not less. This will hurt the ecosystem |
|
This seems like a somewhat strange decision. What were the robustness issues? |
| @@ -0,0 +1,65 @@ | |||
| module DynamicPPLDifferentiationInterfaceExt | |||
There was a problem hiding this comment.
Is there a reason for making this a package extension instead of keeping it as a dependency? DI only depends on ADTypes and LinearAlgebra, it doesn't get much more lightweight
| fill!(dx, zero(eltype(dx))) | ||
| _, val = Enzyme.autodiff( | ||
| _enzyme_gradient_mode(adtype), | ||
| logdensity_at, |
There was a problem hiding this comment.
Is the function annotation in the AutoEnzyme backend taken into account?
| # DiffResults is a direct dependency of ForwardDiff; access it through ForwardDiff's namespace | ||
| # rather than listing it as a separate (weak)dep of DynamicPPL. |
There was a problem hiding this comment.
This is bad practice because you cannot version-bound DiffResults separately, and nothing guarantees that ForwardDiff will keep it as a dep
| @inbounds for i in eachindex(grad, dx) | ||
| dx[i] = one(eltype(dx)) | ||
| result = value_and_derivative!!(cache, Dual(f, NoTangent()), Dual(params, dx)) | ||
| value = primal(result) | ||
| grad[i] = tangent(result) | ||
| dx[i] = zero(eltype(dx)) | ||
| end |
There was a problem hiding this comment.
I thought there was a chunked forward mode for this kind of stuff?
Related:
| function _value_and_gradient(adtype::ADTypes.AbstractADType, args...) | ||
| throw( | ||
| ArgumentError( | ||
| "No gradient implementation found for AD backend $adtype. " * | ||
| "If you intended to use the default (ForwardDiff), ensure that ForwardDiff is " * | ||
| "loaded (e.g. `using ForwardDiff`). For other backends, load the corresponding " * | ||
| "package (e.g. `using Mooncake`, `using Enzyme`) or load " * | ||
| "DifferentiationInterface as a fallback.", | ||
| ), | ||
| ) | ||
| end |
There was a problem hiding this comment.
And it makes ReverseDiff no longer supported by Turing, if I understand properly?
There was a problem hiding this comment.
Technically it is still supported but you need to import DI separately to trigger the extension
There was a problem hiding this comment.
I think this PR also removes the relevant tests anyway?
There was a problem hiding this comment.
It's still tested, but even though DI is removed as a test dep, it still gets pulled into the test env via Bijectors and MarginalLogDensities, which is why the tests pass. So I guess this is dangerous since it implicitly relies on them still having DI as a dep.
pkg> why DifferentiationInterface
Bijectors → DifferentiationInterface
MarginalLogDensities → DifferentiationInterface
MarginalLogDensities → Optimization → OptimizationBase → DifferentiationInterface
MarginalLogDensities → OptimizationOptimJL → Optim → LineSearches → NLSolversBase → DifferentiationInterface
MarginalLogDensities → OptimizationOptimJL → Optim → NLSolversBase → DifferentiationInterface
MarginalLogDensities → OptimizationOptimJL → OptimizationBase → DifferentiationInterfaceThere was a problem hiding this comment.
Especially since @yebai is also removing DI from
Bijectors as part of the great spring cleaning 😅
Calls ForwardDiff, Enzyme, and Mooncake API directly, as DI has robustness
issues with both Enzyme and Mooncake.
This PR improves the performance of all benchmarking cases.
To support backend-specific prep and evaluation,
_prepare_gradientand_value_and_gradientare extracted into overridable dispatch methods.The
ADPtype parameter onLogDensityFunctionis unconstrainedaccordingly, since
AutoMooncakeForward's prep is aNamedTupleratherthan a
GradientPrep. Atangent_type(LogDensityAt) = NoTangentdeclaration tells Mooncake to treat the function object as a constant.