Skip to content

Commit 99de26f

Browse files
yebaiclaude
andcommitted
Simplify _cache_config, drop friendly_tangents test
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent f9d58e8 commit 99de26f

2 files changed

Lines changed: 13 additions & 87 deletions

File tree

ext/DynamicPPLMooncakeExt.jl

Lines changed: 7 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,9 @@ module DynamicPPLMooncakeExt
33
using DynamicPPL: DynamicPPL, is_transformed
44
using Mooncake:
55
Mooncake,
6-
Dual,
76
NoTangent,
8-
primal,
97
prepare_derivative_cache,
108
prepare_gradient_cache,
11-
tangent,
12-
value_and_derivative!!,
139
value_and_gradient!!
1410

1511
# These are purely optimisations (although quite significant ones sometimes, especially for
@@ -27,23 +23,19 @@ using ADTypes: AutoMooncake, AutoMooncakeForward
2723
using Distributions: Normal, InverseGamma, Beta
2824
using PrecompileTools: @setup_workload, @compile_workload
2925

30-
_config(::Union{AutoMooncake{Nothing},AutoMooncakeForward{Nothing}}) = Mooncake.Config()
31-
_config(adtype::Union{AutoMooncake,AutoMooncakeForward}) = adtype.config
26+
function _cache_config(::Union{AutoMooncake{Nothing},AutoMooncakeForward{Nothing}})
27+
Mooncake.Config(; friendly_tangents=false)
28+
end
3229
function _cache_config(adtype::Union{AutoMooncake,AutoMooncakeForward})
33-
config = _config(adtype)
34-
# `friendly_tangents=true` rewrites tangent types into named structs at tape-build time,
35-
# which is incompatible with a reusable cache (the cached tape would be tied to the
36-
# original tangent struct layout). Force it off so the cache stays valid across calls.
30+
config = adtype.config
3731
return Mooncake.Config(;
3832
debug_mode=config.debug_mode,
3933
silence_debug_messages=config.silence_debug_messages,
4034
friendly_tangents=false,
4135
)
4236
end
4337

44-
# LogDensityAt is the function being differentiated through, not a quantity being
45-
# differentiated with respect to. Declaring NoTangent here tells Mooncake to treat it as
46-
# a constant, which is correct and avoids unnecessary tangent allocation.
38+
# LogDensityAt is a constant w.r.t. differentiation; NoTangent avoids tangent allocation.
4739
Mooncake.tangent_type(::Type{<:DynamicPPL.LogDensityAt}) = NoTangent
4840

