Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from sbi.sbi_types import Array, Shape, TorchTransform
from sbi.utils.sbiutils import gradient_ascent
from sbi.utils.torchutils import ensure_theta_batched, process_device
from sbi.utils.torchutils import assert_all_finite, ensure_theta_batched, process_device
from sbi.utils.user_input_checks import process_x


Expand All @@ -35,6 +35,7 @@ def __init__(
theta_transform: Optional[TorchTransform] = None,
device: Optional[Union[str, torch.device]] = None,
x_shape: Optional[torch.Size] = None,
check_finite_x: bool = True,
):
"""
Args:
Expand All @@ -61,6 +62,7 @@ def __init__(
)

self._device = process_device(potential_fn.device if device is None else device)
self._check_finite_x = check_finite_x

self.potential_fn = potential_fn

Expand Down Expand Up @@ -180,15 +182,27 @@ def set_default_x(self, x: Tensor) -> "NeuralPosterior":
Returns:
`NeuralPosterior` that will use a default `x` when not explicitly passed.
"""
self._x = process_x(x, x_event_shape=None).to(self._device)
x = process_x(x, x_event_shape=None)

if self._check_finite_x:
assert_all_finite(x, "Observed data x_o contains Nans or Infs.")

self._x = x.to(self._device)

self._map = None
return self

def _x_else_default_x(self, x: Optional[Array]) -> Tensor:
if x is not None:
# New x, reset posterior sampler.
self._posterior_sampler = None
return process_x(x, x_event_shape=None)
x = process_x(x, x_event_shape=None)

if self._check_finite_x:
assert_all_finite(x, "Observed data x_o contains Nans or Infs.")

return x

elif self.default_x is None:
raise ValueError(
"Context `x` needed when a default has not been set."
Expand Down
3 changes: 3 additions & 0 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
device: Optional[Union[str, torch.device]] = None,
x_shape: Optional[torch.Size] = None,
enable_transform: bool = True,
check_finite_x: bool = True,
):
"""
Args:
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(
theta_transform=theta_transform,
device=device,
x_shape=x_shape,
check_finite_x=check_finite_x,
)

self.device = device
Expand Down Expand Up @@ -127,6 +129,7 @@ def to(self, device: Union[str, torch.device]) -> None:
theta_transform=theta_transform,
device=device,
x_shape=self.x_shape,
check_finite_x=self._check_finite_x,
)
# super().__init__ erases the self._x, so we need to set it again
if x_o is not None:
Expand Down
1 change: 1 addition & 0 deletions sbi/inference/posteriors/posterior_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

@dataclass(frozen=True)
class PosteriorParameters(ABC):
check_finite_x: bool = True
@abstractmethod
def validate(self):
"""
Expand Down
3 changes: 3 additions & 0 deletions sbi/inference/posteriors/vector_field_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
device: Optional[Union[str, torch.device]] = None,
enable_transform: bool = True,
sample_with: Literal["ode", "sde"] = "sde",
check_finite_x: bool = True,
**kwargs,
):
"""
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
potential_fn=potential_fn,
theta_transform=theta_transform,
device=device,
check_finite_x=check_finite_x,
)
# Set the potential function type.
self.potential_fn: VectorFieldBasedPotential = potential_fn
Expand Down Expand Up @@ -138,6 +140,7 @@ def to(self, device: Union[str, torch.device]) -> None:
potential_fn=potential_fn,
theta_transform=theta_transform,
device=device,
check_finite_x=self._check_finite_x,
)
# super().__init__ erases the self._x, so we need to set it again
if x_o is not None:
Expand Down
8 changes: 4 additions & 4 deletions sbi/utils/user_input_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from sbi.sbi_types import Array
from sbi.utils.sbiutils import within_support
from sbi.utils.torchutils import BoxUniform, assert_all_finite, atleast_2d
from sbi.utils.torchutils import BoxUniform, atleast_2d
from sbi.utils.user_input_checks_utils import (
CustomPriorWrapper,
MultipleIndependent,
Expand Down Expand Up @@ -593,7 +593,9 @@ def batch_loop_simulator(theta: Tensor) -> Tensor:
return batch_loop_simulator


def process_x(x: Array, x_event_shape: Optional[torch.Size] = None) -> Tensor:
def process_x(
x: Array, x_event_shape: Optional[torch.Size] = None
) -> Tensor:
"""Return observed data adapted to match sbi's shape and type requirements.

