Skip to content

Commit f485602

Browse files
committed
format
1 parent 4cd782a commit f485602

2 files changed

Lines changed: 22 additions & 16 deletions

File tree

papers/tutorial/calibration/post_cal.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def post_process_cal(
1212
data_x,
1313
data_y,
1414
sigma,
15-
args
15+
args,
1616
):
1717
# Post process
1818
ranges = []
@@ -45,8 +45,8 @@ def post_process_cal(
4545
else:
4646
filename += "_opt"
4747
filename += f"_{args.alpha}_{args.beta}_corner"
48-
plt.savefig(filename+".png")
49-
plt.savefig(filename+".eps")
48+
plt.savefig(filename + ".png")
49+
plt.savefig(filename + ".eps")
5050

5151
for ax in fig.get_axes():
5252
ax.tick_params(
@@ -113,7 +113,7 @@ def post_process_cal(
113113
label="95% Model confidence interval",
114114
)
115115
plt.plot(rangex, std2_5_real, "--", color="k", linewidth=3)
116-
pretty_labels("", "", 20, title=f"Missing phys. unc. = {sigma:.2g}")
116+
pretty_labels("", "", 20, title=f"Missing phys. unc. = {sigma:.2g}")
117117
filename = ""
118118
if args.useNN:
119119
filename += "Surr"
@@ -124,6 +124,6 @@ def post_process_cal(
124124
else:
125125
filename += "_opt"
126126
filename += f"_{args.alpha}_{args.beta}_prop"
127-
plt.savefig(filename+".png")
128-
plt.savefig(filename+".eps")
127+
plt.savefig(filename + ".png")
128+
plt.savefig(filename + ".eps")
129129
plt.show()

papers/tutorial/calibration/tut_calibration.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def bayes_step_opt_err(y=None, y_err=0.1):
155155
std_obs = jnp.ones(y_model.shape[0]) * y_err
156156
numpyro.sample("obs", dist.Normal(y_model, std_obs), obs=y)
157157

158+
158159
def bayes_step_cal_err(y=None):
159160
mu_d = numpyro.sample("mu_d", dist.Uniform(0.01, 0.9))
160161
sigma_d = numpyro.sample("sigma_d", dist.Uniform(0.01, 0.9))
@@ -163,6 +164,7 @@ def bayes_step_cal_err(y=None):
163164
std_obs = jnp.ones(y_model.shape[0]) * err
164165
numpyro.sample("obs", dist.Normal(y_model, std_obs), obs=y)
165166

167+
166168
def mcmc_iter(y_err=0.1, mcmc_method="HMC", cal_err=False):
167169
rng_key = random.PRNGKey(0)
168170
rng_key, rng_key_ = random.split(rng_key)
@@ -176,7 +178,6 @@ def mcmc_iter(y_err=0.1, mcmc_method="HMC", cal_err=False):
176178
bayes_step = bayes_step_cal_err
177179
else:
178180
bayes_step = bayes_step_opt_err
179-
180181

181182
# Hamiltonian Monte Carlo (HMC) with no u turn sampling (NUTS)
182183
if mcmc_method.lower() == "hmc":
@@ -195,9 +196,9 @@ def mcmc_iter(y_err=0.1, mcmc_method="HMC", cal_err=False):
195196
jit_model_args=True,
196197
)
197198
if cal_err:
198-
mcmc.run(rng_key_, y=data_y)
199+
mcmc.run(rng_key_, y=data_y)
199200
else:
200-
mcmc.run(rng_key_, y=data_y, y_err=y_err)
201+
mcmc.run(rng_key_, y=data_y, y_err=y_err)
201202
mcmc.print_summary()
202203

203204
# Draw samples
@@ -246,7 +247,7 @@ def mcmc_iter(y_err=0.1, mcmc_method="HMC", cal_err=False):
246247
}
247248

248249
if cal_err:
249-
err = np.mean(np_mcmc_samples[:,-1])
250+
err = np.mean(np_mcmc_samples[:, -1])
250251
else:
251252
err = y_err
252253
true_m95 = data_y - 2 * y_err
@@ -260,24 +261,29 @@ def mcmc_iter(y_err=0.1, mcmc_method="HMC", cal_err=False):
260261
else:
261262
return True, results
262263

264+
263265
if not args.cal_err:
264266
min_sigma = 1e-4
265267
max_sigma = 1
266268
guess_sigma = 0.07
267269
sigma = guess_sigma
268-
270+
269271
for iteration_sigma in range(10):
270272
print(f"Doing sigma = {sigma:.3g}")
271-
reduce_sigma, results = mcmc_iter(y_err=sigma, mcmc_method="hmc", cal_err=args.cal_err)
273+
reduce_sigma, results = mcmc_iter(
274+
y_err=sigma, mcmc_method="hmc", cal_err=args.cal_err
275+
)
272276
if reduce_sigma:
273277
max_sigma = sigma
274278
sigma = sigma - (sigma - min_sigma) / 2
275279
else:
276280
min_sigma = sigma
277281
sigma = sigma + (max_sigma - sigma) / 2
278-
282+
279283
if not reduce_sigma:
280-
reduce_sigma, results = mcmc_iter(y_err=max_sigma, mcmc_method="hmc", cal_err=args.cal_err)
284+
reduce_sigma, results = mcmc_iter(
285+
y_err=max_sigma, mcmc_method="hmc", cal_err=args.cal_err
286+
)
281287
else:
282288
_, results = mcmc_iter(mcmc_method="hmc", cal_err=args.cal_err)
283289

@@ -293,7 +299,7 @@ def mcmc_iter(y_err=0.1, mcmc_method="HMC", cal_err=False):
293299
)
294300

295301
if args.cal_err:
296-
sigma = np.mean(np_mcmc_samples[:,-1])
302+
sigma = np.mean(np_mcmc_samples[:, -1])
297303

298304

299305
post_process_cal(
@@ -305,5 +311,5 @@ def mcmc_iter(y_err=0.1, mcmc_method="HMC", cal_err=False):
305311
data_x,
306312
data_y,
307313
sigma,
308-
args
314+
args,
309315
)

0 commit comments

Comments
 (0)