Skip to content

Commit 00bda0b

Browse files
committed
Support weight validity checking when resampling.
1 parent b786db4 commit 00bda0b

5 files changed

Lines changed: 182 additions & 60 deletions

File tree

src/resample.jl

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@ The resampling method can optionally be specified: `:multinomial` (default),
1212
`:residual`, or `:stratified`. See [1] for a survey of resampling methods
1313
and their variance properties.
1414
15-
A `priority_fn` can also be specified as a keyword argument, which maps log
16-
particle weights to custom log priority scores for the purpose of resampling
17-
(e.g. `w -> w/2` for less aggressive pruning).
18-
1915
[1] R. Douc and O. Cappé, "Comparison of resampling schemes
2016
for particle filtering," in ISPA 2005. Proceedings of the 4th International
2117
Symposium on Image and Signal Processing and Analysis, 2005., 2005, pp. 64–69.
@@ -40,19 +36,26 @@ Performs multinomial resampling (i.e. simple random resampling) of the
4036
particles in the filter. Each trace (i.e. particle) is resampled with
4137
probability proportional to its weight.
4238
43-
A `priority_fn` can be specified as a keyword argument, which maps log particle
44-
weights to custom log priority scores for the purpose of resampling
45-
(e.g. `w -> w/2` for less aggressive pruning).
39+
# Keyword Arguments
40+
41+
- `priority_fn = nothing`: An optional function that maps particle weights to
42+
custom log priority scores (e.g. `w -> w/2` for less aggressive pruning).
43+
- `check = :warn`: Set to `true` to throw an error for invalid normalized
44+
weights (all NaNs or zeros), `:warn` to issue warnings, or `false` to
45+
suppress checks. In the latter two cases, zero weights will be renormalized
46+
to uniform weights for resampling.
4647
"""
4748
function pf_multinomial_resample!(state::ParticleFilterView;
48-
priority_fn=nothing)
49-
# Update estimate of log marginal likelihood
50-
update_lml_est!(state)
49+
priority_fn=nothing, check=:warn)
5150
# Compute priority scores if priority function is provided
5251
log_priorities = priority_fn === nothing ?
5352
state.log_weights : priority_fn.(state.log_weights)
53+
# Normalize weights and check their validity
54+
weights, invalid = safe_softmax(log_priorities, warn = (check != false))
55+
check == true && invalid && error("Invalid weights.")
56+
# Update estimate of log marginal likelihood
57+
update_lml_est!(state)
5458
# Resample new traces according to current normalized weights
55-
weights = softmax(log_priorities)
5659
rand!(Categorical(weights), state.parents)
5760
state.new_traces .= view(state.traces, state.parents)
5861
# Reweight particles and update trace references
@@ -70,21 +73,28 @@ normalized weight ``w_i``, ``⌊n w_i⌋`` copies are resampled, where ``n`` is
7073
total number of particles. The remainder are sampled with probability
7174
proportional to ``n w_i - ⌊n w_i⌋`` for each particle ``i``.
7275
73-
A `priority_fn` can be specified as a keyword argument, which maps log particle
74-
weights to custom log priority scores for the purpose of resampling
75-
(e.g. `w -> w/2` for less aggressive pruning).
76+
# Keyword Arguments
77+
78+
- `priority_fn = nothing`: An optional function that maps particle weights to
79+
custom log priority scores (e.g. `w -> w/2` for less aggressive pruning).
80+
- `check = :warn`: Set to `true` to throw an error for invalid normalized
81+
weights (all NaNs or zeros), `:warn` to issue warnings, or `false` to
82+
suppress checks. In the latter two cases, zero weights will be renormalized
83+
to uniform weights for resampling.
7684
"""
7785
function pf_residual_resample!(state::ParticleFilterView;
78-
priority_fn=nothing)
79-
# Update estimate of log marginal likelihood
80-
update_lml_est!(state)
86+
priority_fn=nothing, check=:warn)
8187
# Compute priority scores if priority function is provided
8288
log_priorities = priority_fn === nothing ?
8389
state.log_weights : priority_fn.(state.log_weights)
90+
# Normalize weights and check their validity
91+
weights, invalid = safe_softmax(log_priorities, warn = (check != false))
92+
check == true && invalid && error("Invalid weights.")
93+
# Update estimate of log marginal likelihood
94+
update_lml_est!(state)
8495
# Deterministically copy previous particles according to their weights
8596
n_resampled = 0
8697
n_particles = length(state.traces)
87-
weights = softmax(log_priorities)
8898
for (i, w) in enumerate(weights)
8999
n_copies = floor(Int, n_particles * w)
90100
if n_copies == 0 continue end
@@ -119,20 +129,28 @@ where ``n`` is the number of particles. Then, given the cumulative normalized
119129
weights ``W_k = Σ_{j=1}^{k} w_j ``, sample the ``k``th particle for each ``u_i``
120130
where ``W_{k-1} ≤ u_i < W_k``.
121131
122-
A `priority_fn` can be specified as a keyword argument, which maps log particle
123-
weights to custom log priority scores for the purpose of resampling
124-
(e.g. `w -> w/2` for less aggressive pruning). The `sort_particles` keyword
125-
argument controls whether particles are sorted by weight before stratification
126-
(default: true).
132+
# Keyword Arguments
133+
134+
- `priority_fn = nothing`: An optional function that maps particle weights to
135+
custom log priority scores (e.g. `w -> w/2` for less aggressive pruning).
136+
- `check = :warn`: Set to `true` to throw an error for invalid normalized
137+
weights (all NaNs or zeros), `:warn` to issue warnings, or `false` to
138+
suppress checks. In the latter two cases, zero weights will be renormalized
139+
to uniform weights for resampling.
140+
- `sort_particles = true`: Set to `true` to sort particles by weight before
141+
stratification.
127142
"""
128143
function pf_stratified_resample!(state::ParticleFilterView;
129-
priority_fn=nothing, sort_particles::Bool=true)
130-
# Update estimate of log marginal likelihood
131-
update_lml_est!(state)
144+
priority_fn=nothing, check=:warn,
145+
sort_particles::Bool=true)
132146
# Compute priority scores if priority function is provided
133147
log_priorities = priority_fn === nothing ?
134148
state.log_weights : priority_fn.(state.log_weights)
135-
weights = softmax(log_priorities)
149+
# Normalize weights and check their validity
150+
weights, invalid = safe_softmax(log_priorities, warn = (check != false))
151+
check == true && invalid && error("Invalid weights.")
152+
# Update estimate of log marginal likelihood
153+
update_lml_est!(state)
136154
# Optionally sort particles by weight before resampling
137155
n_particles = length(state.traces)
138156
order = sort_particles ?

src/resize.jl

Lines changed: 54 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@ export pf_introduce!
1212
Resizes a particle filter by resampling existing particles until a total of
1313
`n_particles` have been sampled. The resampling method can optionally be
1414
specified: `:multinomial` (default), `:residual` or `:optimal`.
15-
16-
A `priority_fn` can also be specified as a keyword argument, which maps log
17-
particle weights to custom log priority scores for the purpose of resampling
18-
(e.g. `w -> w/2` for less aggressive pruning).
1915
"""
2016
function pf_resize!(state::ParticleFilterState, n_particles::Int,
2117
method::Symbol=:multinomial; kwargs...)
@@ -38,22 +34,30 @@ Resizes a particle filter through multinomial resampling (i.e. simple random
3834
resampling) of existing particles until `n_particles` are sampled. Each trace
3935
(i.e. particle) is resampled with probability proportional to its weight.
4036
41-
A `priority_fn` can be specified as a keyword argument, which maps log particle
42-
weights to custom log priority scores for the purpose of resampling
43-
(e.g. `w -> w/2` for less aggressive pruning).
37+
# Keyword Arguments
38+
39+
- `priority_fn = nothing`: An optional function that maps particle weights to
40+
custom log priority scores (e.g. `w -> w/2` for less aggressive pruning).
41+
- `check = :warn`: Set to `true` to throw an error for invalid normalized
42+
weights (all NaNs or zeros), `:warn` to issue warnings, or `false` to
43+
suppress checks. In the latter two cases, zero weights will be renormalized
44+
to uniform weights for resampling.
4445
"""
4546
function pf_multinomial_resize!(state::ParticleFilterState, n_particles::Int;
46-
priority_fn=nothing)
47-
# Update estimate of log marginal likelihood
48-
update_lml_est!(state)
47+
priority_fn=nothing, check=:warn)
4948
# Compute priority scores if priority function is provided
5049
log_priorities = priority_fn === nothing ?
5150
state.log_weights : priority_fn.(state.log_weights)
51+
# Normalize weights and check their validity
52+
weights = softmax(log_priorities)
53+
weights, invalid = safe_softmax(log_priorities, warn = (check != false))
54+
check == true && invalid && error("Invalid weights.")
55+
# Update estimate of log marginal likelihood
56+
update_lml_est!(state)
5257
# Resize arrays
5358
resize!(state.parents, n_particles)
5459
resize!(state.new_traces, n_particles)
5560
# Resample new traces according to current normalized weights
56-
weights = softmax(log_priorities)
5761
rand!(Categorical(weights), state.parents)
5862
state.new_traces .= view(state.traces, state.parents)
5963
# Reweight particles and update trace references
@@ -63,30 +67,38 @@ function pf_multinomial_resize!(state::ParticleFilterState, n_particles::Int;
6367
end
6468

6569
"""
66-
pf_residual_resize!(state::ParticleFilterState, n_particles::Int; kwargs...)
70+
pf_residual_resize!(state::ParticleFilterState, n_particles::Int;
71+
kwargs...)
6772
6873
Resizes a particle filter through residual resampling of existing particles.
6974
For each particle with normalized weight ``w_i``, ``⌊n w_i⌋`` copies are
7075
resampled, where ``n`` is `n_particles`. The remainder are sampled with
7176
probability proportional to ``n w_i - ⌊n w_i⌋`` for each particle ``i``.
7277
73-
A `priority_fn` can be specified as a keyword argument, which maps log particle
74-
weights to custom log priority scores for the purpose of resampling
75-
(e.g. `w -> w/2` for less aggressive pruning).
78+
# Keyword Arguments
79+
80+
- `priority_fn = nothing`: An optional function that maps particle weights to
81+
custom log priority scores (e.g. `w -> w/2` for less aggressive pruning).
82+
- `check = :warn`: Set to `true` to throw an error for invalid normalized
83+
weights (all NaNs or zeros), `:warn` to issue warnings, or `false` to
84+
suppress checks. In the latter two cases, zero weights will be renormalized
85+
to uniform weights for resampling.
7686
"""
7787
function pf_residual_resize!(state::ParticleFilterState, n_particles::Int;
78-
priority_fn=nothing)
79-
# Update estimate of log marginal likelihood
80-
update_lml_est!(state)
88+
priority_fn=nothing, check=:warn)
8189
# Compute priority scores if priority function is provided
8290
log_priorities = priority_fn === nothing ?
8391
state.log_weights : priority_fn.(state.log_weights)
92+
# Normalize weights and check their validity
93+
weights, invalid = safe_softmax(log_priorities, warn = (check != false))
94+
check == true && invalid && error("Invalid weights.")
95+
# Update estimate of log marginal likelihood
96+
update_lml_est!(state)
8497
# Resize arrays
8598
resize!(state.parents, n_particles)
8699
resize!(state.new_traces, n_particles)
87100
# Deterministically copy previous particles according to their weights
88101
n_resampled = 0
89-
weights = softmax(log_priorities)
90102
for (i, w) in enumerate(weights)
91103
n_copies = floor(Int, n_particles * w)
92104
if n_copies == 0 continue end
@@ -122,34 +134,38 @@ the variance of the resulting weight distribution with respect to the original
122134
weight distribution. Note that `n_particles` should not be greater than the
123135
current number of particles.
124136
137+
# Keyword Arguments
138+
139+
- `check = :warn`: Set to `true` to throw an error for invalid normalized
140+
weights (all NaNs or zeros), `:warn` to issue warnings, or `false` to
141+
suppress checks. In the latter two cases, zero weights will be renormalized
142+
to uniform weights for resampling.
143+
125144
[1] Paul Fearnhead , Peter Clifford, On-Line Inference for Hidden Markov Models
126145
via Particle Filters, Journal of the Royal Statistical Society Series B:
127146
Statistical Methodology, Volume 65, Issue 4, November 2003, Pages 887–899,
128147
https://doi.org/10.1111/1467-9868.00421
129148
"""
130149
function pf_optimal_resize!(state::ParticleFilterState, n_particles::Int;
131-
kwargs...)
132-
# Resize arrays
133-
n_old = length(state.traces)
134-
@assert n_particles <= n_old
135-
resize!(state.parents, n_particles)
136-
resize!(state.new_traces, n_particles)
137-
# Normalize weights and compute inverse weight threshold
138-
weights = softmax(state.log_weights)
139-
inv_w_thresh = find_inv_w_threshold(weights, n_particles)
150+
check = :warn, kwargs...)
151+
# Normalize weights and check their validity
152+
weights, invalid = safe_softmax(state.log_weights, warn = (check != false))
153+
check == true && invalid && error("Invalid weights.")
140154
# Find particles to keep deterministically vs. resample with stratification
155+
inv_w_thresh = find_inv_w_threshold(weights, n_particles)
141156
keep_idxs = (inv_w_thresh .* weights .>= 1)
142157
strat_idxs = .!(keep_idxs)
143158
# Keep selected indices
144159
keep_idxs = findall(keep_idxs)
145160
n_keep = length(keep_idxs)
146-
state.parents[1:n_keep] .= keep_idxs
147-
# Perform stratified resampling on remaining indices
161+
# Perform stratified resampling on unselected indices
148162
n_resample = n_particles - n_keep
149163
resample_idxs = Int[]
150164
strat_idxs = findall(strat_idxs)
151165
n_strat = length(strat_idxs)
152-
norm_strat_weights = softmax(state.log_weights[strat_idxs])
166+
norm_strat_weights, invalid =
167+
safe_softmax(state.log_weights[strat_idxs], warn = (check != false))
168+
check == true && invalid && error("Invalid weights.")
153169
# Compute resampled indices
154170
step_size = 1 / n_resample
155171
u = rand() * step_size
@@ -160,9 +176,15 @@ function pf_optimal_resize!(state::ParticleFilterState, n_particles::Int;
160176
u += step_size
161177
end
162178
end
163-
# Keep resampled indices
179+
# Assign parent indices
180+
state.parents[1:n_keep] .= keep_idxs
164181
@assert length(resample_idxs) == n_resample
165182
state.parents[n_keep+1:n_particles] .= resample_idxs
183+
# Resize arrays
184+
n_old = length(state.traces)
185+
@assert n_particles <= n_old
186+
resize!(state.parents, n_particles)
187+
resize!(state.new_traces, n_particles)
166188
# Update weights
167189
log_n_ratio = log(n_particles) - log(n_old)
168190
log_tot_weight = logsumexp(state.log_weights)

src/utils.jl

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,46 @@ end
9999

100100
lognorm(vs::AbstractVector) = vs .- logsumexp(vs)
101101

102+
"Computes the softmax of a vector of (unnormalized) log probabilities."
102103
function softmax(vs::AbstractVector{T}) where {T <: Real}
103-
if isempty(vs) return T[] end
104+
isempty(vs) && return T[]
104105
ws = exp.(vs .- maximum(vs))
105106
return ws ./ sum(ws)
106107
end
107108

109+
"""
110+
probs, invalid = safe_softmax(vs; warn::Bool=true)
111+
112+
Returns the softmax of a vector of (unnormalized) log probabilities, and
113+
a boolean indicating whether the result is invalid. Invalid outputs can occur if
114+
`vs` contains any `NaN` values, or if all weights sum to zero. Warning messages
115+
are printed if `warn` is `true`.
116+
"""
117+
function safe_softmax(vs::AbstractVector{T}; warn::Bool=true) where {T <: Real}
118+
isempty(vs) && return T[]
119+
if any(isnan, vs)
120+
warn && @warn("NaN found in input values. Returning NaN weights.")
121+
ws = fill(convert(float(T), NaN), length(vs))
122+
return (ws, true)
123+
elseif all(==(-Inf), vs)
124+
warn && @warn("All input values are -Inf. Returning uniform weights.")
125+
ws = ones(float(T), length(vs)) ./ length(vs)
126+
return (ws, true)
127+
end
128+
ws = exp.(vs .- maximum(vs))
129+
total_w = sum(ws)
130+
if iszero(total_w)
131+
warn && @warn("All weights are zero. Returning uniform weights.")
132+
ws = ones(float(T), length(vs)) ./ length(vs)
133+
return (ws, true)
134+
elseif isnan(total_w)
135+
warn && @warn("Total weight is NaN. Returning NaN weights.")
136+
ws = fill(convert(float(T), NaN), length(vs))
137+
return (ws, true)
138+
end
139+
return (ws ./ sum(ws), false)
140+
end
141+
108142
"""
109143
get_log_norm_weights(state::ParticleFilterState)
110144

test/resample.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@
2121
new_lml_est = get_lml_est(state)
2222
@test new_traces == old_traces[state.parents]
2323
@test new_lml_est old_lml_est
24+
25+
# Test resampling with invalid weights
26+
with_logger(Logging.SimpleLogger(Logging.Error)) do
27+
state = pf_initialize(line_model, (0,), slope_choicemap(-3), 100)
28+
@test_throws ErrorException pf_multinomial_resample!(state, check=true)
29+
state = pf_multinomial_resample!(state, check=false)
30+
@test all(iszero, get_log_weights(state))
31+
end
2432
end
2533

2634
@testset "Residual resampling" begin
@@ -60,6 +68,14 @@ end
6068
@test all(copies .>= min_copies)
6169
new_lml_est = get_lml_est(state)
6270
@test new_lml_est old_lml_est
71+
72+
# Test resampling with invalid weights
73+
with_logger(Logging.SimpleLogger(Logging.Error)) do
74+
state = pf_initialize(line_model, (0,), slope_choicemap(-3), 100)
75+
@test_throws ErrorException pf_residual_resample!(state, check=true)
76+
state = pf_residual_resample!(state, check=false)
77+
@test all(iszero, get_log_weights(state))
78+
end
6379
end
6480

6581
@testset "Stratified resampling" begin
@@ -101,6 +117,14 @@ end
101117
@test copies >= min_copies
102118
new_lml_est = get_lml_est(state)
103119
@test new_lml_est old_lml_est
120+
121+
# Test resampling with invalid weights
122+
with_logger(Logging.SimpleLogger(Logging.Error)) do
123+
state = pf_initialize(line_model, (0,), slope_choicemap(-3), 100)
124+
@test_throws ErrorException pf_stratified_resample!(state, check=true)
125+
state = pf_stratified_resample!(state, check=false)
126+
@test all(iszero, get_log_weights(state))
127+
end
104128
end
105129

106130
@testset "Blockwise resampling of separate views" begin

0 commit comments

Comments
 (0)