22export pf_rejuvenate!, pf_move_accept!, pf_move_reweight!
33export move_reweight
44
5+ using Gen: check_observations
6+
57"""
68 pf_rejuvenate!(state::ParticleFilterState, kern, kern_args::Tuple=(),
7- n_iters::Int=1; method=:move)
9+ n_iters::Int=1; method=:move, kwargs... )
810
911Rejuvenates particles by repeated application of a kernel `kern`. `kern`
1012should be a callable which takes a trace as its first argument, and returns
1113a tuple with a trace as the first return value. `method` specifies the
1214rejuvenation method: `:move` for MCMC moves without a reweighting step,
13- and `:reweight` for rejuvenation with a reweighting step.
15+ and `:reweight` for rejuvenation with a reweighting step. Additional keyword
16+ arguments are passed to the kernel.
1417"""
1518function pf_rejuvenate! (state:: ParticleFilterView , kern, kern_args:: Tuple = (),
16- n_iters:: Int = 1 ; method:: Symbol = :move )
19+ n_iters:: Int = 1 ; method:: Symbol = :move , kwargs ... )
1720 if method == :move
18- return pf_move_accept! (state, kern, kern_args, n_iters)
21+ return pf_move_accept! (state, kern, kern_args, n_iters; kwargs ... )
1922 elseif method == :reweight
20- return pf_move_reweight! (state, kern, kern_args, n_iters)
23+ return pf_move_reweight! (state, kern, kern_args, n_iters; kwargs ... )
2124 else
2225 error (" Method $method not recognized." )
2326 end
2427end
2528
2629"""
2730 pf_move_accept!(state::ParticleFilterState, kern,
28- kern_args::Tuple=(), n_iters::Int=1)
31+ kern_args::Tuple=(), n_iters::Int=1; kwargs... )
2932
3033Rejuvenates particles by repeated application of a MCMC kernel `kern`. `kern`
3134should be a callable which takes a trace as its first argument, and returns
3235a tuple `(trace, accept)`, where `trace` is the (potentially) new trace, and
3336`accept` is true if the MCMC move was accepted. Subsequent arguments to `kern`
34- can be supplied with `kern_args`. The kernel is repeatedly applied to each trace
35- for `n_iters`.
37+ can be supplied with `kern_args` or `kwargs` . The kernel is repeatedly applied
38+ to each trace for `n_iters`.
3639"""
3740function pf_move_accept! (state:: ParticleFilterView ,
38- kern, kern_args:: Tuple = (), n_iters:: Int = 1 )
41+ kern, kern_args:: Tuple = (), n_iters:: Int = 1 ;
42+ kwargs... )
3943 # Potentially rejuvenate each trace
4044 for (i, trace) in enumerate (state. traces)
4145 for k = 1 : n_iters
42- trace, accept = kern (trace, kern_args... )
46+ trace, accept = kern (trace, kern_args... ; kwargs ... )
4347 @debug " Accepted: $accept "
4448 end
4549 state. new_traces[i] = trace
5054
5155"""
5256 pf_move_reweight!(state::ParticleFilterState, kern,
53- kern_args::Tuple=(), n_iters::Int=1)
57+ kern_args::Tuple=(), n_iters::Int=1; kwargs... )
5458
5559Rejuvenates and reweights particles by repeated application of a reweighting
5660kernel `kern`, as described in [1]. `kern` should be a callable which takes a
5761trace as its first argument, and returns a tuple `(trace, rel_weight)`,
5862where `trace` is the new trace, and `rel_weight` is the relative log-importance
59- weight. Subsequent arguments to `kern` can be supplied with `kern_args`.
60- The kernel is repeatedly applied to each trace for `n_iters`, and the weights
61- accumulated accordingly. Both the [`move_reweight`](@ref) function and
63+ weight. Subsequent arguments to `kern` can be supplied with `kern_args` or
64+ `kwargs`. The kernel is repeatedly applied to each trace for `n_iters`, and the
65+ weights accumulated accordingly.
66+
67+ Both the [`move_reweight`](@ref) function and
6268[symmetric trace translators](https://www.gen.dev/stable/ref/trace_translators/)
6369can serve as reweighting kernels.
6470
6571[1] R. A. G. Marques and G. Storvik, "Particle move-reweighting strategies for
6672online inference," Preprint series. Statistical Research Report, 2013.
6773"""
6874function pf_move_reweight! (state:: ParticleFilterView ,
69- kern, kern_args:: Tuple = (), n_iters:: Int = 1 )
75+ kern, kern_args:: Tuple = (), n_iters:: Int = 1 ;
76+ kwargs... )
7077 # Move and reweight each trace
7178 for (i, trace) in enumerate (state. traces)
7279 weight = 0
7380 for k = 1 : n_iters
74- trace, rel_weight = kern (trace, kern_args... )
81+ trace, rel_weight = kern (trace, kern_args... ; kwargs ... )
7582 weight += rel_weight
7683 @debug " Rel. Weight: $rel_weight "
7784 end
@@ -83,15 +90,15 @@ function pf_move_reweight!(state::ParticleFilterView,
8390end
8491
8592"""
86- move_reweight(trace, selection)
87- move_reweight(trace, proposal, proposal_args)
88- move_reweight(trace, proposal, proposal_args, involution)
89- move_reweight(trace, proposal_fwd, args_fwd,
90- proposal_bwd, args_bwd, involution )
93+ move_reweight(trace, selection; kwargs... )
94+ move_reweight(trace, proposal, proposal_args; kwargs... )
95+ move_reweight(trace, proposal, proposal_args, involution; kwargs... )
96+ move_reweight(trace, proposal_fwd, args_fwd, proposal_bwd, args_bwd,
97+ involution ; kwargs... )
9198
92- Move-reweight MCMC kernel, which takes in a `trace` and returns a new trace
93- along with a relative importance weight. This can be used for rejuvenation
94- within a particle filter, as described in [1].
99+ Move-reweight rejuvenation kernel, which takes in a `trace` and returns a
100+ new trace along with a relative importance weight. This can be used for
101+ rejuvenation within a particle filter, as described in [1].
95102
96103Several variants of `move_reweight` exist, differing in the complexity
97104involved in proposing and re-weighting random choices:
@@ -108,18 +115,25 @@ involved in proposing and re-weighting random choices:
108115 adjusts the computation of the relative importance weight by scoring
109116 the backward choices under the backward proposal.
110117
118+ Similar to `metropolis_hastings`, a `check` flag and `observations` choicemap
119+ can be provided as keyword arguments to ensure that observed choices are
120+ preserved in the new trace.
121+
111122[1] R. A. G. Marques and G. Storvik, "Particle move-reweighting strategies for
112123online inference," Preprint series. Statistical Research Report, 2013.
113124"""
114- function move_reweight (trace:: Trace , selection:: Selection )
125+ function move_reweight (trace:: Trace , selection:: Selection ;
126+ check= false , observations= EmptyChoiceMap ())
115127 args = get_args (trace)
116128 argdiffs = map ((_) -> NoChange (), args)
117129 new_trace, rel_weight = regenerate (trace, args, argdiffs, selection)
130+ check && check_observations (get_choices (new_trace), observations)
118131 return new_trace, rel_weight
119132end
120133
121134function move_reweight (trace:: Trace , proposal:: GenerativeFunction ,
122- proposal_args:: Tuple )
135+ proposal_args:: Tuple ;
136+ check= false , observations= EmptyChoiceMap ())
123137 model_args = Gen. get_args (trace)
124138 argdiffs = map ((_) -> NoChange (), model_args)
125139 fwd_choices, fwd_score, fwd_ret =
@@ -128,6 +142,7 @@ function move_reweight(trace::Trace, proposal::GenerativeFunction,
128142 update (trace, model_args, argdiffs, fwd_choices)
129143 bwd_score, bwd_ret =
130144 assess (proposal, (new_trace, proposal_args... ), discard)
145+ check && check_observations (get_choices (new_trace), observations)
131146 rel_weight = weight - fwd_score + bwd_score
132147 return new_trace, rel_weight
133148end
@@ -140,19 +155,22 @@ function move_reweight(trace::Trace, proposal::GenerativeFunction,
140155 involution (trace, fwd_choices, fwd_ret, proposal_args)
141156 bwd_score, bwd_ret =
142157 assess (proposal, (new_trace, proposal_args... ), bwd_choices)
158+ check && check_observations (get_choices (new_trace), observations)
143159 rel_weight = weight - fwd_score + bwd_score
144160 return new_trace, rel_weight
145161end
146162
147163function move_reweight (trace:: Trace , proposal_fwd:: GenerativeFunction ,
148164 args_fwd:: Tuple , proposal_bwd:: GenerativeFunction ,
149- args_bwd:: Tuple , involution)
165+ args_bwd:: Tuple , involution;
166+ check= false , observations= EmptyChoiceMap ())
150167 fwd_choices, fwd_score, fwd_ret =
151168 propose (proposal_fwd, (trace, args_fwd... ,))
152169 new_trace, bwd_choices, weight =
153170 involution (trace, fwd_choices, fwd_ret, args_fwd)
154171 bwd_score, bwd_ret =
155172 assess (proposal_bwd, (new_trace, args_bwd... ), bwd_choices)
173+ check && check_observations (get_choices (new_trace), observations)
156174 rel_weight = weight - fwd_score + bwd_score
157175 return new_trace, rel_weight
158176end
0 commit comments