-
Notifications
You must be signed in to change notification settings - Fork 242
Expand file tree
/
Copy pathsbc_test.py
More file actions
373 lines (309 loc) · 11.2 KB
/
sbc_test.py
File metadata and controls
373 lines (309 loc) · 11.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from __future__ import annotations
from typing import Callable, Dict, Optional
import pytest
import torch
from torch import eye, ones, zeros
from torch.distributions import MultivariateNormal, Uniform
from sbi.analysis import sbc_rank_plot
from sbi.diagnostics import check_sbc, get_nltp, run_sbc
from sbi.inference import NLE, NPE, NPSE
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.posteriors.posterior_parameters import (
MCMCPosteriorParameters,
VIPosteriorParameters,
)
from sbi.simulators.linear_gaussian import linear_gaussian
from sbi.utils import BoxUniform, MultipleIndependent
from tests.test_utils import PosteriorPotential, TractablePosterior
@pytest.fixture
def gaussian_setup():
"""Fixture for common Gaussian test setup."""
num_dim = 2
likelihood_shift = -1.0 * ones(num_dim)
likelihood_cov = 0.3 * eye(num_dim)
prior_mean = zeros(num_dim)
prior_cov = eye(num_dim)
prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
def simulator(theta):
return linear_gaussian(theta, likelihood_shift, likelihood_cov)
return {
"num_dim": num_dim,
"prior": prior,
"simulator": simulator,
"likelihood_shift": likelihood_shift,
"likelihood_cov": likelihood_cov,
}
def train_inference_method(
method_cls: Callable,
prior: torch.distributions.Distribution,
simulator: Callable,
num_simulations: int = 100,
max_num_epochs: int = 1,
**kwargs,
) -> NeuralPosterior:
"""Helper function to train an inference method and return its posterior."""
inferer = method_cls(prior, show_progress_bars=False)
theta = prior.sample((num_simulations,))
x = simulator(theta)
inferer.append_simulations(theta, x).train(max_num_epochs=max_num_epochs)
posterior = inferer.build_posterior(**kwargs)
return posterior
@pytest.mark.parametrize("reduce_fn_str", ("marginals", "posterior_log_prob"))
@pytest.mark.parametrize("prior_type", ("boxuniform", "independent"))
@pytest.mark.parametrize(
"method, sampler",
(
(NPE, None),
pytest.param(NLE, "mcmc", marks=pytest.mark.mcmc),
pytest.param(NLE, "vi", marks=pytest.mark.mcmc),
(NPSE, None),
),
)
def test_running_sbc(
method,
prior_type: str,
reduce_fn_str: str,
sampler: Optional[str],
mcmc_params_fast: MCMCPosteriorParameters,
):
"""Test running inference and then SBC and obtaining nltp with different methods."""
# Setup
num_dim = 2
if prior_type == "boxuniform":
prior = BoxUniform(-torch.ones(num_dim), torch.ones(num_dim))
else:
prior = MultipleIndependent([
Uniform(-torch.ones(1), torch.ones(1)) for _ in range(num_dim)
])
# Test parameters
num_simulations = 100
max_num_epochs = 1
num_sbc_runs = 2
num_posterior_samples = 20
likelihood_shift = -1.0 * ones(num_dim)
likelihood_cov = 0.3 * eye(num_dim)
# Helper function to simulate data
def simulator(theta):
return linear_gaussian(theta, likelihood_shift, likelihood_cov)
# Build posterior
posterior_kwargs = {}
if method == NLE:
posterior_kwargs = {
"posterior_parameters": mcmc_params_fast
if sampler == "mcmc"
else VIPosteriorParameters()
}
posterior = train_inference_method(
method,
prior,
simulator,
num_simulations=num_simulations,
max_num_epochs=max_num_epochs,
**posterior_kwargs,
)
# Generate test data for SBC
thetas = prior.sample((num_sbc_runs,))
xs = simulator(thetas)
# Run SBC
reduce_fn = "marginals" if reduce_fn_str == "marginals" else posterior.potential
ranks, _ = run_sbc(
thetas,
xs,
posterior,
num_posterior_samples=num_posterior_samples,
reduce_fns=reduce_fn,
)
# Basic shape check
target_rank_dim = num_dim if reduce_fn_str == "marginals" else 1
assert ranks.shape == (num_sbc_runs, target_rank_dim), "Ranks shape is incorrect"
# Check nltp calculation (only for normalized posteriors)
if method in [NPE, NPSE]:
nltp = get_nltp(thetas, xs, posterior)
assert nltp.shape == (num_sbc_runs,), "NLTP shape is incorrect"
@pytest.mark.slow
@pytest.mark.parametrize("density_estimator", ["mdn", "maf"])
@pytest.mark.parametrize("cov_method", ("sbc", "coverage"))
def test_consistent_sbc_results(
density_estimator: str, cov_method: str, gaussian_setup: Dict
):
"""Test consistent SBC results on well-trained NPE."""
# Extract setup from fixture
prior = gaussian_setup["prior"]
simulator = gaussian_setup["simulator"]
# Test parameters
num_simulations = 4500
num_posterior_samples = 1000
num_sbc_runs = 100
# Create and train inference
inference = NPE(prior=prior, density_estimator=density_estimator)
theta = prior.sample((num_simulations,))
x = simulator(theta)
inference.append_simulations(theta, x).train()
posterior = inference.build_posterior()
# Generate test data
thetas = prior.sample((num_sbc_runs,))
xs = simulator(thetas)
# Run SBC
ranks, dap_samples = run_sbc(
thetas,
xs,
posterior,
num_workers=1,
num_posterior_samples=num_posterior_samples,
# Switch between SBC and expected coverage
reduce_fns="marginals" if cov_method == "sbc" else posterior.log_prob,
)
# Check results
checks = check_sbc(
ranks,
prior.sample((num_sbc_runs,)),
dap_samples,
num_posterior_samples=num_posterior_samples,
)
# Statistical tests
assert (checks["ks_pvals"] > 0.05).all(), (
f"KS p-values too small: {checks['ks_pvals']}"
)
assert (checks["c2st_ranks"] < 0.6).all(), (
f"C2ST ranks too large: {checks['c2st_ranks']}"
)
assert (checks["c2st_dap"] < 0.6).all(), f"C2ST DAP too large: {checks['c2st_dap']}"
def test_sbc_accuracy():
"""Test SBC with prior as posterior (perfect calibration case)."""
num_dim = 2
# Gaussian toy problem, set posterior = prior
simulator = lambda theta: torch.randn_like(theta) + theta
prior = BoxUniform(-ones(num_dim), ones(num_dim))
posterior_dist = prior
# Create tractable posterior for testing
potential = PosteriorPotential(posterior=posterior_dist, prior=prior)
posterior = TractablePosterior(potential_fn=potential)
# Run SBC
N = L = 1000
thetas = prior.sample((N,))
xs = simulator(thetas)
ranks, daps = run_sbc(
thetas,
xs,
posterior,
num_workers=1,
num_posterior_samples=L,
)
# Check results
checks = check_sbc(ranks, prior.sample((N,)), daps, num_posterior_samples=L)
pvals, c2st_ranks, _ = checks.values()
# With perfect calibration, ranks should be uniform
assert (c2st_ranks <= 0.6).all(), "posterior ranks must be close to uniform."
assert (pvals > 0.05).all(), "posterior ranks uniformity test p-values too small."
@pytest.mark.slow
def test_sbc_checks():
"""Test the uniformity checks for SBC with artificial uniform ranks."""
num_dim = 2
num_posterior_samples = 1500
prior = MultivariateNormal(zeros(num_dim), eye(num_dim))
# Data averaged posterior samples should be distributed as prior
daps = prior.sample((num_posterior_samples,))
# Create perfectly uniform ranks for testing
ranks = torch.distributions.Uniform(
zeros(num_dim), num_posterior_samples * ones(num_dim)
).sample((num_posterior_samples,))
# Run checks
checks = check_sbc(
ranks,
prior.sample((num_posterior_samples,)),
daps,
num_posterior_samples=num_posterior_samples,
)
# With artificial uniform ranks, test statistics should indicate uniformity
assert (checks["ks_pvals"] > 0.05).all(), "KS test failed on uniform ranks"
assert (checks["c2st_ranks"] < 0.55).all(), "C2ST failed on uniform ranks"
assert (checks["c2st_dap"] < 0.55).all(), (
"C2ST failed on prior-distributed DAP samples"
)
@pytest.mark.parametrize("num_bins", (None, 30))
@pytest.mark.parametrize("plot_type", ("cdf", "hist", "cdf-diff"))
@pytest.mark.parametrize("legend_kwargs", (None, {"loc": "upper left"}))
@pytest.mark.parametrize("num_rank_sets", (1, 2))
def test_sbc_plotting(
num_bins: Optional[int],
plot_type: str,
legend_kwargs: Optional[Dict],
num_rank_sets: int,
):
"""Test SBC plotting functionality with various options."""
num_dim = 2
num_posterior_samples = 1000
# Generate artificial uniform ranks for testing visualization
ranks = [
torch.distributions.Uniform(
zeros(num_dim), num_posterior_samples * ones(num_dim)
).sample((num_posterior_samples,))
] * num_rank_sets
# Test that plotting function runs without errors
fig = sbc_rank_plot(
ranks,
num_posterior_samples,
num_bins=num_bins,
plot_type=plot_type,
legend_kwargs=legend_kwargs,
)
# Basic check that figure was created
assert fig is not None, "Plot function should return a figure"
@pytest.mark.parametrize("num_workers", [1, 2])
def test_sbc_parallelization(num_workers: int, gaussian_setup: Dict):
"""Test that SBC produces consistent results with different worker counts."""
prior = gaussian_setup["prior"]
simulator = gaussian_setup["simulator"]
# Parameters
num_simulations = 200
num_sbc_runs = 10
num_posterior_samples = 50
# Train model
posterior = train_inference_method(
NPE, prior, simulator, num_simulations=num_simulations, max_num_epochs=1
)
# Generate test data
thetas = prior.sample((num_sbc_runs,))
xs = simulator(thetas)
# Run SBC with specified number of workers
ranks, _ = run_sbc(
thetas,
xs,
posterior,
num_posterior_samples=num_posterior_samples,
num_workers=num_workers,
)
# Check shape
assert ranks.shape == (num_sbc_runs, gaussian_setup["num_dim"]), (
f"Ranks shape incorrect for {num_workers} workers"
)
@pytest.mark.parametrize("batch_sampling", [True, False])
def test_sbc_batch_sampling(batch_sampling: bool, gaussian_setup: Dict):
"""Test that SBC works with both batched and non-batched sampling."""
prior = gaussian_setup["prior"]
simulator = gaussian_setup["simulator"]
# Parameters
num_simulations = 200
num_sbc_runs = 5
num_posterior_samples = 50
# Train model
posterior = train_inference_method(
NPE, prior, simulator, num_simulations=num_simulations, max_num_epochs=1
)
# Generate test data
thetas = prior.sample((num_sbc_runs,))
xs = simulator(thetas)
# Run SBC with specified batch sampling setting
ranks, _ = run_sbc(
thetas,
xs,
posterior,
num_posterior_samples=num_posterior_samples,
use_batched_sampling=batch_sampling,
)
# Check shape
assert ranks.shape == (num_sbc_runs, gaussian_setup["num_dim"]), (
f"Ranks shape incorrect with batched_sampling={batch_sampling}"
)