Skip to content

Commit a695690

Browse files
Generalise Container Types (#151)
* Add handling for PD errors * Generalise types for containers and add ReferenceTrajectory type * Formatter suggestions Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 6c283f7 commit a695690

12 files changed

Lines changed: 532 additions & 247 deletions

File tree

GeneralisedFilters/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1616
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
1717
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1818
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
19-
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
2019
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
2120
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2221
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
@@ -56,7 +55,6 @@ LogDensityProblemsAD = "1.13.1"
5655
LogExpFunctions = "0.3"
5756
MCMCChains = "7.7.0"
5857
Mooncake = "0.5.13"
59-
OffsetArrays = "1.14.1"
6058
PDMats = "0.11.35"
6159
SSMProblems = "0.6"
6260
StaticArrays = "1.9.17"

GeneralisedFilters/ext/CUDAExt.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ using GeneralisedFilters:
1919
PostUpdateCallback,
2020
num_particles
2121

22-
using GeneralisedFilters.OffsetArrays: OffsetVector
22+
using GeneralisedFilters: ReferenceTrajectory
2323

2424
using AcceleratedKernels: searchsortedfirst, foreachindex
2525
using CUDA
@@ -226,29 +226,37 @@ function expand!(tree::ParallelParticleTree{T}) where {T}
226226
return tree
227227
end
228228

229-
# Get ancestry of all particles
229+
# Get ancestry of all particles. The parallel tree stores initial and subsequent states
230+
# in a single `states` buffer (homogeneous type), so the returned trajectories have
231+
# T0 == T.
230232
function get_ancestry(tree::ParallelParticleTree{ST}, T::Integer) where {ST}
231-
paths = OffsetVector(Vector{ST}(undef, T + 1), -1)
233+
buf = Vector{Vector{ST}}(undef, T + 1)
232234
parents = copy(tree.leaves)
233-
for t in T:-1:1
234-
paths[t] = tree.states[parents]
235+
for t in (T + 1):-1:2
236+
buf[t] = Vector(tree.states[parents])
235237
gather!(parents, tree.parents, parents)
236238
end
237-
return paths
239+
buf[1] = Vector(tree.states[parents])
240+
# Each leaf's trajectory: x0 = buf[1][k], xs = [buf[2][k], ..., buf[T+1][k]]
241+
return [
242+
ReferenceTrajectory(buf[1][k], [buf[t + 1][k] for t in 1:T]) for
243+
k in eachindex(tree.leaves)
244+
]
238245
end
239246

240247
# Get ancestry of a single particle
241248
function get_ancestry(
242249
container::ParallelParticleTree{ST}, i::Integer, T::Integer
243250
) where {ST}
244-
path = OffsetVector(Vector{ST}(undef, T + 1), -1)
251+
xs = Vector{ST}(undef, T)
245252
CUDA.@allowscalar begin
246253
ancestor_index = container.leaves[i]
247-
for t in T:-1:0
248-
path[t] = container.states[ancestor_index]
254+
for t in T:-1:1
255+
xs[t] = container.states[ancestor_index]
249256
ancestor_index = container.parents[ancestor_index]
250257
end
251-
return path
258+
x0 = container.states[ancestor_index]
259+
return ReferenceTrajectory(x0, xs)
252260
end
253261
end
254262

GeneralisedFilters/ext/TuringExt.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ using DynamicPPL: DynamicPPL
2020
using LinearAlgebra: PosDefException
2121
using LogDensityProblems: LogDensityProblems
2222
using MCMCChains: MCMCChains
23-
using OffsetArrays
2423
using Random: AbstractRNG
2524
using SSMProblems
2625
using Turing: Turing

GeneralisedFilters/src/GeneralisedFilters.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,15 @@ using ADTypes: ADTypes
55
import Distributions: MvNormal, params
66
import Random: AbstractRNG, default_rng, rand
77
import SSMProblems: prior, dyn, obs
8-
using OffsetArrays
98
using SSMProblems
109
using StatsBase
1110
using DifferentiationInterface
1211

1312
const DI = DifferentiationInterface
1413

1514
# Filtering utilities
16-
include("callbacks.jl")
1715
include("containers.jl")
16+
include("callbacks.jl")
1817
include("resamplers.jl")
1918

2019
## FILTERING BASE ##########################################################################

GeneralisedFilters/src/algorithms/csmc.jl

Lines changed: 75 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ ConditionalSMC(pf) = ConditionalSMC(pf, NoRefreshment())
7272
7373
State of a conditional SMC sampler, containing the current reference trajectory.
7474
75-
The trajectory is an `OffsetVector` indexed from 0 (matching the prior at time 0).
76-
For RBPF, the trajectory contains `RBState` objects (outer state + inner filtering
75+
The trajectory is a [`ReferenceTrajectory`](@ref) indexed from 0 (matching the prior at
76+
time 0). For RBPF, the trajectory contains `RBState` objects (outer state + inner filtering
7777
distribution).
7878
"""
7979
struct CSMCState{TT}
@@ -100,10 +100,11 @@ end
100100
# CSMCState stores full trajectories (RBState for RBPF). The filter/initialise/move
101101
# functions expect ref_state to contain only outer states for RBPF. _make_ref_state
102102
# handles this conversion.
103-
# TODO: is this actually needed?
104103
_make_ref_state(::Nothing) = nothing
105-
_make_ref_state(traj) = traj
106-
_make_ref_state(traj::OffsetVector{<:RBState}) = map(s -> s.x, traj)
104+
_make_ref_state(traj::ReferenceTrajectory) = traj
105+
function _make_ref_state(traj::ReferenceTrajectory{<:RBState})
106+
return map(s -> s.x, traj)
107+
end
107108

108109
## TRAJECTORY SAMPLING #####################################################################
109110

@@ -119,16 +120,34 @@ function _sample_trajectory(
119120
rng::AbstractRNG, tree::ParticleTree, state::ParticleDistribution
120121
)
121122
ws = get_weights(state)
122-
path = rand(rng, tree, ws)
123-
return OffsetVector(path, -1)
123+
return rand(rng, tree, ws)
124124
end
125125

126-
## PARTICLE TREE HELPERS ###################################################################
126+
## PARTICLE TREE / CONTAINER HELPERS ######################################################
127+
128+
# Capacity heuristic from Jacob, Murray & Rubenthaler (2015)
129+
_tree_capacity(N::Integer) = max(N, floor(Int64, N * log(N)))
130+
131+
# Construct a ParticleTree using both the time-0 and time-1 particle distributions so
132+
# that the subsequent-state type `T` is inferred from the time-1 states (which may
133+
# differ from the type of the initial states in Rao-Blackwellised settings).
134+
function _init_tree(init_state::ParticleDistribution, state::ParticleDistribution)
135+
initial_states = map(p -> p.state, init_state.particles)
136+
states_t1 = map(p -> p.state, state.particles)
137+
ancestors_t1 = map(p -> p.ancestor, state.particles)
138+
return ParticleTree(
139+
initial_states, states_t1, ancestors_t1, _tree_capacity(length(initial_states))
140+
)
141+
end
127142

128-
function _init_tree(state::ParticleDistribution)
129-
states = map(p -> p.state, state.particles)
130-
N = length(states)
131-
return ParticleTree(states, max(N, floor(Int64, N * log(N))))
143+
function _init_container(init_state::ParticleDistribution, state::ParticleDistribution)
144+
initial_states = map(p -> p.state, init_state.particles)
145+
return DenseParticleContainer(
146+
initial_states,
147+
map(p -> p.state, state.particles),
148+
Float64.(log_weights(state)),
149+
map(p -> p.ancestor, state.particles),
150+
)
132151
end
133152

134153
function _update_tree!(tree::ParticleTree, state::ParticleDistribution)
@@ -140,6 +159,17 @@ function _update_tree!(tree::ParticleTree, state::ParticleDistribution)
140159
return tree
141160
end
142161

162+
function _update_container!(c::DenseParticleContainer, state::ParticleDistribution)
163+
particles = state.particles
164+
push!(
165+
c,
166+
map(p -> p.state, particles),
167+
Float64.(log_weights(state)),
168+
map(p -> p.ancestor, particles),
169+
)
170+
return c
171+
end
172+
143173
## BACKWARD PREDICTIVE LIKELIHOODS #########################################################
144174

145175
# Default: no backward likelihoods needed (regular PF, or first iteration)
@@ -217,11 +247,9 @@ function _csmc_sample(
217247
K = length(observations)
218248
ref_state = _make_ref_state(ref_traj)
219249

220-
state = initialise(rng, prior(model), pf; ref_state)
221-
tree = _init_tree(state)
222-
223-
state, ll = step(rng, model, pf, 1, state, observations[1]; ref_state)
224-
_update_tree!(tree, state)
250+
init_state = initialise(rng, prior(model), pf; ref_state)
251+
state, ll = step(rng, model, pf, 1, init_state, observations[1]; ref_state)
252+
tree = _init_tree(init_state, state)
225253

226254
for t in 2:K
227255
state, ll_inc = step(rng, model, pf, t, state, observations[t]; ref_state)
@@ -247,19 +275,17 @@ function _csmc_sample(
247275
# Backward predictive likelihoods (only non-nothing for RBPF)
248276
back_liks = _compute_backward_likelihoods(rng, model, pf, observations, ref_state)
249277

250-
state = initialise(rng, prior(model), pf; ref_state)
251-
tree = _init_tree(state)
278+
init_state = initialise(rng, prior(model), pf; ref_state)
252279

253-
ll = zero(eltype(log_weights(state)))
254-
for t in 1:K
255-
# Ancestor sampling for the reference particle
280+
# Perform one CSMC-AS step on the current state
281+
function _csmc_as_step(state, t)
282+
ancestor_idx = 0
256283
if !isnothing(ref_state)
257284
ref_as = _build_ancestor_ref(ref_state, back_liks, t)
258285
as_weights = map(state.particles) do particle
259286
ancestor_weight(particle, dyn(model), pf, t, ref_as)
260287
end
261288
ancestor_idx = StatsBase.sample(rng, StatsBase.Weights(softmax(as_weights)))
262-
263289
end
264290

265291
state = resample(rng, resampler(pf), state; ref_state)
@@ -270,9 +296,15 @@ function _csmc_sample(
270296
)
271297
end
272298

273-
state, ll_inc = move(rng, model, pf, t, state, observations[t]; ref_state)
274-
ll += ll_inc
299+
return move(rng, model, pf, t, state, observations[t]; ref_state)
300+
end
301+
302+
state, ll = _csmc_as_step(init_state, 1)
303+
tree = _init_tree(init_state, state)
275304

305+
for t in 2:K
306+
state, ll_inc = _csmc_as_step(state, t)
307+
ll += ll_inc
276308
_update_tree!(tree, state)
277309
end
278310

@@ -366,55 +398,51 @@ function _csmc_sample(
366398
)
367399
pf = csmc.pf
368400
K = length(observations)
401+
N = num_particles(pf)
369402
ref_state = _make_ref_state(ref_traj)
370403

371-
# Forward filtering pass, storing particles at each timestep
404+
# Forward filtering pass: store full history in a DenseParticleContainer.
372405
init_state = initialise(rng, prior(model), pf; ref_state)
373-
init_particles = copy(init_state.particles)
374-
375406
state, ll = step(rng, model, pf, 1, init_state, observations[1]; ref_state)
376-
particle_history = Vector{typeof(state.particles)}(undef, K)
377-
particle_history[1] = copy(state.particles)
407+
container = _init_container(init_state, state)
378408

379409
for t in 2:K
380410
state, ll_inc = step(rng, model, pf, t, state, observations[t]; ref_state)
381411
ll += ll_inc
382-
particle_history[t] = copy(state.particles)
412+
_update_container!(container, state)
383413
end
384414

385415
# Backward simulation pass
386-
ws = get_weights(state)
387-
idx = StatsBase.sample(rng, StatsBase.Weights(ws))
388-
sampled_state = particle_history[K][idx].state
416+
idx = StatsBase.sample(rng, StatsBase.Weights(get_weights(state)))
417+
sampled_state = container.states[K][idx]
389418

390419
back_lik = _bs_init_back_lik(rng, model, pf, observations, K, sampled_state)
391420

392-
ST = typeof(sampled_state)
393-
trajectory = OffsetVector(Vector{ST}(undef, K + 1), -1)
394-
trajectory[K] = sampled_state
421+
xs = Vector{typeof(sampled_state)}(undef, K)
422+
xs[K] = sampled_state
395423

396424
for t in (K - 1):-1:1
397-
ref_next = _build_bs_ref(trajectory[t + 1], back_lik)
398-
backward_ws = map(particle_history[t]) do particle
399-
ancestor_weight(particle, dyn(model), pf, t + 1, ref_next)
425+
ref_next = _build_bs_ref(xs[t + 1], back_lik)
426+
backward_ws = map(1:N) do i
427+
ancestor_weight(Particle(container, t, i), dyn(model), pf, t + 1, ref_next)
400428
end
401429
idx = StatsBase.sample(rng, StatsBase.Weights(softmax(backward_ws)))
402-
trajectory[t] = particle_history[t][idx].state
430+
xs[t] = container.states[t][idx]
403431

404432
back_lik = _bs_step_back_lik(
405-
rng, model, pf, t, back_lik, observations, trajectory[t], trajectory[t + 1]
433+
rng, model, pf, t, back_lik, observations, xs[t], xs[t + 1]
406434
)
407435
end
408436

409-
# Time 0: backward step from t=1 to initial particles
410-
ref_at_1 = _build_bs_ref(trajectory[1], back_lik)
411-
backward_ws = map(init_particles) do particle
437+
# Time 0: backward step from t=1 to initial particles.
438+
ref_at_1 = _build_bs_ref(xs[1], back_lik)
439+
backward_ws = map(init_state.particles) do particle
412440
ancestor_weight(particle, dyn(model), pf, 1, ref_at_1)
413441
end
414442
idx = StatsBase.sample(rng, StatsBase.Weights(softmax(backward_ws)))
415-
trajectory[0] = init_particles[idx].state
443+
x0 = container.initial_states[idx]
416444

417-
return trajectory, ll
445+
return ReferenceTrajectory(x0, xs), ll
418446
end
419447

420448
## ABSTRACTMCMC INTERFACE ##################################################################

0 commit comments

Comments
 (0)