Skip to content

Commit 157e5f6

Browse files
committed
Add faster ParamsWithStats constructor
1 parent a4c1cf6 commit 157e5f6

4 files changed

Lines changed: 105 additions & 1 deletion

File tree

src/chains.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,27 @@ via `unflatten!!` plus re-evaluation. It is faster for two reasons:
120120
otherwise re-evaluation would mutate the VarInfo, rendering it unusable for subsequent
121121
MCMC iterations).
122122
2. The re-evaluation is faster as it uses `OnlyAccsVarInfo`.
123+
124+
Furthermore, if the `LogDensityFunction` has all fixed transforms (i.e., was constructed
125+
with `fix_transforms=true`), and neither `include_log_probs` nor `include_colon_eq` is
126+
set, then model re-evaluation is skipped entirely and the raw parameter values are
127+
extracted directly from the parameter vector using the cached transforms.
123128
"""
124129
function ParamsWithStats(
125130
param_vector::AbstractVector,
126131
ldf::DynamicPPL.LogDensityFunction,
127132
stats::NamedTuple=NamedTuple();
128133
include_colon_eq::Bool=true,
129134
include_log_probs::Bool=true,
135+
)
136+
return pws_with_eval(param_vector, ldf, stats; include_colon_eq, include_log_probs)
137+
end
138+
function pws_with_eval(
139+
param_vector::AbstractVector,
140+
ldf::DynamicPPL.LogDensityFunction,
141+
stats::NamedTuple=NamedTuple();
142+
include_colon_eq::Bool=true,
143+
include_log_probs::Bool=true,
130144
)
131145
strategy = InitFromVector(param_vector, ldf)
132146
accs = if include_log_probs
@@ -158,6 +172,30 @@ function ParamsWithStats(
158172
return ParamsWithStats(params, stats)
159173
end
160174

175+
# Specialisation for when the LDF is known to have all fixed transforms
176+
function ParamsWithStats(
177+
param_vector::AbstractVector,
178+
ldf::LogDensityFunction{M,A,L,F,V,D,X,C,true},
179+
stats::NamedTuple=NamedTuple();
180+
include_colon_eq::Bool=true,
181+
include_log_probs::Bool=true,
182+
) where {M,A,L,F,V,D,X,C}
183+
if include_log_probs || include_colon_eq
184+
return pws_with_eval(param_vector, ldf, stats; include_colon_eq, include_log_probs)
185+
end
186+
# Fast path: extract raw values directly from the parameter vector using the fixed
187+
# transforms, without re-evaluating the model.
188+
params = VarNamedTuple()
189+
for (vn, rat) in pairs(ldf._varname_ranges)
190+
top_sym = AbstractPPL.getsym(vn)
191+
template = get(ldf._varname_ranges.data, top_sym, DynamicPPL.NoTemplate())
192+
raw_val = rat.transform.transform(param_vector[rat.range])
193+
params = DynamicPPL.templated_setindex!!(params, raw_val, vn, template)
194+
end
195+
params = densify!!(params)
196+
return ParamsWithStats(params, stats)
197+
end
198+
161199
function Base.show(io::IO, ::MIME"text/plain", pws::ParamsWithStats)
162200
printstyled(io, "ParamsWithStats"; bold=true)
163201
print(io, "\n ├─ ")

src/logdensityfunction.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ struct LogDensityFunction{
189189
# type of the vector passed to logdensity functions
190190
X<:AbstractVector,
191191
AC<:AccumulatorTuple,
192+
# whether all transforms are FixedTransforms, enabling fast parameter extraction
193+
AllFixed,
192194
}
193195
model::M
194196
adtype::AD
@@ -231,6 +233,12 @@ struct LogDensityFunction{
231233
end
232234
ranges_and_transforms = get_rangeandtransforms(vnt)
233235

236+
# Determine whether all transforms are fixed. This enables fast parameter
237+
# extraction in ParamsWithStats without model re-evaluation.
238+
all_fixed = all(
239+
rat -> rat.transform isa FixedTransform, values(ranges_and_transforms)
240+
)
241+
234242
# Get vectorised parameters. Note that `internal_values_as_vector` just concatenates
235243
# all the vectors inside in iteration order of the VNT's keys. *In principle*, the
236244
# result of that should always be consistent with the ranges extracted above via
@@ -268,6 +276,7 @@ struct LogDensityFunction{
268276
typeof(prep),
269277
typeof(x),
270278
typeof(accs),
279+
all_fixed,
271280
}(
272281
model,
273282
adtype,

test/chains.jl

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ end
7373
@testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS
7474
@testset "$transform_strategy" for transform_strategy in (UnlinkAll(), LinkAll())
7575
# Get the ParamsWithStats using LogDensityFunction
76-
ldf = DynamicPPL.LogDensityFunction(m, getlogjoint, transform_strategy)
76+
ldf = LogDensityFunction(m, getlogjoint, transform_strategy)
7777
param_vector = rand(ldf)
7878
# This will give us a VNT of values.params`.
7979
actual_vnt = ParamsWithStats(param_vector, ldf).params
@@ -93,6 +93,54 @@ end
9393
end
9494
end
9595

96+
@testset "ParamsWithStats from LogDensityFunction with fixed transforms" begin
97+
# Note: can't use ALL_MODELS here because that contains a model with dynamic transforms,
98+
# which would yield incorrect results with fix_transforms.
99+
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
100+
@testset "$transform_strategy" for transform_strategy in (UnlinkAll(), LinkAll())
101+
ldf_fixed = LogDensityFunction(
102+
m, getlogjoint_internal, transform_strategy; fix_transforms=true
103+
)
104+
ldf_dynamic = LogDensityFunction(m, getlogjoint_internal, transform_strategy)
105+
param_vector = rand(ldf_fixed)
106+
107+
# Fast path (no log probs, no colon eq): should match the model-evaluation path
108+
fast = ParamsWithStats(
109+
param_vector, ldf_fixed; include_log_probs=false, include_colon_eq=false
110+
)
111+
slow = ParamsWithStats(
112+
param_vector, ldf_dynamic; include_log_probs=false, include_colon_eq=false
113+
)
114+
@test fast == slow
115+
end
116+
end
117+
118+
@testset "check that model is actually not evaluated" begin
119+
should_error = false
120+
@model function prickly()
121+
x ~ Normal()
122+
return should_error && error("nope")
123+
end
124+
# need to construct LDF without erroring
125+
ldf = LogDensityFunction(
126+
prickly(), getlogjoint_internal, LinkAll(); fix_transforms=true
127+
)
128+
# now make the model error
129+
should_error = true
130+
@test_throws ErrorException prickly()()
131+
# check that ParamsWithStats doesn't error
132+
@test ParamsWithStats(
133+
[0.5], ldf; include_log_probs=false, include_colon_eq=false
134+
) isa Any
135+
# but it does if you set either of them to true
136+
for (ilp, ice) in ((true, false), (false, true), (true, true))
137+
@test_throws ErrorException ParamsWithStats(
138+
[0.5], ldf; include_log_probs=ilp, include_colon_eq=ice
139+
)
140+
end
141+
end
142+
end
143+
96144
@info "Completed $(@__FILE__) in $(now() - __now__)."
97145

98146
end # module

test/logdensityfunction.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,12 +549,21 @@ end
549549
end
550550

551551
@testset "LogDensityFunction: fix_transforms correctness" begin
552+
# Helper function to check whether the AllFixed type parameter on LogDensityFunction is
553+
# set correctly.
554+
function has_all_fixed_transforms(
555+
::LogDensityFunction{M,A,L,F,V,D,X,C,AF}
556+
) where {M,A,L,F,V,D,X,C,AF}
557+
return AF
558+
end
559+
552560
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
553561
@testset "$strategy" for strategy in (UnlinkAll(), LinkAll())
554562
ldf_dynamic = LogDensityFunction(m, getlogjoint_internal, strategy)
555563
ldf_fixed = LogDensityFunction(
556564
m, getlogjoint_internal, strategy; fix_transforms=true
557565
)
566+
@test has_all_fixed_transforms(ldf_fixed)
558567
# Check that the transform strategy does contain fixed transforms
559568
tfm_strategy = ldf_fixed.transform_strategy
560569
@test tfm_strategy isa WithTransforms

0 commit comments

Comments
 (0)