Skip to content

Commit 9262d08

Browse files
yebaiclaude
andcommitted
Simplify _cache_config, drop friendly_tangents test
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent f9d58e8 commit 9262d08

3 files changed

Lines changed: 14 additions & 80 deletions

File tree

ext/DynamicPPLMooncakeExt.jl

Lines changed: 7 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,9 @@ module DynamicPPLMooncakeExt
33
using DynamicPPL: DynamicPPL, is_transformed
44
using Mooncake:
55
Mooncake,
6-
Dual,
76
NoTangent,
8-
primal,
97
prepare_derivative_cache,
108
prepare_gradient_cache,
11-
tangent,
12-
value_and_derivative!!,
139
value_and_gradient!!
1410

1511
# These are purely optimisations (although quite significant ones sometimes, especially for
@@ -27,23 +23,19 @@ using ADTypes: AutoMooncake, AutoMooncakeForward
2723
using Distributions: Normal, InverseGamma, Beta
2824
using PrecompileTools: @setup_workload, @compile_workload
2925

30-
_config(::Union{AutoMooncake{Nothing},AutoMooncakeForward{Nothing}}) = Mooncake.Config()
31-
_config(adtype::Union{AutoMooncake,AutoMooncakeForward}) = adtype.config
26+
function _cache_config(::Union{AutoMooncake{Nothing},AutoMooncakeForward{Nothing}})
27+
return Mooncake.Config(; friendly_tangents=false)
28+
end
3229
function _cache_config(adtype::Union{AutoMooncake,AutoMooncakeForward})
33-
config = _config(adtype)
34-
# `friendly_tangents=true` rewrites tangent types into named structs at tape-build time,
35-
# which is incompatible with a reusable cache (the cached tape would be tied to the
36-
# original tangent struct layout). Force it off so the cache stays valid across calls.
30+
config = adtype.config
3731
return Mooncake.Config(;
3832
debug_mode=config.debug_mode,
3933
silence_debug_messages=config.silence_debug_messages,
4034
friendly_tangents=false,
4135
)
4236
end
4337

44-
# LogDensityAt is the function being differentiated through, not a quantity being
45-
# differentiated with respect to. Declaring NoTangent here tells Mooncake to treat it as
46-
# a constant, which is correct and avoids unnecessary tangent allocation.
38+
# LogDensityAt is a constant w.r.t. differentiation; NoTangent avoids tangent allocation.
4739
Mooncake.tangent_type(::Type{<:DynamicPPL.LogDensityAt}) = NoTangent
4840

