@@ -155,6 +155,7 @@ def bayes_step_opt_err(y=None, y_err=0.1):
155155 std_obs = jnp .ones (y_model .shape [0 ]) * y_err
156156 numpyro .sample ("obs" , dist .Normal (y_model , std_obs ), obs = y )
157157
158+
158159def bayes_step_cal_err (y = None ):
159160 mu_d = numpyro .sample ("mu_d" , dist .Uniform (0.01 , 0.9 ))
160161 sigma_d = numpyro .sample ("sigma_d" , dist .Uniform (0.01 , 0.9 ))
@@ -163,6 +164,7 @@ def bayes_step_cal_err(y=None):
163164 std_obs = jnp .ones (y_model .shape [0 ]) * err
164165 numpyro .sample ("obs" , dist .Normal (y_model , std_obs ), obs = y )
165166
167+
166168def mcmc_iter (y_err = 0.1 , mcmc_method = "HMC" , cal_err = False ):
167169 rng_key = random .PRNGKey (0 )
168170 rng_key , rng_key_ = random .split (rng_key )
@@ -176,7 +178,6 @@ def mcmc_iter(y_err=0.1, mcmc_method="HMC", cal_err=False):
176178 bayes_step = bayes_step_cal_err
177179 else :
178180 bayes_step = bayes_step_opt_err
179-
180181
181182 # Hamiltonian Monte Carlo (HMC) with no u turn sampling (NUTS)
182183 if mcmc_method .lower () == "hmc" :
@@ -195,9 +196,9 @@ def mcmc_iter(y_err=0.1, mcmc_method="HMC", cal_err=False):
195196 jit_model_args = True ,
196197 )
197198 if cal_err :
198- mcmc .run (rng_key_ , y = data_y )
199+ mcmc .run (rng_key_ , y = data_y )
199200 else :
200- mcmc .run (rng_key_ , y = data_y , y_err = y_err )
201+ mcmc .run (rng_key_ , y = data_y , y_err = y_err )
201202 mcmc .print_summary ()
202203
203204 # Draw samples
@@ -246,7 +247,7 @@ def mcmc_iter(y_err=0.1, mcmc_method="HMC", cal_err=False):
246247 }
247248
248249 if cal_err :
249- err = np .mean (np_mcmc_samples [:,- 1 ])
250+ err = np .mean (np_mcmc_samples [:, - 1 ])
250251 else :
251252 err = y_err
252253 true_m95 = data_y - 2 * y_err
@@ -260,24 +261,29 @@ def mcmc_iter(y_err=0.1, mcmc_method="HMC", cal_err=False):
260261 else :
261262 return True , results
262263
264+
263265if not args .cal_err :
264266 min_sigma = 1e-4
265267 max_sigma = 1
266268 guess_sigma = 0.07
267269 sigma = guess_sigma
268-
270+
269271 for iteration_sigma in range (10 ):
270272 print (f"Doing sigma = { sigma :.3g} " )
271- reduce_sigma , results = mcmc_iter (y_err = sigma , mcmc_method = "hmc" , cal_err = args .cal_err )
273+ reduce_sigma , results = mcmc_iter (
274+ y_err = sigma , mcmc_method = "hmc" , cal_err = args .cal_err
275+ )
272276 if reduce_sigma :
273277 max_sigma = sigma
274278 sigma = sigma - (sigma - min_sigma ) / 2
275279 else :
276280 min_sigma = sigma
277281 sigma = sigma + (max_sigma - sigma ) / 2
278-
282+
279283 if not reduce_sigma :
280- reduce_sigma , results = mcmc_iter (y_err = max_sigma , mcmc_method = "hmc" , cal_err = args .cal_err )
284+ reduce_sigma , results = mcmc_iter (
285+ y_err = max_sigma , mcmc_method = "hmc" , cal_err = args .cal_err
286+ )
281287else :
282288 _ , results = mcmc_iter (mcmc_method = "hmc" , cal_err = args .cal_err )
283289
@@ -293,7 +299,7 @@ def mcmc_iter(y_err=0.1, mcmc_method="HMC", cal_err=False):
293299)
294300
295301if args .cal_err :
296- sigma = np .mean (np_mcmc_samples [:,- 1 ])
302+ sigma = np .mean (np_mcmc_samples [:, - 1 ])
297303
298304
299305post_process_cal (
@@ -305,5 +311,5 @@ def mcmc_iter(y_err=0.1, mcmc_method="HMC", cal_err=False):
305311 data_x ,
306312 data_y ,
307313 sigma ,
308- args
314+ args ,
309315)
0 commit comments