@@ -11,7 +11,7 @@ using Gen: run_first_pass, jacobian_correction, check_round_trip, run_transform
1111 new_observations::ChoiceMap = EmptyChoiceMap(),
1212 q_forward::GenerativeFunction,
1313 q_forward_args::Tuple = (),
14- f ::Union{TraceTransformDSLProgram,Nothing} = nothing)
14+ transform ::Union{TraceTransformDSLProgram,Nothing} = nothing)
1515Constructor for a extending trace translator.
1616Run the translator with:
1717 (output_trace, log_weight) = translator(input_trace)
@@ -22,7 +22,7 @@ Run the translator with:
2222 new_observations:: ChoiceMap = EmptyChoiceMap ()
2323 q_forward:: GenerativeFunction
2424 q_forward_args:: Tuple = ()
25- f :: Union{TraceTransformDSLProgram,Nothing} = nothing # a bijection
25+ transform :: Union{TraceTransformDSLProgram,Nothing} = nothing # a bijection
2626end
2727
2828function (translator:: ExtendingTraceTranslator )(prev_model_trace:: Trace )
@@ -33,14 +33,14 @@ function (translator::ExtendingTraceTranslator)(prev_model_trace::Trace)
3333 forward_proposal_score = get_score (forward_proposal_trace)
3434
3535 # transform forward proposal
36- if translator. f === nothing
36+ if translator. transform === nothing
3737 constraints = get_choices (forward_proposal_trace)
3838 log_abs_determinant = 0.0
3939 else
4040 first_pass_results =
41- run_first_pass (translator. f , forward_proposal_trace, nothing )
41+ run_first_pass (translator. transform , forward_proposal_trace, nothing )
4242 log_abs_determinant =
43- jacobian_correction (translator. f , forward_proposal_trace,
43+ jacobian_correction (translator. transform , forward_proposal_trace,
4444 nothing , first_pass_results, nothing )
4545 constraints = first_pass_results. constraints
4646 end
@@ -97,7 +97,7 @@ the observed random choices in the previous trace.
9797 q_forward_args:: Tuple = ()
9898 q_backward:: GenerativeFunction
9999 q_backward_args:: Tuple = ()
100- f :: TraceTransformDSLProgram
100+ transform :: TraceTransformDSLProgram
101101end
102102
103103function Gen. inverse (translator:: UpdatingTraceTranslator , prev_model_trace:: Trace ,
@@ -106,23 +106,22 @@ function Gen.inverse(translator::UpdatingTraceTranslator, prev_model_trace::Trac
106106 get_args (prev_model_trace), map ((_)-> UnknownChange (), get_args (prev_model_trace)),
107107 prev_observations, translator. q_backward, translator. q_backward_args,
108108 translator. q_forward, translator. q_forward_args,
109- inverse (translator. f ))
109+ inverse (translator. transform ))
110110end
111111
112112function Gen. run_transform (translator:: UpdatingTraceTranslator ,
113- prev_model_trace:: Trace , forward_proposal_trace:: Trace ,
114- check:: Bool = false )
115- @unpack f, new_observations = translator
113+ prev_model_trace:: Trace , forward_proposal_trace:: Trace )
114+ @unpack transform, new_observations = translator
116115 @unpack p_new_args, p_argdiffs, q_backward, q_backward_args = translator
117- first_pass_results =
118- Gen . run_first_pass (f , prev_model_trace, forward_proposal_trace)
116+ first_pass_results = run_first_pass (
117+ transform , prev_model_trace, forward_proposal_trace)
119118 constraints = merge (first_pass_results. constraints, new_observations)
120- ( new_model_trace, _, _, discard) = update (
119+ new_model_trace, _, _, discard = update (
121120 prev_model_trace, p_new_args, p_argdiffs, constraints)
122- log_abs_determinant = jacobian_correction (f, prev_model_trace,
123- forward_proposal_trace, first_pass_results, discard)
124- backward_proposal_trace, = generate (q_backward,
125- (new_model_trace, q_backward_args... ), first_pass_results. u_back)
121+ log_abs_determinant = jacobian_correction (
122+ transform, prev_model_trace, forward_proposal_trace, first_pass_results, discard)
123+ backward_proposal_trace, _ = generate (
124+ q_backward, (new_model_trace, q_backward_args... ), first_pass_results. u_back)
126125 return (new_model_trace, backward_proposal_trace, log_abs_determinant)
127126end
128127
@@ -135,7 +134,7 @@ function (translator::UpdatingTraceTranslator)(
135134
136135 # apply trace transform
137136 (new_model_trace, backward_proposal_trace, log_abs_determinant) =
138- run_transform (translator, prev_model_trace, forward_proposal_trace, check )
137+ run_transform (translator, prev_model_trace, forward_proposal_trace)
139138
140139 # compute log weight
141140 prev_model_score = get_score (prev_model_trace)
@@ -149,7 +148,7 @@ function (translator::UpdatingTraceTranslator)(
149148 inverter = inverse (translator, prev_model_trace, prev_observations)
150149 argdiffs = map ((_) -> UnknownChange (), get_args (prev_model_trace))
151150 (prev_model_trace_rt, forward_proposal_trace_rt, _) =
152- run_transform (inverter, new_model_trace, backward_proposal_trace, check )
151+ run_transform (inverter, new_model_trace, backward_proposal_trace)
153152 check_round_trip (prev_model_trace, prev_model_trace_rt,
154153 forward_proposal_trace, forward_proposal_trace_rt)
155154 end
0 commit comments