4941
function DynamicPPL._prepare_gradient(
@@ -69,11 +61,7 @@ function DynamicPPL._prepare_gradient(
6961
accs::DynamicPPL.AccumulatorTuple,
7062
)
7163
f = LogDensityAt(model, getlogdensity, varname_ranges, transform_strategy, accs)
72-
return (;
73-
cache=prepare_derivative_cache(f, x; config=_cache_config(adtype)),
74-
dx=similar(x),
75-
grad=similar(x),
76-
)
64+
return prepare_derivative_cache(f, x; config=_cache_config(adtype))
7765
end
7866

7967
function DynamicPPL._value_and_gradient(
@@ -102,32 +90,7 @@ function DynamicPPL._value_and_gradient(
10290
accs::DynamicPPL.AccumulatorTuple,
10391
)
10492
f = LogDensityAt(model, getlogdensity, varname_ranges, transform_strategy, accs)
105-
dx = prep.dx
106-
grad = prep.grad
107-
108-
if isempty(grad)
109-
# Zero-dimensional parameter vector: evaluate primal only. Use a zero tangent so
110-
# value_and_derivative!! returns the function value without computing any derivative.
111-
fill!(dx, zero(eltype(dx)))
112-
value = primal(
113-
value_and_derivative!!(prep.cache, Dual(f, NoTangent()), Dual(params, dx))
114-
)
115-
return value, copy(grad)
116-
end
117-
118-
# Standard column-by-column forward-mode sweep: set dx to each unit vector in turn,
119-
# compute the directional derivative, and accumulate into grad.
120-
# Each iteration resets dx[i] to zero after use, so dx is all-zeros at loop exit.
121-
value = zero(eltype(grad))
122-
@inbounds for i in eachindex(grad, dx)
123-
dx[i] = oneunit(eltype(dx))
124-
dual_value = value_and_derivative!!(
125-
prep.cache, Dual(f, NoTangent()), Dual(params, dx)
126-
)
127-
value = primal(dual_value)
128-
grad[i] = tangent(dual_value)
129-
dx[i] = zero(eltype(dx))
130-
end
93+
value, grad = value_and_gradient!!(prep, f, params)
13194
return value, copy(grad)
13295
end
13396

src/logdensityfunction.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,7 @@ struct LogDensityFunction{
178178
L<:AbstractTransformStrategy,
179179
F,
180180
VNT<:VarNamedTuple,
181-
# ADP is intentionally unconstrained: most backends store a DI.GradientPrep, but
182-
# backends that override _prepare_gradient (e.g. AutoMooncakeForward) may store any
183-
# prep object (e.g. a NamedTuple with cache + gradient buffers).
184-
ADP,
181+
ADP, # unconstrained: backends may store any prep object via _prepare_gradient
185182
# type of the vector passed to logdensity functions
186183
X<:AbstractVector,
187184
AC<:AccumulatorTuple,
@@ -541,10 +538,6 @@ By default, this function returns `false`, i.e. the constant approach will be us
541538
# closure (see link in the docstring).
542539
_use_closure(::ADTypes.AutoForwardDiff) = false
543540
_use_closure(::ADTypes.AutoMooncake) = false
544-
# AutoMooncakeForward overrides _prepare_gradient/_value_and_gradient in the Mooncake
545-
# extension and bypasses DI entirely, so this value is never reached when Mooncake is
546-
# loaded. It is a defensive fallback for the (unlikely) case where AutoMooncakeForward is
547-
# used without the extension.
548541
_use_closure(::ADTypes.AutoMooncakeForward) = false
549542
# For ReverseDiff, with the compiled tape, you _must_ use a closure because otherwise with
550543
# DI.Constant arguments the tape will always be recompiled upon each call to

test/logdensityfunction.jl

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,12 @@ end
177177
struct ErrorAccumulatorException <: Exception end
178178
struct ErrorAccumulator <: DynamicPPL.AbstractAccumulator end
179179
DynamicPPL.accumulator_name(::ErrorAccumulator) = :ERROR
180-
DynamicPPL.accumulate_assume!!(
181-
::ErrorAccumulator, ::Any, ::Any, ::Any, ::VarName, ::Distribution, ::Any
182-
) = throw(ErrorAccumulatorException())
183-
DynamicPPL.accumulate_observe!!(
184-
::ErrorAccumulator, ::Distribution, ::Any, ::Union{VarName,Nothing}, ::Any
185-
) = throw(ErrorAccumulatorException())
180+
DynamicPPL.accumulate_assume!!(::ErrorAccumulator, ::Any, ::Any, ::Any, ::VarName, ::Distribution, ::Any) = throw(
181+
ErrorAccumulatorException()
182+
)
183+
DynamicPPL.accumulate_observe!!(::ErrorAccumulator, ::Distribution, ::Any, ::Union{VarName,Nothing}, ::Any) = throw(
184+
ErrorAccumulatorException()
185+
)
186186
DynamicPPL.reset(ea::ErrorAccumulator) = ea
187187
Base.copy(ea::ErrorAccumulator) = ea
188188
# Construct an LDF
@@ -547,28 +547,6 @@ end
547547
@test array_model_logp_and_grad[2] array_model_reference[2]
548548
end
549549
end
550-
551-
@testset "Mooncake friendly_tangents" begin
552-
@model function f()
553-
x ~ Normal()
554-
return y ~ Normal(x)
555-
end
556-
557-
params = randn(2)
558-
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(
559-
LogDensityFunction(f(); adtype=ref_adtype), params
560-
)
561-
562-
for adtype in (
563-
AutoMooncake(; config=Mooncake.Config(; friendly_tangents=true)),
564-
AutoMooncakeForward(; config=Mooncake.Config(; friendly_tangents=true)),
565-
)
566-
ldf = LogDensityFunction(f(); adtype)
567-
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, params)
568-
@test logp ref_logp
569-
@test grad ref_grad
570-
end
571-
end
572550
end
573551

574552
end

0 commit comments

Comments
 (0)