@@ -72,8 +72,8 @@ ConditionalSMC(pf) = ConditionalSMC(pf, NoRefreshment())
7272
7373State of a conditional SMC sampler, containing the current reference trajectory.
7474
75- The trajectory is an `OffsetVector` indexed from 0 (matching the prior at time 0).
76- For RBPF, the trajectory contains `RBState` objects (outer state + inner filtering
75+ The trajectory is a [`ReferenceTrajectory`](@ref) indexed from 0 (matching the prior at
76+ time 0). For RBPF, the trajectory contains `RBState` objects (outer state + inner filtering
7777distribution).
7878"""
7979struct CSMCState{TT}
@@ -100,10 +100,11 @@ end
100100# CSMCState stores full trajectories (RBState for RBPF). The filter/initialise/move
101101# functions expect ref_state to contain only outer states for RBPF. _make_ref_state
102102# handles this conversion.
103- # TODO : is this actually needed?
104103_make_ref_state (:: Nothing ) = nothing
105- _make_ref_state (traj) = traj
106- _make_ref_state (traj:: OffsetVector{<:RBState} ) = map (s -> s. x, traj)
104+ _make_ref_state (traj:: ReferenceTrajectory ) = traj
105+ function _make_ref_state (traj:: ReferenceTrajectory{<:RBState} )
106+ return map (s -> s. x, traj)
107+ end
107108
108109# # TRAJECTORY SAMPLING #####################################################################
109110
@@ -119,16 +120,34 @@ function _sample_trajectory(
119120 rng:: AbstractRNG , tree:: ParticleTree , state:: ParticleDistribution
120121)
121122 ws = get_weights (state)
122- path = rand (rng, tree, ws)
123- return OffsetVector (path, - 1 )
123+ return rand (rng, tree, ws)
124124end
125125
126- # # PARTICLE TREE HELPERS ###################################################################
126+ # # PARTICLE TREE / CONTAINER HELPERS ######################################################
127+
128+ # Capacity heuristic from Jacob, Murray & Rubenthaler (2015)
129+ _tree_capacity (N:: Integer ) = max (N, floor (Int64, N * log (N)))
130+
131+ # Construct a ParticleTree using both the time-0 and time-1 particle distributions so
132+ # that the subsequent-state type `T` is inferred from the time-1 states (which may
133+ # differ from the type of the initial states in Rao-Blackwellised settings).
134+ function _init_tree (init_state:: ParticleDistribution , state:: ParticleDistribution )
135+ initial_states = map (p -> p. state, init_state. particles)
136+ states_t1 = map (p -> p. state, state. particles)
137+ ancestors_t1 = map (p -> p. ancestor, state. particles)
138+ return ParticleTree (
139+ initial_states, states_t1, ancestors_t1, _tree_capacity (length (initial_states))
140+ )
141+ end
127142
128- function _init_tree (state:: ParticleDistribution )
129- states = map (p -> p. state, state. particles)
130- N = length (states)
131- return ParticleTree (states, max (N, floor (Int64, N * log (N))))
143+ function _init_container (init_state:: ParticleDistribution , state:: ParticleDistribution )
144+ initial_states = map (p -> p. state, init_state. particles)
145+ return DenseParticleContainer (
146+ initial_states,
147+ map (p -> p. state, state. particles),
148+ Float64 .(log_weights (state)),
149+ map (p -> p. ancestor, state. particles),
150+ )
132151end
133152
134153function _update_tree! (tree:: ParticleTree , state:: ParticleDistribution )
@@ -140,6 +159,17 @@ function _update_tree!(tree::ParticleTree, state::ParticleDistribution)
140159 return tree
141160end
142161
162+ function _update_container! (c:: DenseParticleContainer , state:: ParticleDistribution )
163+ particles = state. particles
164+ push! (
165+ c,
166+ map (p -> p. state, particles),
167+ Float64 .(log_weights (state)),
168+ map (p -> p. ancestor, particles),
169+ )
170+ return c
171+ end
172+
143173# # BACKWARD PREDICTIVE LIKELIHOODS #########################################################
144174
145175# Default: no backward likelihoods needed (regular PF, or first iteration)
@@ -217,11 +247,9 @@ function _csmc_sample(
217247 K = length (observations)
218248 ref_state = _make_ref_state (ref_traj)
219249
220- state = initialise (rng, prior (model), pf; ref_state)
221- tree = _init_tree (state)
222-
223- state, ll = step (rng, model, pf, 1 , state, observations[1 ]; ref_state)
224- _update_tree! (tree, state)
250+ init_state = initialise (rng, prior (model), pf; ref_state)
251+ state, ll = step (rng, model, pf, 1 , init_state, observations[1 ]; ref_state)
252+ tree = _init_tree (init_state, state)
225253
226254 for t in 2 : K
227255 state, ll_inc = step (rng, model, pf, t, state, observations[t]; ref_state)
@@ -247,19 +275,17 @@ function _csmc_sample(
247275 # Backward predictive likelihoods (only non-nothing for RBPF)
248276 back_liks = _compute_backward_likelihoods (rng, model, pf, observations, ref_state)
249277
250- state = initialise (rng, prior (model), pf; ref_state)
251- tree = _init_tree (state)
278+ init_state = initialise (rng, prior (model), pf; ref_state)
252279
253- ll = zero ( eltype ( log_weights ( state)))
254- for t in 1 : K
255- # Ancestor sampling for the reference particle
280+ # Perform one CSMC-AS step on the current state
281+ function _csmc_as_step (state, t)
282+ ancestor_idx = 0
256283 if ! isnothing (ref_state)
257284 ref_as = _build_ancestor_ref (ref_state, back_liks, t)
258285 as_weights = map (state. particles) do particle
259286 ancestor_weight (particle, dyn (model), pf, t, ref_as)
260287 end
261288 ancestor_idx = StatsBase. sample (rng, StatsBase. Weights (softmax (as_weights)))
262-
263289 end
264290
265291 state = resample (rng, resampler (pf), state; ref_state)
@@ -270,9 +296,15 @@ function _csmc_sample(
270296 )
271297 end
272298
273- state, ll_inc = move (rng, model, pf, t, state, observations[t]; ref_state)
274- ll += ll_inc
299+ return move (rng, model, pf, t, state, observations[t]; ref_state)
300+ end
301+
302+ state, ll = _csmc_as_step (init_state, 1 )
303+ tree = _init_tree (init_state, state)
275304
305+ for t in 2 : K
306+ state, ll_inc = _csmc_as_step (state, t)
307+ ll += ll_inc
276308 _update_tree! (tree, state)
277309 end
278310
@@ -366,55 +398,51 @@ function _csmc_sample(
366398)
367399 pf = csmc. pf
368400 K = length (observations)
401+ N = num_particles (pf)
369402 ref_state = _make_ref_state (ref_traj)
370403
371- # Forward filtering pass, storing particles at each timestep
404+ # Forward filtering pass: store full history in a DenseParticleContainer.
372405 init_state = initialise (rng, prior (model), pf; ref_state)
373- init_particles = copy (init_state. particles)
374-
375406 state, ll = step (rng, model, pf, 1 , init_state, observations[1 ]; ref_state)
376- particle_history = Vector {typeof(state.particles)} (undef, K)
377- particle_history[1 ] = copy (state. particles)
407+ container = _init_container (init_state, state)
378408
379409 for t in 2 : K
380410 state, ll_inc = step (rng, model, pf, t, state, observations[t]; ref_state)
381411 ll += ll_inc
382- particle_history[t] = copy ( state. particles )
412+ _update_container! (container, state)
383413 end
384414
385415 # Backward simulation pass
386- ws = get_weights (state)
387- idx = StatsBase. sample (rng, StatsBase. Weights (ws))
388- sampled_state = particle_history[K][idx]. state
416+ idx = StatsBase. sample (rng, StatsBase. Weights (get_weights (state)))
417+ sampled_state = container. states[K][idx]
389418
390419 back_lik = _bs_init_back_lik (rng, model, pf, observations, K, sampled_state)
391420
392- ST = typeof (sampled_state)
393- trajectory = OffsetVector (Vector {ST} (undef, K + 1 ), - 1 )
394- trajectory[K] = sampled_state
421+ xs = Vector {typeof(sampled_state)} (undef, K)
422+ xs[K] = sampled_state
395423
396424 for t in (K - 1 ): - 1 : 1
397- ref_next = _build_bs_ref (trajectory [t + 1 ], back_lik)
398- backward_ws = map (particle_history[t] ) do particle
399- ancestor_weight (particle , dyn (model), pf, t + 1 , ref_next)
425+ ref_next = _build_bs_ref (xs [t + 1 ], back_lik)
426+ backward_ws = map (1 : N ) do i
427+ ancestor_weight (Particle (container, t, i) , dyn (model), pf, t + 1 , ref_next)
400428 end
401429 idx = StatsBase. sample (rng, StatsBase. Weights (softmax (backward_ws)))
402- trajectory [t] = particle_history [t][idx]. state
430+ xs [t] = container . states [t][idx]
403431
404432 back_lik = _bs_step_back_lik (
405- rng, model, pf, t, back_lik, observations, trajectory [t], trajectory [t + 1 ]
433+ rng, model, pf, t, back_lik, observations, xs [t], xs [t + 1 ]
406434 )
407435 end
408436
409- # Time 0: backward step from t=1 to initial particles
410- ref_at_1 = _build_bs_ref (trajectory [1 ], back_lik)
411- backward_ws = map (init_particles ) do particle
437+ # Time 0: backward step from t=1 to initial particles.
438+ ref_at_1 = _build_bs_ref (xs [1 ], back_lik)
439+ backward_ws = map (init_state . particles ) do particle
412440 ancestor_weight (particle, dyn (model), pf, 1 , ref_at_1)
413441 end
414442 idx = StatsBase. sample (rng, StatsBase. Weights (softmax (backward_ws)))
415- trajectory[ 0 ] = init_particles [idx]. state
443+ x0 = container . initial_states [idx]
416444
417- return trajectory , ll
445+ return ReferenceTrajectory (x0, xs) , ll
418446end
419447
420448# # ABSTRACTMCMC INTERFACE ##################################################################
0 commit comments