4941
function DynamicPPL._prepare_gradient(
@@ -69,15 +61,11 @@ function DynamicPPL._prepare_gradient(
6961
accs::DynamicPPL.AccumulatorTuple,
7062
)
7163
f = LogDensityAt(model, getlogdensity, varname_ranges, transform_strategy, accs)
72-
return (;
73-
cache=prepare_derivative_cache(f, x; config=_cache_config(adtype)),
74-
dx=similar(x),
75-
grad=similar(x),
76-
)
64+
return prepare_derivative_cache(f, x; config=_cache_config(adtype))
7765
end
7866

7967
function DynamicPPL._value_and_gradient(
80-
::AutoMooncake,
68+
::Union{AutoMooncake,AutoMooncakeForward},
8169
prep,
8270
params::AbstractVector{<:Real},
8371
model::DynamicPPL.Model,
@@ -91,46 +79,6 @@ function DynamicPPL._value_and_gradient(
9179
return value, copy(grad)
9280
end
9381

94-
function DynamicPPL._value_and_gradient(
95-
::AutoMooncakeForward,
96-
prep,
97-
params::AbstractVector{<:Real},
98-
model::DynamicPPL.Model,
99-
getlogdensity::Any,
100-
varname_ranges::DynamicPPL.VarNamedTuple,
101-
transform_strategy::DynamicPPL.AbstractTransformStrategy,
102-
accs::DynamicPPL.AccumulatorTuple,
103-
)
104-
f = LogDensityAt(model, getlogdensity, varname_ranges, transform_strategy, accs)
105-
dx = prep.dx
106-
grad = prep.grad
107-
108-
if isempty(grad)
109-
# Zero-dimensional parameter vector: evaluate primal only. Use a zero tangent so
110-
# value_and_derivative!! returns the function value without computing any derivative.
111-
fill!(dx, zero(eltype(dx)))
112-
value = primal(
113-
value_and_derivative!!(prep.cache, Dual(f, NoTangent()), Dual(params, dx))
114-
)
115-
return value, copy(grad)
116-
end
117-
118-
# Standard column-by-column forward-mode sweep: set dx to each unit vector in turn,
119-
# compute the directional derivative, and accumulate into grad.
120-
# Each iteration resets dx[i] to zero after use, so dx is all-zeros at loop exit.
121-
value = zero(eltype(grad))
122-
@inbounds for i in eachindex(grad, dx)
123-
dx[i] = oneunit(eltype(dx))
124-
dual_value = value_and_derivative!!(
125-
prep.cache, Dual(f, NoTangent()), Dual(params, dx)
126-
)
127-
value = primal(dual_value)
128-
grad[i] = tangent(dual_value)
129-
dx[i] = zero(eltype(dx))
130-
end
131-
return value, copy(grad)
132-
end
133-
13482
@setup_workload begin
13583
@compile_workload begin
13684
for adtype in (AutoMooncake(), AutoMooncakeForward())

test/logdensityfunction.jl

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,12 @@ end
177177
struct ErrorAccumulatorException <: Exception end
178178
struct ErrorAccumulator <: DynamicPPL.AbstractAccumulator end
179179
DynamicPPL.accumulator_name(::ErrorAccumulator) = :ERROR
180-
DynamicPPL.accumulate_assume!!(
181-
::ErrorAccumulator, ::Any, ::Any, ::Any, ::VarName, ::Distribution, ::Any
182-
) = throw(ErrorAccumulatorException())
183-
DynamicPPL.accumulate_observe!!(
184-
::ErrorAccumulator, ::Distribution, ::Any, ::Union{VarName,Nothing}, ::Any
185-
) = throw(ErrorAccumulatorException())
180+
DynamicPPL.accumulate_assume!!(::ErrorAccumulator, ::Any, ::Any, ::Any, ::VarName, ::Distribution, ::Any) = throw(
181+
ErrorAccumulatorException()
182+
)
183+
DynamicPPL.accumulate_observe!!(::ErrorAccumulator, ::Distribution, ::Any, ::Union{VarName,Nothing}, ::Any) = throw(
184+
ErrorAccumulatorException()
185+
)
186186
DynamicPPL.reset(ea::ErrorAccumulator) = ea
187187
Base.copy(ea::ErrorAccumulator) = ea
188188
# Construct an LDF
@@ -547,28 +547,6 @@ end
547547
@test array_model_logp_and_grad[2] array_model_reference[2]
548548
end
549549
end
550-
551-
@testset "Mooncake friendly_tangents" begin
552-
@model function f()
553-
x ~ Normal()
554-
return y ~ Normal(x)
555-
end
556-
557-
params = randn(2)
558-
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(
559-
LogDensityFunction(f(); adtype=ref_adtype), params
560-
)
561-
562-
for adtype in (
563-
AutoMooncake(; config=Mooncake.Config(; friendly_tangents=true)),
564-
AutoMooncakeForward(; config=Mooncake.Config(; friendly_tangents=true)),
565-
)
566-
ldf = LogDensityFunction(f(); adtype)
567-
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, params)
568-
@test logp ref_logp
569-
@test grad ref_grad
570-
end
571-
end
572550
end
573551

574552
end

0 commit comments

Comments
 (0)