This means that `x` is returned with a `batch_dim`.
Expand All @@ -611,8 +613,6 @@ def process_x(x: Array, x_event_shape: Optional[torch.Size] = None) -> Tensor:
"""

x = atleast_2d(torch.as_tensor(x, dtype=float32))
assert_all_finite(x, "Observed data x_o contains Nans or Infs.")

if x_event_shape is not None and len(x_event_shape) > len(x.shape):
raise ValueError(
f"You passed an `x` of shape {x.shape} but the `x_event_shape` (inferred "
Expand Down
11 changes: 4 additions & 7 deletions tests/embedding_net_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from sbi import utils
from sbi.inference import NLE, NPE, NRE, simulate_for_sbi
from sbi.inference.posteriors.posterior_parameters import MCMCPosteriorParameters
from sbi.inference.posteriors.posterior_parameters import MCMCPosteriorParameters , DirectPosteriorParameters
from sbi.neural_nets import classifier_nn, likelihood_nn, posterior_nn
from sbi.neural_nets.embedding_nets import (
CNNEmbedding,
Expand Down Expand Up @@ -401,11 +401,6 @@ def simulator1d(theta):


@pytest.mark.slow
@pytest.mark.xfail(
raises=ValueError,
reason="Padding with NaNs causes error in new NaN check on x_o, see #1701, #1717",
strict=True,
)
def test_npe_with_with_iid_embedding_varying_num_trials(trial_factor=50):
"""Test inference accuracy with embeddings for varying number of trials.

Expand Down Expand Up @@ -455,7 +450,9 @@ def test_npe_with_with_iid_embedding_varying_num_trials(trial_factor=50):
_ = inference.append_simulations(theta, x, exclude_invalid_x=False).train(
training_batch_size=100
)
posterior = inference.build_posterior()
posterior = inference.build_posterior(
posterior_parameters=DirectPosteriorParameters(check_finite_x=False)
)

num_samples = 1000
# test different number of trials
Expand Down
54 changes: 45 additions & 9 deletions tests/user_input_checks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
Uniform,
)

from sbi.inference import NPE_A, NPE_C, simulate_for_sbi
from sbi.inference import NPE, NPE_A, NPE_C, simulate_for_sbi
from sbi.inference.posteriors.direct_posterior import DirectPosterior
from sbi.inference.posteriors.posterior_parameters import (
DirectPosteriorParameters,
)
from sbi.simulators import linear_gaussian
from sbi.simulators.linear_gaussian import diagonal_linear_gaussian
from sbi.utils import mcmc_transform, within_support
Expand Down Expand Up @@ -210,25 +213,57 @@ def test_process_prior(prior):
(ones(10, 3), torch.Size([10, 3])), # 2D data / iid NPE
pytest.param(ones(10, 3), None), # 2D data / iid NPE without x_shape
(ones(10, 10), torch.Size([10])), # iid likelihood based
pytest.param(
(
torch.cat([ones(3), torch.tensor([float("nan")]), ones(3)]), # contains nan
torch.Size([7]),
marks=pytest.mark.xfail(
reason="process_x must raise error if x contains NaNs or Infs."
),

),
pytest.param(
(
# contains inf
torch.cat([ones(3), torch.tensor([float("inf")]), ones(3)]).expand(10, -1),
torch.Size([7]),
marks=pytest.mark.xfail(
reason="process_x must raise error if x contains NaNs or Infs."
),

),
),
)
def test_process_x(x, x_shape):
process_x(x, x_shape)

def test_set_default_x_check_finite():
prior = BoxUniform(zeros(2), ones(2))

inference = NPE_C(
prior=prior,
density_estimator="maf",
show_progress_bars=False,
)

theta = prior.sample((100,))
x = torch.randn(100, 2)

posterior_estimator = (
inference.append_simulations(theta, x)
.train(max_num_epochs=1)
)

x_with_nan = torch.tensor([0.0, float("nan")])

posterior = DirectPosterior(
posterior_estimator=posterior_estimator,
prior=prior,
)

with pytest.raises(ValueError):
posterior.set_default_x(x_with_nan)

posterior = inference.build_posterior(
posterior_parameters=DirectPosteriorParameters(
check_finite_x=False
),
)

posterior.set_default_x(x_with_nan)



@pytest.mark.parametrize(
Expand All @@ -247,6 +282,7 @@ def test_process_x(x, x_shape):
(lambda _: torch.randn(10, 2), BoxUniform(zeros(2), ones(2)), (10, 2)),
),
)

def test_process_simulator(simulator: Callable, prior: Distribution, x_shape: Tuple):
prior, theta_dim, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)
Expand Down