Skip to content

Commit 8facb74

Browse files
fix: remove NPE rejection prior guard, extend tests for sample_with across all methods (#1839)
* Enhance parameterized tests for various sampling methods * remove prior validation for rejection sampling in PosteriorEstimatorTrainer * Remove xfail test for rejection sampling
1 parent c18a585 commit 8facb74

3 files changed

Lines changed: 53 additions & 76 deletions

File tree

sbi/inference/trainers/npe/npe_base.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -437,10 +437,6 @@ def build_posterior(
437437
(the returned log-probability is unnormalized).
438438
"""
439439

440-
self._check_prior_for_rejection_sampling(
441-
prior, sample_with, posterior_parameters
442-
)
443-
444440
return super().build_posterior(
445441
density_estimator,
446442
prior,
@@ -484,44 +480,6 @@ def _get_potential_function(
484480
)
485481
return potential_fn, theta_transform
486482

487-
def _check_prior_for_rejection_sampling(
488-
self,
489-
prior: Optional[Distribution],
490-
sample_with: Literal["mcmc", "rejection", "vi", "importance", "direct"],
491-
posterior_parameters: Optional[
492-
Union[
493-
DirectPosteriorParameters,
494-
MCMCPosteriorParameters,
495-
VIPosteriorParameters,
496-
RejectionPosteriorParameters,
497-
ImportanceSamplingPosteriorParameters,
498-
]
499-
],
500-
) -> None:
501-
"""
502-
Validates that when using rejection sampling, a prior distribution
503-
is explicitly provided.
504-
505-
Args:
506-
prior: Prior distribution.
507-
sample_with: The sampling method used. Must be one of
508-
"mcmc", "rejection", "vi", "importance", or "direct".
509-
posterior_parameters: Configuration for building the posterior.
510-
"""
511-
512-
if (
513-
sample_with == "rejection"
514-
or isinstance(posterior_parameters, RejectionPosteriorParameters)
515-
) and prior is None:
516-
raise ValueError(
517-
"You indicated sampling via rejection sampling but "
518-
"haven't passed a prior. As of sbi v0.23.0, you either have"
519-
" to pass a prior to perform rejection sampling using the prior"
520-
" as proposal, or to use the posterior as proposal, you have to"
521-
" use a DirectPosterior via `sample_with='direct' or"
522-
" `posterior_parameters=DirectPosteriorParameters`."
523-
)
524-
525483
def _loss(
526484
self,
527485
theta: Tensor,

tests/posterior_nn_test.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from sbi.inference import (
1212
FMPE,
1313
NLE_A,
14-
NPE,
1514
NPE_A,
1615
NPE_C,
1716
NPSE,
@@ -23,7 +22,6 @@
2322
)
2423
from sbi.inference.posteriors.posterior_parameters import (
2524
MCMCPosteriorParameters,
26-
RejectionPosteriorParameters,
2725
)
2826
from sbi.inference.potentials.posterior_based_potential import (
2927
posterior_estimator_based_potential,
@@ -459,27 +457,6 @@ def simulator(theta):
459457
inference.build_posterior(density_estimator=nn.Module())
460458

461459

462-
@pytest.mark.xfail(
463-
raises=ValueError,
464-
reason="Prior must be passed through build_posterior method for rejection"
465-
" sampling in NPE",
466-
)
467-
def test_build_posterior_raises_error_for_rejection_sampling():
468-
def simulator(theta):
469-
return 1.0 + theta + torch.randn(theta.shape, device=theta.device) * 0.1
470-
471-
num_dim = 3
472-
prior = BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))
473-
theta = prior.sample((300,))
474-
x = simulator(theta)
475-
476-
inference = NPE(prior=prior)
477-
inference.append_simulations(theta, x)
478-
479-
inference.train(max_num_epochs=1)
480-
inference.build_posterior(posterior_parameters=RejectionPosteriorParameters())
481-
482-
483460
def get_multidim_simulator_embedding(num_dim: int = 5, embedding_dim: int = 10):
484461
"""Returns a simulator producing 2D matrix observations and a CNN embedding net."""
485462

tests/posterior_parameters_test.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88
import torch
99

10-
from sbi.inference import NPE, NRE
10+
from sbi.inference import FMPE, NLE, NPE, NPSE, NRE
1111
from sbi.inference.posteriors.direct_posterior import DirectPosterior
1212
from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior
1313
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
@@ -416,6 +416,7 @@ def test_if_warning_raised_for_deprecated_build_posterior_parameters(
416416
get_inference.build_posterior(**params)
417417

418418

419+
@pytest.mark.parametrize("inference_class", [NPE, NRE, NLE])
419420
@pytest.mark.parametrize(
420421
"sample_with",
421422
["mcmc", "vi", "rejection", "importance"],
@@ -424,27 +425,68 @@ def test_if_warning_raised_for_deprecated_build_posterior_parameters(
424425
raises=ValueError,
425426
reason="Prior required for non-direct sampling methods",
426427
)
427-
def test_resolve_prior_missing_prior_npe(sample_with):
428-
inference = NPE() # no prior
428+
def test_resolve_prior_missing_prior(inference_class, sample_with):
429+
"""
430+
Test that an error is raised when trying to build a posterior with a sampling method
431+
that requires a prior, but no prior is provided to the inference method.
432+
"""
433+
inference = inference_class() # no prior
429434

430-
theta = torch.randn(10, 2)
435+
theta = torch.randn(50, 2)
431436
x = theta + 0.1 * torch.randn_like(theta)
432437

433438
inference.append_simulations(theta, x).train(max_num_epochs=1)
434439

435440
inference.build_posterior(sample_with=sample_with)
436441

437442

438-
@pytest.mark.parametrize("sample_with", ["direct"])
439-
def test_resolve_prior_missing_prior_npe_direct_ok(sample_with):
440-
inference = NPE() # no prior
443+
VALID_SAMPLE_WITH_AND_INFERENCE_METHOD = [
444+
(NPE, "direct"),
445+
(NPE, "mcmc"),
446+
(NPE, "vi"),
447+
(NPE, "rejection"),
448+
(NPE, "importance"),
449+
(NLE, "mcmc"),
450+
(NLE, "vi"),
451+
(NLE, "rejection"),
452+
(NLE, "importance"),
453+
(NRE, "mcmc"),
454+
(NRE, "vi"),
455+
(NRE, "rejection"),
456+
(NRE, "importance"),
457+
(NPSE, "sde"),
458+
(NPSE, "ode"),
459+
(FMPE, "sde"),
460+
(FMPE, "ode"),
461+
]
441462

442-
theta = torch.randn(10, 2)
443-
x = theta + 0.1 * torch.randn_like(theta)
463+
464+
@pytest.mark.parametrize(
465+
"inference_method, sample_with", VALID_SAMPLE_WITH_AND_INFERENCE_METHOD
466+
)
467+
def test_inference_method_with_valid_sample_with(inference_method, sample_with):
468+
"""
469+
Test that the inference method works correctly with valid sample_with options.
470+
"""
471+
num_samples = 2
472+
dim_theta = 2
473+
num_simulations = 50
474+
475+
prior = BoxUniform(low=-2 * torch.ones(dim_theta), high=2 * torch.ones(dim_theta))
476+
inference = inference_method(prior=prior)
477+
478+
theta = prior.sample((num_simulations,))
479+
x = theta + torch.randn_like(theta) * 0.1
444480

445481
inference.append_simulations(theta, x).train(max_num_epochs=1)
446482

447483
posterior = inference.build_posterior(sample_with=sample_with)
448484

449-
samples = posterior.sample((5,), x=x[:1])
450-
assert samples.shape[0] == 5
485+
if sample_with == "vi":
486+
posterior.set_default_x(x[0])
487+
posterior.train(max_num_iters=5, show_progress_bar=False, quality_control=False)
488+
samples = posterior.sample((num_samples,))
489+
else:
490+
samples = posterior.sample((num_samples,), x=x[0])
491+
492+
assert samples.shape == torch.Size([num_samples, dim_theta])

0 commit comments

Comments
 (0)