@@ -12,10 +12,6 @@ export pf_introduce!
1212Resizes a particle filter by resampling existing particles until a total of
1313`n_particles` have been sampled. The resampling method can optionally be
1414specified: `:multinomial` (default), `:residual` or `:optimal`.
15-
16- A `priority_fn` can also be specified as a keyword argument, which maps log
17- particle weights to custom log priority scores for the purpose of resampling
18- (e.g. `w -> w/2` for less aggressive pruning).
1915"""
2016function pf_resize! (state:: ParticleFilterState , n_particles:: Int ,
2117 method:: Symbol = :multinomial ; kwargs... )
@@ -38,22 +34,30 @@ Resizes a particle filter through multinomial resampling (i.e. simple random
3834resampling) of existing particles until `n_particles` are sampled. Each trace
3935(i.e. particle) is resampled with probability proportional to its weight.
4036
41- A `priority_fn` can be specified as a keyword argument, which maps log particle
42- weights to custom log priority scores for the purpose of resampling
43- (e.g. `w -> w/2` for less aggressive pruning).
37+ # Keyword Arguments
38+
39+ - `priority_fn = nothing`: An optional function that maps particle weights to
40+ custom log priority scores (e.g. `w -> w/2` for less aggressive pruning).
41+ - `check = :warn`: Set to `true` to throw an error for invalid normalized
42+ weights (all NaNs or zeros), `:warn` to issue warnings, or `false` to
43+ suppress checks. In the latter two cases, zero weights will be renormalized
44+ to uniform weights for resampling.
4445"""
4546function pf_multinomial_resize! (state:: ParticleFilterState , n_particles:: Int ;
46- priority_fn= nothing )
47- # Update estimate of log marginal likelihood
48- update_lml_est! (state)
47+ priority_fn= nothing , check= :warn )
4948 # Compute priority scores if priority function is provided
5049 log_priorities = priority_fn === nothing ?
5150 state. log_weights : priority_fn .(state. log_weights)
51+ # Normalize weights and check their validity
52+ weights = softmax (log_priorities)
53+ weights, invalid = safe_softmax (log_priorities, warn = (check != false ))
54+ check == true && invalid && error (" Invalid weights." )
55+ # Update estimate of log marginal likelihood
56+ update_lml_est! (state)
5257 # Resize arrays
5358 resize! (state. parents, n_particles)
5459 resize! (state. new_traces, n_particles)
5560 # Resample new traces according to current normalized weights
56- weights = softmax (log_priorities)
5761 rand! (Categorical (weights), state. parents)
5862 state. new_traces .= view (state. traces, state. parents)
5963 # Reweight particles and update trace references
@@ -63,30 +67,38 @@ function pf_multinomial_resize!(state::ParticleFilterState, n_particles::Int;
6367end
6468
6569"""
66- pf_residual_resize!(state::ParticleFilterState, n_particles::Int; kwargs...)
70+ pf_residual_resize!(state::ParticleFilterState, n_particles::Int;
71+ kwargs...)
6772
6873Resizes a particle filter through residual resampling of existing particles.
6974For each particle with normalized weight ``w_i``, ``⌊n w_i⌋`` copies are
7075resampled, where ``n`` is `n_particles`. The remainder are sampled with
7176probability proportional to ``n w_i - ⌊n w_i⌋`` for each particle ``i``.
7277
73- A `priority_fn` can be specified as a keyword argument, which maps log particle
74- weights to custom log priority scores for the purpose of resampling
75- (e.g. `w -> w/2` for less aggressive pruning).
78+ # Keyword Arguments
79+
80+ - `priority_fn = nothing`: An optional function that maps particle weights to
81+ custom log priority scores (e.g. `w -> w/2` for less aggressive pruning).
82+ - `check = :warn`: Set to `true` to throw an error for invalid normalized
83+ weights (all NaNs or zeros), `:warn` to issue warnings, or `false` to
84+ suppress checks. In the latter two cases, zero weights will be renormalized
85+ to uniform weights for resampling.
7686"""
7787function pf_residual_resize! (state:: ParticleFilterState , n_particles:: Int ;
78- priority_fn= nothing )
79- # Update estimate of log marginal likelihood
80- update_lml_est! (state)
88+ priority_fn= nothing , check= :warn )
8189 # Compute priority scores if priority function is provided
8290 log_priorities = priority_fn === nothing ?
8391 state. log_weights : priority_fn .(state. log_weights)
92+ # Normalize weights and check their validity
93+ weights, invalid = safe_softmax (log_priorities, warn = (check != false ))
94+ check == true && invalid && error (" Invalid weights." )
95+ # Update estimate of log marginal likelihood
96+ update_lml_est! (state)
8497 # Resize arrays
8598 resize! (state. parents, n_particles)
8699 resize! (state. new_traces, n_particles)
87100 # Deterministically copy previous particles according to their weights
88101 n_resampled = 0
89- weights = softmax (log_priorities)
90102 for (i, w) in enumerate (weights)
91103 n_copies = floor (Int, n_particles * w)
92104 if n_copies == 0 continue end
@@ -122,34 +134,38 @@ the variance of the resulting weight distribution with respect to the original
122134weight distribution. Note that `n_particles` should not be greater than the
123135current number of particles.
124136
137+ # Keyword Arguments
138+
139+ - `check = :warn`: Set to `true` to throw an error for invalid normalized
140+ weights (all NaNs or zeros), `:warn` to issue warnings, or `false` to
141+ suppress checks. In the latter two cases, zero weights will be renormalized
142+ to uniform weights for resampling.
143+
125144[1] Paul Fearnhead , Peter Clifford, On-Line Inference for Hidden Markov Models
126145via Particle Filters, Journal of the Royal Statistical Society Series B:
127146Statistical Methodology, Volume 65, Issue 4, November 2003, Pages 887–899,
128147https://doi.org/10.1111/1467-9868.00421
129148"""
130149function pf_optimal_resize! (state:: ParticleFilterState , n_particles:: Int ;
131- kwargs... )
132- # Resize arrays
133- n_old = length (state. traces)
134- @assert n_particles <= n_old
135- resize! (state. parents, n_particles)
136- resize! (state. new_traces, n_particles)
137- # Normalize weights and compute inverse weight threshold
138- weights = softmax (state. log_weights)
139- inv_w_thresh = find_inv_w_threshold (weights, n_particles)
150+ check = :warn , kwargs... )
151+ # Normalize weights and check their validity
152+ weights, invalid = safe_softmax (state. log_weights, warn = (check != false ))
153+ check == true && invalid && error (" Invalid weights." )
140154 # Find particles to keep deterministically vs. resample with stratification
155+ inv_w_thresh = find_inv_w_threshold (weights, n_particles)
141156 keep_idxs = (inv_w_thresh .* weights .>= 1 )
142157 strat_idxs = .! (keep_idxs)
143158 # Keep selected indices
144159 keep_idxs = findall (keep_idxs)
145160 n_keep = length (keep_idxs)
146- state. parents[1 : n_keep] .= keep_idxs
147- # Perform stratified resampling on remaining indices
161+ # Perform stratified resampling on unselected indices
148162 n_resample = n_particles - n_keep
149163 resample_idxs = Int[]
150164 strat_idxs = findall (strat_idxs)
151165 n_strat = length (strat_idxs)
152- norm_strat_weights = softmax (state. log_weights[strat_idxs])
166+ norm_strat_weights, invalid =
167+ safe_softmax (state. log_weights[strat_idxs], warn = (check != false ))
168+ check == true && invalid && error (" Invalid weights." )
153169 # Compute resampled indices
154170 step_size = 1 / n_resample
155171 u = rand () * step_size
@@ -160,9 +176,15 @@ function pf_optimal_resize!(state::ParticleFilterState, n_particles::Int;
160176 u += step_size
161177 end
162178 end
163- # Keep resampled indices
179+ # Assign parent indices
180+ state. parents[1 : n_keep] .= keep_idxs
164181 @assert length (resample_idxs) == n_resample
165182 state. parents[n_keep+ 1 : n_particles] .= resample_idxs
183+ # Resize arrays
184+ n_old = length (state. traces)
185+ @assert n_particles <= n_old
186+ resize! (state. parents, n_particles)
187+ resize! (state. new_traces, n_particles)
166188 # Update weights
167189 log_n_ratio = log (n_particles) - log (n_old)
168190 log_tot_weight = logsumexp (state. log_weights)
0 commit comments