@@ -3,13 +3,9 @@ module DynamicPPLMooncakeExt
33using DynamicPPL: DynamicPPL, is_transformed
44using Mooncake:
55 Mooncake,
6- Dual,
76 NoTangent,
8- primal,
97 prepare_derivative_cache,
108 prepare_gradient_cache,
11- tangent,
12- value_and_derivative!!,
139 value_and_gradient!!
1410
1511# These are purely optimisations (although quite significant ones sometimes, especially for
@@ -27,23 +23,19 @@ using ADTypes: AutoMooncake, AutoMooncakeForward
2723using Distributions: Normal, InverseGamma, Beta
2824using PrecompileTools: @setup_workload , @compile_workload
2925
30- _config (:: Union{AutoMooncake{Nothing},AutoMooncakeForward{Nothing}} ) = Mooncake. Config ()
31- _config (adtype:: Union{AutoMooncake,AutoMooncakeForward} ) = adtype. config
26+ function _cache_config (:: Union{AutoMooncake{Nothing},AutoMooncakeForward{Nothing}} )
27+ Mooncake. Config (; friendly_tangents= false )
28+ end
3229function _cache_config (adtype:: Union{AutoMooncake,AutoMooncakeForward} )
33- config = _config (adtype)
34- # `friendly_tangents=true` rewrites tangent types into named structs at tape-build time,
35- # which is incompatible with a reusable cache (the cached tape would be tied to the
36- # original tangent struct layout). Force it off so the cache stays valid across calls.
30+ config = adtype. config
3731 return Mooncake. Config (;
3832 debug_mode= config. debug_mode,
3933 silence_debug_messages= config. silence_debug_messages,
4034 friendly_tangents= false ,
4135 )
4236end
4337
44- # LogDensityAt is the function being differentiated through, not a quantity being
45- # differentiated with respect to. Declaring NoTangent here tells Mooncake to treat it as
46- # a constant, which is correct and avoids unnecessary tangent allocation.
38+ # LogDensityAt is a constant w.r.t. differentiation; NoTangent avoids tangent allocation.
4739Mooncake. tangent_type (:: Type{<:DynamicPPL.LogDensityAt} ) = NoTangent
4840
4941function DynamicPPL. _prepare_gradient (
@@ -69,15 +61,11 @@ function DynamicPPL._prepare_gradient(
6961 accs:: DynamicPPL.AccumulatorTuple ,
7062)
7163 f = LogDensityAt (model, getlogdensity, varname_ranges, transform_strategy, accs)
72- return (;
73- cache= prepare_derivative_cache (f, x; config= _cache_config (adtype)),
74- dx= similar (x),
75- grad= similar (x),
76- )
64+ return prepare_derivative_cache (f, x; config= _cache_config (adtype))
7765end
7866
7967function DynamicPPL. _value_and_gradient (
80- :: AutoMooncake ,
68+ :: Union{ AutoMooncake,AutoMooncakeForward} ,
8169 prep,
8270 params:: AbstractVector{<:Real} ,
8371 model:: DynamicPPL.Model ,
@@ -91,46 +79,6 @@ function DynamicPPL._value_and_gradient(
9179 return value, copy (grad)
9280end
9381
94- function DynamicPPL. _value_and_gradient (
95- :: AutoMooncakeForward ,
96- prep,
97- params:: AbstractVector{<:Real} ,
98- model:: DynamicPPL.Model ,
99- getlogdensity:: Any ,
100- varname_ranges:: DynamicPPL.VarNamedTuple ,
101- transform_strategy:: DynamicPPL.AbstractTransformStrategy ,
102- accs:: DynamicPPL.AccumulatorTuple ,
103- )
104- f = LogDensityAt (model, getlogdensity, varname_ranges, transform_strategy, accs)
105- dx = prep. dx
106- grad = prep. grad
107-
108- if isempty (grad)
109- # Zero-dimensional parameter vector: evaluate primal only. Use a zero tangent so
110- # value_and_derivative!! returns the function value without computing any derivative.
111- fill! (dx, zero (eltype (dx)))
112- value = primal (
113- value_and_derivative!! (prep. cache, Dual (f, NoTangent ()), Dual (params, dx))
114- )
115- return value, copy (grad)
116- end
117-
118- # Standard column-by-column forward-mode sweep: set dx to each unit vector in turn,
119- # compute the directional derivative, and accumulate into grad.
120- # Each iteration resets dx[i] to zero after use, so dx is all-zeros at loop exit.
121- value = zero (eltype (grad))
122- @inbounds for i in eachindex (grad, dx)
123- dx[i] = oneunit (eltype (dx))
124- dual_value = value_and_derivative!! (
125- prep. cache, Dual (f, NoTangent ()), Dual (params, dx)
126- )
127- value = primal (dual_value)
128- grad[i] = tangent (dual_value)
129- dx[i] = zero (eltype (dx))
130- end
131- return value, copy (grad)
132- end
133-
13482@setup_workload begin
13583 @compile_workload begin
13684 for adtype in (AutoMooncake (), AutoMooncakeForward ())
0 commit comments