diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 599622904..9f61e1674 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -16,7 +16,7 @@ ) from sbi.sbi_types import Array, Shape, TorchTransform from sbi.utils.sbiutils import gradient_ascent -from sbi.utils.torchutils import ensure_theta_batched, process_device +from sbi.utils.torchutils import assert_all_finite, ensure_theta_batched, process_device from sbi.utils.user_input_checks import process_x @@ -35,6 +35,7 @@ def __init__( theta_transform: Optional[TorchTransform] = None, device: Optional[Union[str, torch.device]] = None, x_shape: Optional[torch.Size] = None, + check_finite_x: bool = True, ): """ Args: @@ -61,6 +62,7 @@ def __init__( ) self._device = process_device(potential_fn.device if device is None else device) + self._check_finite_x = check_finite_x self.potential_fn = potential_fn @@ -180,15 +182,27 @@ def set_default_x(self, x: Tensor) -> "NeuralPosterior": Returns: `NeuralPosterior` that will use a default `x` when not explicitly passed. """ - self._x = process_x(x, x_event_shape=None).to(self._device) + x = process_x(x, x_event_shape=None) + + if self._check_finite_x: + assert_all_finite(x, "Observed data x_o contains Nans or Infs.") + + self._x = x.to(self._device) + self._map = None return self - + def _x_else_default_x(self, x: Optional[Array]) -> Tensor: if x is not None: # New x, reset posterior sampler. self._posterior_sampler = None - return process_x(x, x_event_shape=None) + x = process_x(x, x_event_shape=None) + + if self._check_finite_x: + assert_all_finite(x, "Observed data x_o contains Nans or Infs.") + + return x + elif self.default_x is None: raise ValueError( "Context `x` needed when a default has not been set." diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 2c1e4b35e..6ea1d2456 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -49,6 +49,7 @@ def __init__( device: Optional[Union[str, torch.device]] = None, x_shape: Optional[torch.Size] = None, enable_transform: bool = True, + check_finite_x: bool = True, ): """ Args: @@ -81,6 +82,7 @@ def __init__( theta_transform=theta_transform, device=device, x_shape=x_shape, + check_finite_x=check_finite_x, ) self.device = device @@ -127,6 +129,7 @@ def to(self, device: Union[str, torch.device]) -> None: theta_transform=theta_transform, device=device, x_shape=self.x_shape, + check_finite_x=self._check_finite_x, ) # super().__init__ erases the self._x, so we need to set it again if x_o is not None: diff --git a/sbi/inference/posteriors/posterior_parameters.py b/sbi/inference/posteriors/posterior_parameters.py index 6f10d4246..1e5d8df03 100644 --- a/sbi/inference/posteriors/posterior_parameters.py +++ b/sbi/inference/posteriors/posterior_parameters.py @@ -27,6 +27,7 @@ @dataclass(frozen=True) class PosteriorParameters(ABC): + check_finite_x: bool = True @abstractmethod def validate(self): """ diff --git a/sbi/inference/posteriors/vector_field_posterior.py b/sbi/inference/posteriors/vector_field_posterior.py index ef0ae6acd..330166713 100644 --- a/sbi/inference/posteriors/vector_field_posterior.py +++ b/sbi/inference/posteriors/vector_field_posterior.py @@ -59,6 +59,7 @@ def __init__( device: Optional[Union[str, torch.device]] = None, enable_transform: bool = True, sample_with: Literal["ode", "sde"] = "sde", + check_finite_x: bool = True, **kwargs, ): """ @@ -90,6 +91,7 @@ def __init__( potential_fn=potential_fn, theta_transform=theta_transform, device=device, + check_finite_x=check_finite_x, ) # Set the potential function type. self.potential_fn: VectorFieldBasedPotential = potential_fn @@ -138,6 +140,7 @@ def to(self, device: Union[str, torch.device]) -> None: potential_fn=potential_fn, theta_transform=theta_transform, device=device, + check_finite_x=self._check_finite_x, ) # super().__init__ erases the self._x, so we need to set it again if x_o is not None: diff --git a/sbi/utils/user_input_checks.py b/sbi/utils/user_input_checks.py index 56660fe73..e5d2e5d97 100644 --- a/sbi/utils/user_input_checks.py +++ b/sbi/utils/user_input_checks.py @@ -13,7 +13,7 @@ from sbi.sbi_types import Array from sbi.utils.sbiutils import within_support -from sbi.utils.torchutils import BoxUniform, assert_all_finite, atleast_2d +from sbi.utils.torchutils import BoxUniform, atleast_2d from sbi.utils.user_input_checks_utils import ( CustomPriorWrapper, MultipleIndependent, @@ -593,7 +593,9 @@ def batch_loop_simulator(theta: Tensor) -> Tensor: return batch_loop_simulator -def process_x(x: Array, x_event_shape: Optional[torch.Size] = None) -> Tensor: +def process_x( + x: Array, x_event_shape: Optional[torch.Size] = None +) -> Tensor: """Return observed data adapted to match sbi's shape and type requirements. This means that `x` is returned with a `batch_dim`. @@ -611,8 +613,6 @@ def process_x(x: Array, x_event_shape: Optional[torch.Size] = None) -> Tensor: """ x = atleast_2d(torch.as_tensor(x, dtype=float32)) - assert_all_finite(x, "Observed data x_o contains Nans or Infs.") - if x_event_shape is not None and len(x_event_shape) > len(x.shape): raise ValueError( f"You passed an `x` of shape {x.shape} but the `x_event_shape` (inferred " diff --git a/tests/embedding_net_test.py b/tests/embedding_net_test.py index 4be610c0b..9417f815c 100644 --- a/tests/embedding_net_test.py +++ b/tests/embedding_net_test.py @@ -13,7 +13,7 @@ from sbi import utils from sbi.inference import NLE, NPE, NRE, simulate_for_sbi -from sbi.inference.posteriors.posterior_parameters import MCMCPosteriorParameters +from sbi.inference.posteriors.posterior_parameters import MCMCPosteriorParameters , DirectPosteriorParameters from sbi.neural_nets import classifier_nn, likelihood_nn, posterior_nn from sbi.neural_nets.embedding_nets import ( CNNEmbedding, @@ -401,11 +401,6 @@ def simulator1d(theta): @pytest.mark.slow -@pytest.mark.xfail( - raises=ValueError, - reason="Padding with NaNs causes error in new NaN check on x_o, see #1701, #1717", - strict=True, -) def test_npe_with_with_iid_embedding_varying_num_trials(trial_factor=50): """Test inference accuracy with embeddings for varying number of trials. @@ -455,7 +450,9 @@ def test_npe_with_with_iid_embedding_varying_num_trials(trial_factor=50): _ = inference.append_simulations(theta, x, exclude_invalid_x=False).train( training_batch_size=100 ) - posterior = inference.build_posterior() + posterior = inference.build_posterior( + posterior_parameters=DirectPosteriorParameters(check_finite_x=False) + ) num_samples = 1000 # test different number of trials diff --git a/tests/user_input_checks_test.py b/tests/user_input_checks_test.py index 0103472a1..44bace22a 100644 --- a/tests/user_input_checks_test.py +++ b/tests/user_input_checks_test.py @@ -19,8 +19,11 @@ Uniform, ) -from sbi.inference import NPE_A, NPE_C, simulate_for_sbi +from sbi.inference import NPE, NPE_A, NPE_C, simulate_for_sbi from sbi.inference.posteriors.direct_posterior import DirectPosterior +from sbi.inference.posteriors.posterior_parameters import ( + DirectPosteriorParameters, +) from sbi.simulators import linear_gaussian from sbi.simulators.linear_gaussian import diagonal_linear_gaussian from sbi.utils import mcmc_transform, within_support @@ -210,25 +213,57 @@ def test_process_prior(prior): (ones(10, 3), torch.Size([10, 3])), # 2D data / iid NPE pytest.param(ones(10, 3), None), # 2D data / iid NPE without x_shape (ones(10, 10), torch.Size([10])), # iid likelihood based - pytest.param( + ( torch.cat([ones(3), torch.tensor([float("nan")]), ones(3)]), # contains nan torch.Size([7]), - marks=pytest.mark.xfail( - reason="process_x must raise error if x contains NaNs or Infs." - ), + ), - pytest.param( + ( # contains inf torch.cat([ones(3), torch.tensor([float("inf")]), ones(3)]).expand(10, -1), torch.Size([7]), - marks=pytest.mark.xfail( - reason="process_x must raise error if x contains NaNs or Infs." - ), + ), ), ) def test_process_x(x, x_shape): process_x(x, x_shape) + +def test_set_default_x_check_finite(): + prior = BoxUniform(zeros(2), ones(2)) + + inference = NPE_C( + prior=prior, + density_estimator="maf", + show_progress_bars=False, + ) + + theta = prior.sample((100,)) + x = torch.randn(100, 2) + + posterior_estimator = ( + inference.append_simulations(theta, x) + .train(max_num_epochs=1) + ) + + x_with_nan = torch.tensor([0.0, float("nan")]) + + posterior = DirectPosterior( + posterior_estimator=posterior_estimator, + prior=prior, + ) + + with pytest.raises(ValueError): + posterior.set_default_x(x_with_nan) + + posterior = inference.build_posterior( + posterior_parameters=DirectPosteriorParameters( + check_finite_x=False + ), +) + + posterior.set_default_x(x_with_nan) + @pytest.mark.parametrize( @@ -247,6 +282,7 @@ def test_process_x(x, x_shape): (lambda _: torch.randn(10, 2), BoxUniform(zeros(2), ones(2)), (10, 2)), ), ) + def test_process_simulator(simulator: Callable, prior: Distribution, x_shape: Tuple): prior, theta_dim, prior_returns_numpy = process_prior(prior) simulator = process_simulator(simulator, prior, prior_returns_numpy)