@@ -3,13 +3,9 @@ module DynamicPPLMooncakeExt
33using DynamicPPL: DynamicPPL, is_transformed
44using Mooncake:
55 Mooncake,
6- Dual,
76 NoTangent,
8- primal,
97 prepare_derivative_cache,
108 prepare_gradient_cache,
11- tangent,
12- value_and_derivative!!,
139 value_and_gradient!!
1410
1511# These are purely optimisations (although quite significant ones sometimes, especially for
@@ -27,23 +23,19 @@ using ADTypes: AutoMooncake, AutoMooncakeForward
2723using Distributions: Normal, InverseGamma, Beta
2824using PrecompileTools: @setup_workload , @compile_workload
2925
30- _config (:: Union{AutoMooncake{Nothing},AutoMooncakeForward{Nothing}} ) = Mooncake. Config ()
31- _config (adtype:: Union{AutoMooncake,AutoMooncakeForward} ) = adtype. config
26+ function _cache_config (:: Union{AutoMooncake{Nothing},AutoMooncakeForward{Nothing}} )
27+ return Mooncake. Config (; friendly_tangents= false )
28+ end
3229function _cache_config (adtype:: Union{AutoMooncake,AutoMooncakeForward} )
33- config = _config (adtype)
34- # `friendly_tangents=true` rewrites tangent types into named structs at tape-build time,
35- # which is incompatible with a reusable cache (the cached tape would be tied to the
36- # original tangent struct layout). Force it off so the cache stays valid across calls.
30+ config = adtype. config
3731 return Mooncake. Config (;
3832 debug_mode= config. debug_mode,
3933 silence_debug_messages= config. silence_debug_messages,
4034 friendly_tangents= false ,
4135 )
4236end
4337
44- # LogDensityAt is the function being differentiated through, not a quantity being
45- # differentiated with respect to. Declaring NoTangent here tells Mooncake to treat it as
46- # a constant, which is correct and avoids unnecessary tangent allocation.
38+ # LogDensityAt is a constant w.r.t. differentiation; NoTangent avoids tangent allocation.
4739Mooncake. tangent_type (:: Type{<:DynamicPPL.LogDensityAt} ) = NoTangent
4840
4941function DynamicPPL. _prepare_gradient (
@@ -69,11 +61,7 @@ function DynamicPPL._prepare_gradient(
6961 accs:: DynamicPPL.AccumulatorTuple ,
7062)
7163 f = LogDensityAt (model, getlogdensity, varname_ranges, transform_strategy, accs)
72- return (;
73- cache= prepare_derivative_cache (f, x; config= _cache_config (adtype)),
74- dx= similar (x),
75- grad= similar (x),
76- )
64+ return prepare_derivative_cache (f, x; config= _cache_config (adtype))
7765end
7866
7967function DynamicPPL. _value_and_gradient (
@@ -102,32 +90,7 @@ function DynamicPPL._value_and_gradient(
10290 accs:: DynamicPPL.AccumulatorTuple ,
10391)
10492 f = LogDensityAt (model, getlogdensity, varname_ranges, transform_strategy, accs)
105- dx = prep. dx
106- grad = prep. grad
107-
108- if isempty (grad)
109- # Zero-dimensional parameter vector: evaluate primal only. Use a zero tangent so
110- # value_and_derivative!! returns the function value without computing any derivative.
111- fill! (dx, zero (eltype (dx)))
112- value = primal (
113- value_and_derivative!! (prep. cache, Dual (f, NoTangent ()), Dual (params, dx))
114- )
115- return value, copy (grad)
116- end
117-
118- # Standard column-by-column forward-mode sweep: set dx to each unit vector in turn,
119- # compute the directional derivative, and accumulate into grad.
120- # Each iteration resets dx[i] to zero after use, so dx is all-zeros at loop exit.
121- value = zero (eltype (grad))
122- @inbounds for i in eachindex (grad, dx)
123- dx[i] = oneunit (eltype (dx))
124- dual_value = value_and_derivative!! (
125- prep. cache, Dual (f, NoTangent ()), Dual (params, dx)
126- )
127- value = primal (dual_value)
128- grad[i] = tangent (dual_value)
129- dx[i] = zero (eltype (dx))
130- end
93+ value, grad = value_and_gradient!! (prep, f, params)
13194 return value, copy (grad)
13295end
13396
0 commit comments