diff --git a/sbi/inference/trainers/npe/npe_a.py b/sbi/inference/trainers/npe/npe_a.py
index e263a5797..4cb080c30 100644
--- a/sbi/inference/trainers/npe/npe_a.py
+++ b/sbi/inference/trainers/npe/npe_a.py
@@ -462,28 +462,7 @@ def build_posterior(
return self._posterior
- def _log_prob_proposal_posterior(
- self,
- theta: Tensor,
- x: Tensor,
- masks: Tensor,
- proposal: Optional[Any],
- ) -> Tensor:
- """Return the log-probability of the proposal posterior.
-
- For SNPE-A this is the same as `self._neural_net.log_prob(theta, x)` in
- `_loss()` to be found in `snpe_base.py`.
- Args:
- theta: Batch of parameters θ.
- x: Batch of data.
- masks: Mask that is True for prior samples in the batch in order to train
- them with prior loss.
- proposal: Proposal distribution.
-
- Returns: Log-probability of the proposal posterior.
- """
- return self._neural_net.log_prob(theta, x)
def _correct_for_proposal(
diff --git a/sbi/inference/trainers/npe/npe_b.py b/sbi/inference/trainers/npe/npe_b.py
index 2ba1f2461..de6c8069b 100644
--- a/sbi/inference/trainers/npe/npe_b.py
+++ b/sbi/inference/trainers/npe/npe_b.py
@@ -1,22 +1,22 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see
-from typing import Any, Literal, Optional, Union
+from typing import Callable, Dict, Literal, Optional, Union
-import torch
-from torch import Tensor
from torch.distributions import Distribution
from torch.utils.tensorboard.writer import SummaryWriter
-import sbi.utils as utils
from sbi.inference.trainers.npe.npe_base import (
PosteriorEstimatorTrainer,
)
+from sbi.inference.trainers.npe.npe_loss import (
+ ImportanceWeightedLoss,
+ NPELossStrategy,
+)
from sbi.neural_nets.estimators.base import (
ConditionalDensityEstimator,
ConditionalEstimatorBuilder,
)
-from sbi.neural_nets.estimators.shape_handling import reshape_to_sample_batch_event
from sbi.sbi_types import Tracker
from sbi.utils.sbiutils import del_entries
@@ -107,72 +107,48 @@ def __init__(
kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)
- def _log_prob_proposal_posterior(
+ def train(
self,
- theta: Tensor,
- x: Tensor,
- masks: Tensor,
- proposal: Optional[Any],
- ) -> Tensor:
- """
- Return importance-weighted log probability (Lueckmann, Goncalves et al., 2017).
-
- Args:
- theta: Batch of parameters θ.
- x: Batch of data.
- masks: Mask that is True for prior samples in the batch in order to train
- them with prior loss.
- proposal: Proposal distribution.
-
- Returns:
- Importance-weighted log-probability of the proposal posterior.
- """
-
- # Evaluate prior
- # we accept prior log prob to be -Inf at theta
- # meaning that theta is out of the prior range (the weight is thus 0)
- utils.assert_not_nan_or_plus_inf(
- self._prior.log_prob(theta), "prior log probs of proposal samples"
+ training_batch_size: int = 200,
+ learning_rate: float = 5e-4,
+ validation_fraction: float = 0.1,
+ stop_after_epochs: int = 20,
+ max_num_epochs: int = 2**31 - 1,
+ clip_max_norm: Optional[float] = 5.0,
+ calibration_kernel: Optional[Callable] = None,
+ resume_training: bool = False,
+ force_first_round_loss: bool = False,
+ discard_prior_samples: bool = False,
+ retrain_from_scratch: bool = False,
+ show_train_summary: bool = False,
+ dataloader_kwargs: Optional[Dict] = None,
+ loss_strategy: Optional[NPELossStrategy] = None,
+ ) -> ConditionalDensityEstimator:
+
+ kwargs = del_entries(
+ locals(),
+ entries=("self", "__class__", "loss_strategy"),
)
- prior = torch.exp(self._prior.log_prob(theta))
-
- # Evaluate proposal
- # (as theta comes from prior and proposal from previous rounds,
- # the last proposal is actually a mixture of the prior
- # and of all the previous proposals with coefficients representing
- # the proportion of the new theta added at each round)
- prop = torch.zeros(self._round + 1, device=theta.device)
- nb_samples = 0 # total number of theta from all the rounds
-
- for k in range(self._round + 1):
- nb_samples += self._theta_roundwise[k].size(0)
- # the number of new theta sampled in the round k
- prop[k] = self._theta_roundwise[k].size(0)
-
- prop /= nb_samples
- log_prop = torch.log(prop).repeat(theta.size(0), 1)
-
- log_previous_proposals = torch.zeros(
- (theta.size(0), self._round + 1), device=theta.device
- )
- for k, density in enumerate(self._proposal_roundwise):
- # we accept the k th proposal log prob to be -Inf at theta
- # meaning that theta is out of the k th proposal range
- log_previous_proposals[:, k] = density.log_prob(theta)
- utils.assert_not_nan_or_plus_inf(
- log_previous_proposals[:, k], "proposal log probs of proposal samples"
- )
-
- log_proposal = torch.logsumexp(log_prop + log_previous_proposals, dim=1)
- proposal = torch.exp(log_proposal)
- # Construct the importance weights and normalize them
- importance_weights = prior / proposal
- importance_weights /= importance_weights.sum()
+ if len(self._data_round_index) == 0:
+ raise RuntimeError(
+ "No simulations found. You must call .append_simulations() "
+ "before calling .train()."
+ )
- theta = reshape_to_sample_batch_event(theta, theta.shape[1:])
- # Reshape the density estimator log probs
- # from (sample_shape, batch_shape) to (batch_shape)
- posterior_log_probs = self._neural_net.log_prob(theta, x).squeeze(dim=0)
+ self._round = max(self._data_round_index)
+
+ if loss_strategy is not None:
+ self._loss_strategy = loss_strategy
+ elif self._round > 0:
+ self._loss_strategy = ImportanceWeightedLoss(
+ neural_net=self._neural_net,
+ prior=self._prior,
+ round_idx=self._round,
+ theta_roundwise=self._theta_roundwise,
+ proposal_roundwise=self._proposal_roundwise,
+ )
+ else:
+ self._loss_strategy = None
- return importance_weights * posterior_log_probs
+ return super().train(**kwargs)
diff --git a/sbi/inference/trainers/npe/npe_base.py b/sbi/inference/trainers/npe/npe_base.py
index 4105bf72f..8f168500c 100644
--- a/sbi/inference/trainers/npe/npe_base.py
+++ b/sbi/inference/trainers/npe/npe_base.py
@@ -2,7 +2,7 @@
# under the Apache License Version 2.0, see
import warnings
-from abc import ABC, abstractmethod
+from abc import ABC
from dataclasses import asdict
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Tuple, Union
@@ -32,6 +32,7 @@
NeuralInference,
check_if_proposal_has_default_x,
)
+from sbi.inference.trainers.npe.npe_loss import NPELossStrategy
from sbi.neural_nets import posterior_nn
from sbi.neural_nets.estimators import ConditionalDensityEstimator
from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
@@ -111,16 +112,9 @@ def __init__(
self._build_neural_net = density_estimator
self._proposal_roundwise = []
- self.use_non_atomic_loss = False
+ self._loss_strategy: Optional[NPELossStrategy] = None
+
- @abstractmethod
- def _log_prob_proposal_posterior(
- self,
- theta: Tensor,
- x: Tensor,
- masks: Tensor,
- proposal: Optional[Any],
- ) -> Tensor: ...
def append_simulations(
self,
@@ -194,10 +188,11 @@ def append_simulations(
# Check for problematic z-scoring
warn_if_invalid_for_zscoring(x)
+ is_atomic = self._loss_strategy is None or not getattr(self._loss_strategy, "uses_only_latest_round", True)
if (
- type(self).__name__ == "SNPE_C"
+ type(self).__name__ in ("SNPE_C", "NPE_C")
and current_round > 0
- and not self.use_non_atomic_loss
+ and is_atomic
):
nle_nre_apt_msg_on_invalid_x(
num_nans,
@@ -254,6 +249,7 @@ def train(
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[dict] = None,
+ loss_strategy: Optional[NPELossStrategy] = None,
) -> ConditionalDensityEstimator:
r"""Return density estimator that approximates the distribution $p(\theta|x)$.
@@ -550,9 +546,11 @@ def _loss(
# Use posterior log prob (without proposal correction) for first round.
loss = self._neural_net.loss(theta, x)
else:
+ if self._loss_strategy is None:
+ raise ValueError("Loss strategy is None but round > 0.")
# Currently only works for `DensityEstimator` objects.
# Must be extended ones other Estimators are implemented. See #966,
- loss = -self._log_prob_proposal_posterior(theta, x, masks, proposal)
+ loss = -self._loss_strategy(theta, x, masks, proposal)
assert_all_finite(loss, "NPE loss")
return calibration_kernel(x) * loss
@@ -650,7 +648,10 @@ def _get_start_index(self, context: StartIndexContext) -> int:
# SNPE-A can, by construction of the algorithm, only use samples from the last
# round. SNPE-A is the only algorithm that has an attribute `_ran_final_round`,
# so this is how we check for whether or not we are using SNPE-A.
- if self.use_non_atomic_loss or hasattr(self, "_ran_final_round"):
+ if (
+ self._loss_strategy is not None
+ and self._loss_strategy.uses_only_latest_round
+ ) or hasattr(self, "_ran_final_round"):
start_idx = self._round
return start_idx
diff --git a/sbi/inference/trainers/npe/npe_c.py b/sbi/inference/trainers/npe/npe_c.py
index 9accdc2c6..bad251fcb 100644
--- a/sbi/inference/trainers/npe/npe_c.py
+++ b/sbi/inference/trainers/npe/npe_c.py
@@ -4,7 +4,6 @@
from typing import Callable, Dict, Literal, Optional, Union
import torch
-from torch import Tensor, eye, ones
from torch.distributions import Distribution, MultivariateNormal, Uniform
from torch.utils.tensorboard.writer import SummaryWriter
@@ -12,6 +11,11 @@
from sbi.inference.trainers.npe.npe_base import (
PosteriorEstimatorTrainer,
)
+from sbi.inference.trainers.npe.npe_loss import (
+ AtomicLoss,
+ NPELossStrategy,
+ NonAtomicGaussianLoss,
+)
from sbi.neural_nets.estimators.base import (
ConditionalDensityEstimator,
ConditionalEstimatorBuilder,
@@ -19,21 +23,12 @@
from sbi.neural_nets.estimators.mixture_density_estimator import (
MixtureDensityEstimator,
)
-from sbi.neural_nets.estimators.mog import MoG
-from sbi.neural_nets.estimators.shape_handling import (
- reshape_to_batch_event,
- reshape_to_sample_batch_event,
-)
from sbi.sbi_types import Tracker
from sbi.utils import (
- batched_mixture_mv,
- batched_mixture_vmv,
check_dist_class,
- clamp_and_warn,
del_entries,
- repeat_rows,
)
-from sbi.utils.torchutils import BoxUniform, assert_all_finite
+from sbi.utils.torchutils import BoxUniform
class NPE_C(PosteriorEstimatorTrainer):
@@ -141,6 +136,7 @@ def train(
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
+ loss_strategy: Optional[NPELossStrategy] = None,
) -> ConditionalDensityEstimator:
r"""Return density estimator that approximates the distribution $p(\theta|x)$.
@@ -183,18 +179,12 @@ def train(
Returns:
Density estimator that approximates the distribution $p(\theta|x)$.
"""
-
if len(self._data_round_index) == 0:
raise RuntimeError(
"No simulations found. You must call .append_simulations() "
"before calling .train()."
)
- # WARNING: sneaky trick ahead. We proxy the parent's `train` here,
- # requiring the signature to have `num_atoms`, save it for use below, and
- # continue. It's sneaky because we are using the object (self) as a namespace
- # to pass arguments between functions, and that's implicit state management.
- self._num_atoms = num_atoms
- self._use_combined_loss = use_combined_loss
+
kwargs = del_entries(
locals(),
entries=("self", "__class__", "num_atoms", "use_combined_loss"),
@@ -202,13 +192,12 @@ def train(
self._round = max(self._data_round_index)
- if self._round > 0:
- # Set the proposal to the last proposal that was passed by the user. For
- # atomic SNPE, it does not matter what the proposal is. For non-atomic
- # SNPE, we only use the latest data that was passed, i.e. the one from the
- # last proposal.
+ if loss_strategy is not None:
+ self._loss_strategy = loss_strategy
+ elif self._round > 0:
+ # Set the proposal to the last proposal that was passed by the user.
proposal = self._proposal_roundwise[-1]
- self.use_non_atomic_loss = (
+ use_non_atomic_loss = (
isinstance(proposal, DirectPosterior)
and isinstance(proposal.posterior_estimator, MixtureDensityEstimator)
and isinstance(self._neural_net, MixtureDensityEstimator)
@@ -217,75 +206,69 @@ def train(
)[0]
)
- algorithm = "non-atomic" if self.use_non_atomic_loss else "atomic"
+ algorithm = "non-atomic" if use_non_atomic_loss else "atomic"
print(f"Using SNPE-C with {algorithm} loss")
- if self.use_non_atomic_loss:
+ if use_non_atomic_loss:
# Take care of z-scoring, pre-compute and store prior terms.
self._set_state_for_mog_proposal()
+ # Instantiate Non-Atomic Strategy
+ if isinstance(self._maybe_z_scored_prior, MultivariateNormal):
+ prec_m_prod_prior = torch.mv(
+ self._maybe_z_scored_prior.precision_matrix,
+ self._maybe_z_scored_prior.loc,
+ )
+ else:
+ prec_m_prod_prior = None
+
+ self._loss_strategy = NonAtomicGaussianLoss(
+ neural_net=self._neural_net,
+ maybe_z_scored_prior=self._maybe_z_scored_prior,
+ prec_m_prod_prior=prec_m_prod_prior,
+ z_score_theta=self.z_score_theta,
+ )
+ else:
+ # Instantiate Atomic Strategy
+ self._loss_strategy = AtomicLoss(
+ neural_net=self._neural_net,
+ prior=self._prior,
+ num_atoms=num_atoms,
+ use_combined_loss=use_combined_loss,
+ )
+ else:
+ # Default to None for first round (equivalent to MLE)
+ self._loss_strategy = None
+
return super().train(**kwargs)
def _set_state_for_mog_proposal(self) -> None:
"""Set state variables that are used at each training step of non-atomic SNPE-C.
Three things are computed:
- 1) Check if z-scoring was requested. We check if the MixtureDensityEstimator
- has an input transform enabled via the has_input_transform property.
- 2) Define a (potentially standardized) prior. It's standardized if z-scoring
- had been requested.
- 3) Compute (Precision * mean) for the prior. This quantity is used at every
- training step if the prior is Gaussian.
+ 1) Check if z-scoring was requested.
+ 2) Define a (potentially standardized) prior.
+ 3) Compute (Precision * mean) for the prior.
"""
- # Check if z-scoring is enabled on the MixtureDensityEstimator
assert isinstance(self._neural_net, MixtureDensityEstimator)
self.z_score_theta = self._neural_net.has_input_transform
self._set_maybe_z_scored_prior()
- if isinstance(self._maybe_z_scored_prior, MultivariateNormal):
- self.prec_m_prod_prior = torch.mv(
- self._maybe_z_scored_prior.precision_matrix, # type: ignore
- self._maybe_z_scored_prior.loc, # type: ignore
- )
-
def _set_maybe_z_scored_prior(self) -> None:
- r"""Compute and store potentially standardized prior (if z-scoring was done).
-
- The proposal posterior is:
- $pp(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$
-
- Let's denote z-scored theta by `a`: a = (theta - mean) / std
- Then pp'(a|x) = 1/Z_2 * q'(a|x) * prop'(a) / p'(a)$
-
- The ' indicates that the evaluation occurs in standardized space. The constant
- scaling factor has been absorbed into Z_2.
- From the above equation, we see that we need to evaluate the prior **in
- standardized space**. We build the standardized prior in this function.
-
- The standardize transform that is applied to the samples theta does not use
- the exact prior mean and std (due to implementation issues). Hence, the z-scored
- prior will not be exactly have mean=0 and std=1.
- """
+ r"""Compute and store potentially standardized prior (if z-scoring was done)."""
if self.z_score_theta:
# Get z-score parameters from the MixtureDensityEstimator
- # The transform is: z = (theta - shift) / scale
- # where shift = mean (estimated from samples) and scale = std (estimated)
assert isinstance(self._neural_net, MixtureDensityEstimator)
shift = self._neural_net._transform_shift
scale = self._neural_net._transform_scale
- # The MixtureDensityEstimator uses: z = (theta - shift) / scale
- # where shift = mean and scale = std (estimated from training data)
estim_prior_mean = shift
estim_prior_std = scale
# Compute the discrepancy of the true prior mean and std and the mean and
# std that was empirically estimated from samples.
- # N(theta|m,s) = N((theta-m_e)/s_e|(m-m_e)/s_e, s/s_e)
- # Above: m,s are true prior mean and std. m_e,s_e are estimated prior mean
- # and std (estimated from samples and used to build standardize transform).
almost_zero_mean = (self._prior.mean - estim_prior_mean) / estim_prior_std
almost_one_std = torch.sqrt(self._prior.variance) / estim_prior_std
@@ -301,415 +284,4 @@ def _set_maybe_z_scored_prior(self) -> None:
else:
self._maybe_z_scored_prior = self._prior
- def _log_prob_proposal_posterior(
- self,
- theta: Tensor,
- x: Tensor,
- masks: Tensor,
- proposal: DirectPosterior,
- ) -> Tensor:
- """Return the log-probability of the proposal posterior.
-
- If the proposal is a MoG, the density estimator is a MoG, and the prior is
- either Gaussian or uniform, we use non-atomic loss. Else, use atomic loss (which
- suffers from leakage).
-
- Args:
- theta: Batch of parameters θ.
- x: Batch of data.
- masks: Mask that is True for prior samples in the batch in order to train
- them with prior loss.
- proposal: Proposal distribution.
-
- Returns: Log-probability of the proposal posterior.
- """
-
- if self.use_non_atomic_loss:
- if not isinstance(self._neural_net, MixtureDensityEstimator):
- raise ValueError(
- "The density estimator must be a MixtureDensityEstimator "
- "for non-atomic loss."
- )
-
- return self._log_prob_proposal_posterior_mog(theta, x, proposal)
- else:
- if not hasattr(self._neural_net, "log_prob"):
- raise ValueError(
- "The neural estimator must have a log_prob method, for\
- atomic loss. It should at best follow the \
- sbi.neural_nets 'DensityEstiamtor' interface."
- )
- return self._log_prob_proposal_posterior_atomic(theta, x, masks)
-
- def _log_prob_proposal_posterior_atomic(
- self, theta: Tensor, x: Tensor, masks: Tensor
- ):
- """Return log probability of the proposal posterior for atomic proposals.
-
- We have two main options when evaluating the proposal posterior.
- (1) Generate atoms from the proposal prior.
- (2) Generate atoms from a more targeted distribution, such as the most
- recent posterior.
- If we choose the latter, it is likely beneficial not to do this in the first
- round, since we would be sampling from a randomly-initialized neural density
- estimator.
-
- Args:
- theta: Batch of parameters θ.
- x: Batch of data.
- masks: Mask that is True for prior samples in the batch in order to train
- them with prior loss.
-
- Returns:
- Log-probability of the proposal posterior.
- """
- batch_size = theta.shape[0]
-
- num_atoms = int(
- clamp_and_warn("num_atoms", self._num_atoms, min_val=2, max_val=batch_size)
- )
-
- # Each set of parameter atoms is evaluated using the same x,
- # so we repeat rows of the data x, e.g. [1, 2] -> [1, 1, 2, 2]
- repeated_x = repeat_rows(x, num_atoms)
-
- # To generate the full set of atoms for a given item in the batch,
- # we sample without replacement num_atoms - 1 times from the rest
- # of the theta in the batch.
- probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1)
-
- choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False)
- contrasting_theta = theta[choices]
-
- # We can now create our sets of atoms from the contrasting parameter sets
- # we have generated.
- atomic_theta = torch.cat((theta[:, None, :], contrasting_theta), dim=1).reshape(
- batch_size * num_atoms, -1
- )
-
- # Get (batch_size * num_atoms) log prob prior evals.
- log_prob_prior = self._prior.log_prob(atomic_theta)
- log_prob_prior = log_prob_prior.reshape(batch_size, num_atoms)
- assert_all_finite(log_prob_prior, "prior eval")
-
- # Evaluate large batch giving (batch_size * num_atoms) log prob posterior evals.
- atomic_theta = reshape_to_sample_batch_event(
- atomic_theta, atomic_theta.shape[1:]
- )
- repeated_x = reshape_to_batch_event(
- repeated_x, self._neural_net.condition_shape
- )
- log_prob_posterior = self._neural_net.log_prob(atomic_theta, repeated_x)
- assert_all_finite(log_prob_posterior, "posterior eval")
- log_prob_posterior = log_prob_posterior.reshape(batch_size, num_atoms)
-
- # Compute unnormalized proposal posterior.
- unnormalized_log_prob = log_prob_posterior - log_prob_prior
-
- # Normalize proposal posterior across discrete set of atoms.
- log_prob_proposal_posterior = unnormalized_log_prob[:, 0] - torch.logsumexp(
- unnormalized_log_prob, dim=-1
- )
- assert_all_finite(log_prob_proposal_posterior, "proposal posterior eval")
-
- # XXX This evaluates the posterior on _all_ prior samples
- if self._use_combined_loss:
- theta = reshape_to_sample_batch_event(theta, self._neural_net.input_shape)
- x = reshape_to_batch_event(x, self._neural_net.condition_shape)
- log_prob_posterior_non_atomic = self._neural_net.log_prob(theta, x)
- # squeeze to remove sample dimension, which is always one during the loss
- # evaluation of `SNPE_C` (because we have one theta vector per x vector).
- log_prob_posterior_non_atomic = log_prob_posterior_non_atomic.squeeze(dim=0)
- masks = masks.reshape(-1)
- log_prob_proposal_posterior = (
- masks * log_prob_posterior_non_atomic + log_prob_proposal_posterior
- )
-
- return log_prob_proposal_posterior
-
- def _log_prob_proposal_posterior_mog(
- self, theta: Tensor, x: Tensor, proposal: DirectPosterior
- ) -> Tensor:
- """Return log-probability of the proposal posterior for MoG proposal.
-
- For MoG proposals and MoG density estimators, this can be done in closed form
- and does not require atomic loss (i.e. there will be no leakage issues).
-
- Notation:
-
- m are mean vectors.
- prec are precision matrices.
- cov are covariance matrices.
-
- _p at the end indicates that it is the proposal.
- _d indicates that it is the density estimator.
- _pp indicates the proposal posterior.
-
- All tensors will have shapes (batch_dim, num_components, ...)
-
- Args:
- theta: Batch of parameters θ.
- x: Batch of data.
- proposal: Proposal distribution.
-
- Returns:
- Log-probability of the proposal posterior.
- """
- # Get the proposal MoG at the default_x
- assert isinstance(proposal.posterior_estimator, MixtureDensityEstimator)
- assert proposal.default_x is not None, "Proposal must have default_x set"
- mog_p = proposal.posterior_estimator.get_uncorrected_mog(proposal.default_x)
- norm_logits_p = mog_p.log_weights # Already normalized
- m_p = mog_p.means
- prec_p = mog_p.precisions
-
- # Get the density estimator MoG at the training data x
- assert isinstance(self._neural_net, MixtureDensityEstimator)
- mog_d = self._neural_net.get_uncorrected_mog(x)
- norm_logits_d = mog_d.log_weights # Already normalized
- m_d = mog_d.means
- prec_d = mog_d.precisions
-
- # z-score theta if z-scoring was requested.
- theta = self._maybe_z_score_theta(theta)
-
- # Compute the MoG parameters of the proposal posterior.
- (
- logits_pp,
- m_pp,
- prec_pp,
- cov_pp,
- ) = self._automatic_posterior_transformation(
- norm_logits_p, m_p, prec_p, norm_logits_d, m_d, prec_d
- )
-
- # Create MoG for proposal posterior and compute log_prob
- # We need precision_factors for MoG, compute via Cholesky
- precf_pp = torch.linalg.cholesky(prec_pp, upper=True)
- mog_pp = MoG(
- logits=logits_pp,
- means=m_pp,
- precisions=prec_pp,
- precision_factors=precf_pp,
- )
-
- # Compute the log_prob of theta under the product.
- log_prob_proposal_posterior = mog_pp.log_prob(theta)
- assert_all_finite(
- log_prob_proposal_posterior,
- """the evaluation of the MoG proposal posterior. This is likely due to a
- numerical instability in the training procedure. Please create an issue on
- Github.""",
- )
-
- return log_prob_proposal_posterior
-
- def _automatic_posterior_transformation(
- self,
- logits_p: Tensor,
- means_p: Tensor,
- precisions_p: Tensor,
- logits_d: Tensor,
- means_d: Tensor,
- precisions_d: Tensor,
- ):
- r"""Returns the MoG parameters of the proposal posterior.
-
- The proposal posterior is:
- $pp(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$
- In words: proposal posterior = posterior estimate * proposal / prior.
-
- If the posterior estimate and the proposal are MoG and the prior is either
- Gaussian or uniform, we can solve this in closed-form. The is implemented in
- this function.
-
- This function implements Appendix A1 from Greenberg et al. 2019.
-
- We have to build L*K components. How do we do this?
- Example: proposal has two components, density estimator has three components.
- Let's call the two components of the proposal i,j and the three components
- of the density estimator x,y,z. We have to multiply every component of the
- proposal with every component of the density estimator. So, what we do is:
- 1) for the proposal, build: i,i,i,j,j,j. Done with torch.repeat_interleave()
- 2) for the density estimator, build: x,y,z,x,y,z. Done with torch.repeat()
- 3) Multiply them with simple matrix operations.
-
- Args:
- logits_p: Component weight of each Gaussian of the proposal.
- means_p: Mean of each Gaussian of the proposal.
- precisions_p: Precision matrix of each Gaussian of the proposal.
- logits_d: Component weight for each Gaussian of the density estimator.
- means_d: Mean of each Gaussian of the density estimator.
- precisions_d: Precision matrix of each Gaussian of the density estimator.
-
- Returns: (Component weight, mean, precision matrix, covariance matrix) of each
- Gaussian of the proposal posterior. Has L*K terms (proposal has L terms,
- density estimator has K terms).
- """
-
- precisions_pp, covariances_pp = self._precisions_proposal_posterior(
- precisions_p, precisions_d
- )
-
- means_pp = self._means_proposal_posterior(
- covariances_pp, means_p, precisions_p, means_d, precisions_d
- )
-
- logits_pp = self._logits_proposal_posterior(
- means_pp,
- precisions_pp,
- covariances_pp,
- logits_p,
- means_p,
- precisions_p,
- logits_d,
- means_d,
- precisions_d,
- )
-
- return logits_pp, means_pp, precisions_pp, covariances_pp
-
- def _precisions_proposal_posterior(
- self, precisions_p: Tensor, precisions_d: Tensor
- ):
- """Return the precisions and covariances of the proposal posterior.
-
- Args:
- precisions_p: Precision matrices of the proposal distribution.
- precisions_d: Precision matrices of the density estimator.
-
- Returns: (Precisions, Covariances) of the proposal posterior. L*K terms.
- """
-
- num_comps_p = precisions_p.shape[1]
- num_comps_d = precisions_d.shape[1]
-
- precisions_p_rep = precisions_p.repeat_interleave(num_comps_d, dim=1)
- precisions_d_rep = precisions_d.repeat(1, num_comps_p, 1, 1)
-
- precisions_pp = precisions_p_rep + precisions_d_rep
- if isinstance(self._maybe_z_scored_prior, MultivariateNormal):
- precisions_pp -= self._maybe_z_scored_prior.precision_matrix
-
- covariances_pp = torch.inverse(precisions_pp)
-
- return precisions_pp, covariances_pp
-
- def _means_proposal_posterior(
- self,
- covariances_pp: Tensor,
- means_p: Tensor,
- precisions_p: Tensor,
- means_d: Tensor,
- precisions_d: Tensor,
- ):
- """Return the means of the proposal posterior.
-
- means_pp = C_ix * (P_i * m_i + P_x * m_x - P_o * m_o).
-
- Args:
- covariances_pp: Covariance matrices of the proposal posterior.
- means_p: Means of the proposal distribution.
- precisions_p: Precision matrices of the proposal distribution.
- means_d: Means of the density estimator.
- precisions_d: Precision matrices of the density estimator.
-
- Returns: Means of the proposal posterior. L*K terms.
- """
-
- num_comps_p = precisions_p.shape[1]
- num_comps_d = precisions_d.shape[1]
-
- # First, compute the product P_i * m_i and P_j * m_j
- prec_m_prod_p = batched_mixture_mv(precisions_p, means_p)
- prec_m_prod_d = batched_mixture_mv(precisions_d, means_d)
-
- # Repeat them to allow for matrix operations: same trick as for the precisions.
- prec_m_prod_p_rep = prec_m_prod_p.repeat_interleave(num_comps_d, dim=1)
- prec_m_prod_d_rep = prec_m_prod_d.repeat(1, num_comps_p, 1)
-
- # Means = C_ij * (P_i * m_i + P_x * m_x - P_o * m_o).
- summed_cov_m_prod_rep = prec_m_prod_p_rep + prec_m_prod_d_rep
- if isinstance(self._maybe_z_scored_prior, MultivariateNormal):
- summed_cov_m_prod_rep -= self.prec_m_prod_prior
-
- means_pp = batched_mixture_mv(covariances_pp, summed_cov_m_prod_rep)
-
- return means_pp
-
- def _maybe_z_score_theta(self, theta: Tensor) -> Tensor:
- """Return potentially standardized theta if z-scoring was requested."""
-
- if self.z_score_theta:
- assert isinstance(self._neural_net, MixtureDensityEstimator)
- theta = self._neural_net._transform_input(theta)
-
- return theta
-
- @staticmethod
- def _logits_proposal_posterior(
- means_pp: Tensor,
- precisions_pp: Tensor,
- covariances_pp: Tensor,
- logits_p: Tensor,
- means_p: Tensor,
- precisions_p: Tensor,
- logits_d: Tensor,
- means_d: Tensor,
- precisions_d: Tensor,
- ):
- """Return the component weights (i.e. logits) of the proposal posterior.
-
- Args:
- means_pp: Means of the proposal posterior.
- precisions_pp: Precision matrices of the proposal posterior.
- covariances_pp: Covariance matrices of the proposal posterior.
- logits_p: Component weights (i.e. logits) of the proposal distribution.
- means_p: Means of the proposal distribution.
- precisions_p: Precision matrices of the proposal distribution.
- logits_d: Component weights (i.e. logits) of the density estimator.
- means_d: Means of the density estimator.
- precisions_d: Precision matrices of the density estimator.
-
- Returns: Component weights of the proposal posterior. L*K terms.
- """
-
- num_comps_p = precisions_p.shape[1]
- num_comps_d = precisions_d.shape[1]
-
- # Compute log(alpha_i * beta_j)
- logits_p_rep = logits_p.repeat_interleave(num_comps_d, dim=1)
- logits_d_rep = logits_d.repeat(1, num_comps_p)
- logit_factors = logits_p_rep + logits_d_rep
-
- # Compute sqrt(det()/(det()*det()))
- logdet_covariances_pp = torch.logdet(covariances_pp)
- logdet_covariances_p = -torch.logdet(precisions_p)
- logdet_covariances_d = -torch.logdet(precisions_d)
-
- # Repeat the proposal and density estimator terms such that there are LK terms.
- # Same trick as has been used above.
- logdet_covariances_p_rep = logdet_covariances_p.repeat_interleave(
- num_comps_d, dim=1
- )
- logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_p)
-
- log_sqrt_det_ratio = 0.5 * (
- logdet_covariances_pp
- - (logdet_covariances_p_rep + logdet_covariances_d_rep)
- )
-
- # Compute for proposal, density estimator, and proposal posterior:
- # mu_i.T * P_i * mu_i
- exponent_p = batched_mixture_vmv(precisions_p, means_p)
- exponent_d = batched_mixture_vmv(precisions_d, means_d)
- exponent_pp = batched_mixture_vmv(precisions_pp, means_pp)
-
- # Extend proposal and density estimator exponents to get LK terms.
- exponent_p_rep = exponent_p.repeat_interleave(num_comps_d, dim=1)
- exponent_d_rep = exponent_d.repeat(1, num_comps_p)
- exponent = -0.5 * (exponent_p_rep + exponent_d_rep - exponent_pp)
-
- logits_pp = logit_factors + log_sqrt_det_ratio + exponent
- return logits_pp
diff --git a/sbi/inference/trainers/npe/npe_loss.py b/sbi/inference/trainers/npe/npe_loss.py
new file mode 100644
index 000000000..af6ddcfa8
--- /dev/null
+++ b/sbi/inference/trainers/npe/npe_loss.py
@@ -0,0 +1,527 @@
+# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
+# under the Apache License Version 2.0, see
+
+from typing import Any, List, Optional
+
+try:
+ from typing import Protocol
+except ImportError:
+ from typing_extensions import Protocol
+
+import torch
+from torch import Tensor, eye, ones
+from torch.distributions import Distribution, MultivariateNormal
+
+import sbi.utils as utils
+from sbi.inference.posteriors.direct_posterior import DirectPosterior
+from sbi.neural_nets.estimators.base import ConditionalDensityEstimator
+from sbi.neural_nets.estimators.mixture_density_estimator import (
+ MixtureDensityEstimator,
+)
+from sbi.neural_nets.estimators.mog import MoG
+from sbi.neural_nets.estimators.shape_handling import (
+ reshape_to_batch_event,
+ reshape_to_sample_batch_event,
+)
+from sbi.utils.sbiutils import (
+ batched_mixture_mv,
+ batched_mixture_vmv,
+ clamp_and_warn,
+)
+from sbi.utils.torchutils import assert_all_finite, repeat_rows
+
+
+class NPELossStrategy(Protocol):
+ """Protocol for NPE loss strategies."""
+
+ # Replaces the use_non_atomic_loss flag and the hasattr(self, "_ran_final_round")
+ uses_only_latest_round: bool
+
+ def __call__(
+ self,
+ theta: Tensor,
+ x: Tensor,
+ masks: Tensor,
+ proposal: Optional[Any],
+ **kwargs,
+ ) -> Tensor:
+ """Calculate the log-probability of the proposal posterior.
+
+ Args:
+ theta: Batch of parameters θ.
+ x: Batch of data.
+ masks: Mask that is True for prior samples in the batch.
+ proposal: Proposal distribution.
+
+ Returns:
+ Log-probability of the proposal posterior.
+ """
+ ...
+
+
+class AtomicLoss:
+ """Atomic loss for NPE-C (sample-based)."""
+
+ uses_only_latest_round: bool = False
+
+ def __init__(
+ self,
+ neural_net: ConditionalDensityEstimator,
+ prior: Distribution,
+ num_atoms: int = 10,
+ use_combined_loss: bool = False,
+ ):
+ self._neural_net = neural_net
+ self._prior = prior
+ self._num_atoms = num_atoms
+ self._use_combined_loss = use_combined_loss
+
+ def __call__(
+ self,
+ theta: Tensor,
+ x: Tensor,
+ masks: Tensor,
+ proposal: Optional[Any],
+ **kwargs,
+ ) -> Tensor:
+ """Return log probability of the proposal posterior for atomic proposals.
+
+ We have two main options when evaluating the proposal posterior.
+ (1) Generate atoms from the proposal prior.
+ (2) Generate atoms from a more targeted distribution, such as the most
+ recent posterior.
+ If we choose the latter, it is likely beneficial not to do this in the first
+ round, since we would be sampling from a randomly-initialized neural density
+ estimator.
+
+ Args:
+ theta: Batch of parameters θ.
+ x: Batch of data.
+ masks: Mask that is True for prior samples in the batch in order to train
+ them with prior loss.
+ proposal: Proposal distribution.
+ **kwargs: Extra arguments.
+
+ Returns:
+ Log-probability of the proposal posterior.
+ """
+ batch_size = theta.shape[0]
+
+ num_atoms = int(
+ clamp_and_warn("num_atoms", self._num_atoms, min_val=2, max_val=batch_size)
+ )
+
+ # Each set of parameter atoms is evaluated using the same x
+ repeated_x = repeat_rows(x, num_atoms)
+
+ # To generate the full set of atoms for a given item in the batch,
+ # we sample without replacement num_atoms - 1 times from the rest
+ # of the theta in the batch.
+ probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1)
+
+ choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False)
+ contrasting_theta = theta[choices]
+
+ # We can now create our sets of atoms from the contrasting parameter sets
+ atomic_theta = torch.cat((theta[:, None, :], contrasting_theta), dim=1).reshape(
+ batch_size * num_atoms, -1
+ )
+
+ # Get (batch_size * num_atoms) log prob prior evals.
+ log_prob_prior = self._prior.log_prob(atomic_theta)
+ log_prob_prior = log_prob_prior.reshape(batch_size, num_atoms)
+ assert_all_finite(log_prob_prior, "prior eval")
+
+ # Evaluate large batch giving (batch_size * num_atoms) log prob posterior evals.
+ atomic_theta = reshape_to_sample_batch_event(
+ atomic_theta, atomic_theta.shape[1:]
+ )
+ repeated_x = reshape_to_batch_event(
+ repeated_x, self._neural_net.condition_shape
+ )
+ log_prob_posterior = self._neural_net.log_prob(atomic_theta, repeated_x)
+ assert_all_finite(log_prob_posterior, "posterior eval")
+ log_prob_posterior = log_prob_posterior.reshape(batch_size, num_atoms)
+
+ # Compute unnormalized proposal posterior.
+ unnormalized_log_prob = log_prob_posterior - log_prob_prior
+
+ # Normalize proposal posterior across discrete set of atoms.
+ log_prob_proposal_posterior = unnormalized_log_prob[:, 0] - torch.logsumexp(
+ unnormalized_log_prob, dim=-1
+ )
+ assert_all_finite(log_prob_proposal_posterior, "proposal posterior eval")
+
+ # combined loss helps prevent density leaking with bounded priors.
+ if self._use_combined_loss:
+ theta = reshape_to_sample_batch_event(theta, self._neural_net.input_shape)
+ x = reshape_to_batch_event(x, self._neural_net.condition_shape)
+ log_prob_posterior_non_atomic = self._neural_net.log_prob(theta, x)
+ # squeeze to remove sample dimension, which is always one during the loss
+ # evaluation of `SNPE_C`.
+ log_prob_posterior_non_atomic = log_prob_posterior_non_atomic.squeeze(dim=0)
+ masks = masks.reshape(-1)
+ log_prob_proposal_posterior = (
+ masks * log_prob_posterior_non_atomic + log_prob_proposal_posterior
+ )
+
+ return log_prob_proposal_posterior
+
+
+class NonAtomicGaussianLoss:
+ """Non-atomic loss for NPE-C (analytical MoG)."""
+
+ uses_only_latest_round: bool = True
+
+ def __init__(
+ self,
+ neural_net: MixtureDensityEstimator,
+ maybe_z_scored_prior: Distribution,
+ prec_m_prod_prior: Optional[Tensor] = None,
+ z_score_theta: bool = False,
+ ):
+ self._neural_net = neural_net
+ self._maybe_z_scored_prior = maybe_z_scored_prior
+ self.prec_m_prod_prior = prec_m_prod_prior
+ self.z_score_theta = z_score_theta
+
+ def __call__(
+ self,
+ theta: Tensor,
+ x: Tensor,
+ masks: Tensor,
+ proposal: DirectPosterior,
+ **kwargs,
+ ) -> Tensor:
+ """Return log-probability of the proposal posterior for MoG proposal.
+
+ For MoG proposals and MoG density estimators, this can be done in closed form
+ and does not require atomic loss (i.e. there will be no leakage issues).
+
+ Notation:
+
+ m are mean vectors.
+ prec are precision matrices.
+ cov are covariance matrices.
+
+ _p at the end indicates that it is the proposal.
+ _d indicates that it is the density estimator.
+ _pp indicates the proposal posterior.
+
+ All tensors will have shapes (batch_dim, num_components, ...)
+
+ Args:
+ theta: Batch of parameters θ.
+ x: Batch of data.
+ masks: Mask that is True for prior samples in the batch.
+ proposal: Proposal distribution.
+ **kwargs: Extra arguments.
+
+ Returns:
+ Log-probability of the proposal posterior.
+ """
+ # Get the proposal MoG at the default_x
+ assert isinstance(proposal.posterior_estimator, MixtureDensityEstimator), (
+ "Proposal posterior_estimator must be MixtureDensityEstimator for "
+ "non-atomic loss."
+ )
+ assert proposal.default_x is not None, "Proposal must have default_x set"
+ mog_p = proposal.posterior_estimator.get_uncorrected_mog(proposal.default_x)
+ norm_logits_p = mog_p.log_weights # Already normalized
+ m_p = mog_p.means
+ prec_p = mog_p.precisions
+
+ # Get the density estimator MoG at the training data x
+ mog_d = self._neural_net.get_uncorrected_mog(x)
+ norm_logits_d = mog_d.log_weights # Already normalized
+ m_d = mog_d.means
+ prec_d = mog_d.precisions
+
+ # z-score theta if it z-scoring had been requested.
+ if self.z_score_theta:
+ theta = self._neural_net._transform_input(theta)
+
+ # Compute the MoG parameters of the proposal posterior.
+ (
+ logits_pp,
+ m_pp,
+ prec_pp,
+ cov_pp,
+ ) = self._automatic_posterior_transformation(
+ norm_logits_p, m_p, prec_p, norm_logits_d, m_d, prec_d
+ )
+
+ # Create MoG for proposal posterior and compute log_prob
+ # We need precision_factors for MoG, compute via Cholesky
+ precf_pp = torch.linalg.cholesky(prec_pp, upper=True)
+ mog_pp = MoG(
+ logits=logits_pp,
+ means=m_pp,
+ precisions=prec_pp,
+ precision_factors=precf_pp,
+ )
+
+ # Compute the log_prob of theta under the product.
+ log_prob_proposal_posterior = mog_pp.log_prob(theta)
+ assert_all_finite(
+ log_prob_proposal_posterior,
+ """the evaluation of the MoG proposal posterior. This is likely due to a
+ numerical instability in the training procedure. Please create an issue on
+ Github.""",
+ )
+
+ return log_prob_proposal_posterior
+
+ def _automatic_posterior_transformation(
+ self,
+ logits_p: Tensor,
+ means_p: Tensor,
+ precisions_p: Tensor,
+ logits_d: Tensor,
+ means_d: Tensor,
+ precisions_d: Tensor,
+ ):
+ r"""Returns the MoG parameters of the proposal posterior.
+
+ The proposal posterior is:
+ $pp(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$
+ In words: proposal posterior = posterior estimate * proposal / prior.
+
+ If the posterior estimate and the proposal are MoG and the prior is either
+ Gaussian or uniform, we can solve this in closed-form. The is implemented in
+ this function.
+
+ This function implements Appendix A1 from Greenberg et al. 2019.
+
+ We have to build L*K components. How do we do this?
+ Example: proposal has two components, density estimator has three components.
+ Let's call the two components of the proposal i,j and the three components
+ of the density estimator x,y,z. We have to multiply every component of the
+ proposal with every component of the density estimator. So, what we do is:
+ 1) for the proposal, build: i,i,i,j,j,j. Done with torch.repeat_interleave()
+ 2) for the density estimator, build: x,y,z,x,y,z. Done with torch.repeat()
+ 3) Multiply them with simple matrix operations.
+
+ Args:
+ logits_p: Component weight of each Gaussian of the proposal.
+ means_p: Mean of each Gaussian of the proposal.
+ precisions_p: Precision matrix of each Gaussian of the proposal.
+ logits_d: Component weight for each Gaussian of the density estimator.
+ means_d: Mean of each Gaussian of the density estimator.
+ precisions_d: Precision matrix of each Gaussian of the density estimator.
+
+ Returns: (Component weight, mean, precision matrix, covariance matrix) of each
+ Gaussian of the proposal posterior. Has L*K terms (proposal has L terms,
+ density estimator has K terms).
+ """
+ precisions_pp, covariances_pp = self._precisions_proposal_posterior(
+ precisions_p, precisions_d
+ )
+
+ means_pp = self._means_proposal_posterior(
+ covariances_pp, means_p, precisions_p, means_d, precisions_d
+ )
+
+ logits_pp = self._logits_proposal_posterior(
+ means_pp,
+ precisions_pp,
+ covariances_pp,
+ logits_p,
+ means_p,
+ precisions_p,
+ logits_d,
+ means_d,
+ precisions_d,
+ )
+
+ return logits_pp, means_pp, precisions_pp, covariances_pp
+
+ def _precisions_proposal_posterior(
+ self, precisions_p: Tensor, precisions_d: Tensor
+ ):
+ """Return the precisions and covariances of the proposal posterior."""
+ num_comps_p = precisions_p.shape[1]
+ num_comps_d = precisions_d.shape[1]
+
+ precisions_p_rep = precisions_p.repeat_interleave(num_comps_d, dim=1)
+ precisions_d_rep = precisions_d.repeat(1, num_comps_p, 1, 1)
+
+ precisions_pp = precisions_p_rep + precisions_d_rep
+ if isinstance(self._maybe_z_scored_prior, MultivariateNormal):
+ precisions_pp -= self._maybe_z_scored_prior.precision_matrix
+
+ covariances_pp = torch.inverse(precisions_pp)
+
+ return precisions_pp, covariances_pp
+
+ def _means_proposal_posterior(
+ self,
+ covariances_pp: Tensor,
+ means_p: Tensor,
+ precisions_p: Tensor,
+ means_d: Tensor,
+ precisions_d: Tensor,
+ ):
+ """Return the means of the proposal posterior."""
+ num_comps_p = precisions_p.shape[1]
+ num_comps_d = precisions_d.shape[1]
+
+ # First, compute the product P_i * m_i and P_j * m_j
+ prec_m_prod_p = batched_mixture_mv(precisions_p, means_p)
+ prec_m_prod_d = batched_mixture_mv(precisions_d, means_d)
+
+ # Repeat them to allow for matrix operations
+ prec_m_prod_p_rep = prec_m_prod_p.repeat_interleave(num_comps_d, dim=1)
+ prec_m_prod_d_rep = prec_m_prod_d.repeat(1, num_comps_p, 1)
+
+ # Means = C_ij * (P_i * m_i + P_x * m_x - P_o * m_o).
+ summed_cov_m_prod_rep = prec_m_prod_p_rep + prec_m_prod_d_rep
+
+ if self.prec_m_prod_prior is not None:
+ summed_cov_m_prod_rep -= self.prec_m_prod_prior
+
+ means_pp = batched_mixture_mv(covariances_pp, summed_cov_m_prod_rep)
+
+ return means_pp
+
+ @staticmethod
+ def _logits_proposal_posterior(
+ means_pp: Tensor,
+ precisions_pp: Tensor,
+ covariances_pp: Tensor,
+ logits_p: Tensor,
+ means_p: Tensor,
+ precisions_p: Tensor,
+ logits_d: Tensor,
+ means_d: Tensor,
+ precisions_d: Tensor,
+ ):
+ """Return the component weights (i.e. logits) of the proposal posterior."""
+ num_comps_p = precisions_p.shape[1]
+ num_comps_d = precisions_d.shape[1]
+
+ # Compute log(alpha_i * beta_j)
+ logits_p_rep = logits_p.repeat_interleave(num_comps_d, dim=1)
+ logits_d_rep = logits_d.repeat(1, num_comps_p)
+ logit_factors = logits_p_rep + logits_d_rep
+
+ # Compute sqrt(det()/(det()*det()))
+ logdet_covariances_pp = torch.logdet(covariances_pp)
+ logdet_covariances_p = -torch.logdet(precisions_p)
+ logdet_covariances_d = -torch.logdet(precisions_d)
+
+ # Repeat the proposal and density estimator terms
+ logdet_covariances_p_rep = logdet_covariances_p.repeat_interleave(
+ num_comps_d, dim=1
+ )
+ logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_p)
+
+ log_sqrt_det_ratio = 0.5 * (
+ logdet_covariances_pp
+ - (logdet_covariances_p_rep + logdet_covariances_d_rep)
+ )
+
+ # Compute for proposal, density estimator, and proposal posterior:
+ exponent_p = batched_mixture_vmv(precisions_p, means_p)
+ exponent_d = batched_mixture_vmv(precisions_d, means_d)
+ exponent_pp = batched_mixture_vmv(precisions_pp, means_pp)
+
+ # Extend proposal and density estimator exponents
+ exponent_p_rep = exponent_p.repeat_interleave(num_comps_d, dim=1)
+ exponent_d_rep = exponent_d.repeat(1, num_comps_p)
+ exponent = -0.5 * (exponent_p_rep + exponent_d_rep - exponent_pp)
+
+ logits_pp = logit_factors + log_sqrt_det_ratio + exponent
+
+ return logits_pp
+
+
+class ImportanceWeightedLoss:
+ """Importance-weighted loss for NPE-B."""
+
+ uses_only_latest_round: bool = False
+
+ def __init__(
+ self,
+ neural_net: ConditionalDensityEstimator,
+ prior: Distribution,
+ round_idx: int,
+ theta_roundwise: List[Tensor],
+ proposal_roundwise: List[Any],
+ ):
+ self._neural_net = neural_net
+ self._prior = prior
+ self._round_idx = round_idx
+ self._theta_roundwise = theta_roundwise
+ self._proposal_roundwise = proposal_roundwise
+
+ def __call__(
+ self,
+ theta: Tensor,
+ x: Tensor,
+ masks: Tensor,
+ proposal: Optional[Any],
+ **kwargs,
+ ) -> Tensor:
+ """
+ Return importance-weighted log probability (Lueckmann, Goncalves et al., 2017).
+
+ Args:
+ theta: Batch of parameters θ.
+ x: Batch of data.
+ masks: Mask that is True for prior samples in the batch in order to train
+ them with prior loss.
+ proposal: Proposal distribution.
+ **kwargs: Extra arguments.
+
+ Returns:
+ Importance-weighted log-probability of the proposal posterior.
+ """
+ # Evaluate prior
+ # we accept prior log prob to be -Inf at theta
+ # meaning that theta is out of the prior range (the weight is thus 0)
+ utils.assert_not_nan_or_plus_inf(
+ self._prior.log_prob(theta), "prior log probs of proposal samples"
+ )
+ prior = torch.exp(self._prior.log_prob(theta))
+
+ # Evaluate proposal
+ # (as theta comes from prior and proposal from previous rounds,
+ # the last proposal is actually a mixture of the prior
+ # and of all the previous proposals with coefficients representing
+ # the proportion of the new theta added at each round)
+ prop = torch.zeros(self._round_idx + 1, device=theta.device)
+ nb_samples = 0 # total number of theta from all the rounds
+
+ for k in range(self._round_idx + 1):
+ nb_samples += self._theta_roundwise[k].size(0)
+ # the number of new theta sampled in the round k
+ prop[k] = self._theta_roundwise[k].size(0)
+
+ prop /= nb_samples
+ log_prop = torch.log(prop).repeat(theta.size(0), 1)
+
+ log_previous_proposals = torch.zeros(
+ (theta.size(0), self._round_idx + 1), device=theta.device
+ )
+ for k, density in enumerate(self._proposal_roundwise):
+ # we accept the k th proposal log prob to be -Inf at theta
+ # meaning that theta is out of the k th proposal range
+ log_previous_proposals[:, k] = density.log_prob(theta)
+ utils.assert_not_nan_or_plus_inf(
+ log_previous_proposals[:, k], "proposal log probs of proposal samples"
+ )
+
+ log_proposal = torch.logsumexp(log_prop + log_previous_proposals, dim=1)
+ proposal_weighted = torch.exp(log_proposal)
+
+ # Construct the importance weights and normalize them
+ importance_weights = prior / proposal_weighted
+ importance_weights /= importance_weights.sum()
+
+ theta = reshape_to_sample_batch_event(theta, theta.shape[1:])
+ # Reshape the density estimator log probs
+ # from (sample_shape, batch_shape) to (batch_shape)
+ posterior_log_probs = self._neural_net.log_prob(theta, x).squeeze(dim=0)
+
+ return importance_weights * posterior_log_probs
diff --git a/sbi/inference/trainers/nre/bnre.py b/sbi/inference/trainers/nre/bnre.py
index 933b6c74d..0f8d84b20 100644
--- a/sbi/inference/trainers/nre/bnre.py
+++ b/sbi/inference/trainers/nre/bnre.py
@@ -3,18 +3,17 @@
from typing import Dict, Optional, Sequence, Union
-import torch
-from torch import Tensor, nn, ones
+from torch import Tensor
from torch.distributions import Distribution
from torch.utils.tensorboard.writer import SummaryWriter
from sbi.inference.trainers._contracts import LossArgs, LossArgsBNRE
from sbi.inference.trainers.nre.nre_a import NRE_A
+from sbi.inference.trainers.nre.nre_loss import BNRELoss, NRELossStrategy
from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
from sbi.neural_nets.ratio_estimators import RatioEstimator
from sbi.sbi_types import Tracker
from sbi.utils.sbiutils import del_entries
-from sbi.utils.torchutils import assert_all_finite
class BNRE(NRE_A):
@@ -113,6 +112,7 @@ def train(
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
+ loss_strategy: Optional[NRELossStrategy] = None,
) -> RatioEstimator:
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
Args:
@@ -156,44 +156,11 @@ def train(
regularization_strength=kwargs.pop("regularization_strength"),
)
- return super().train(**kwargs)
-
- def _loss(
- self, theta: Tensor, x: Tensor, num_atoms: int, regularization_strength: float
- ) -> Tensor:
- """Returns the binary cross-entropy loss for the trained classifier.
-
- The classifier takes as input a $(\theta,x)$ pair. It is trained to predict 1
- if the pair was sampled from the joint $p(\theta,x)$, and to predict 0 if the
- pair was sampled from the marginals $p(\theta)p(x)$.
- """
-
- assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match."
- batch_size = theta.shape[0]
+ if loss_strategy is None:
+ kwargs["loss_strategy"] = BNRELoss()
- logits = self._classifier_logits(theta, x, num_atoms)
- likelihood = torch.sigmoid(logits).squeeze()
-
- # Alternating pairs where there is one sampled from the joint and one
- # sampled from the marginals. The first element is sampled from the
- # joint p(theta, x) and is labelled 1. The second element is sampled
- # from the marginals p(theta)p(x) and is labelled 0. And so on.
- labels = ones(2 * batch_size, device=self._device) # two atoms
- labels[1::2] = 0.0
-
- # Binary cross entropy to learn the likelihood (AALR-specific)
- bce = nn.BCELoss()(likelihood, labels)
-
- # Balancing regularizer
- regularizer = (
- (torch.sigmoid(logits[0::2]) + torch.sigmoid(logits[1::2]) - 1)
- .mean()
- .square()
- )
+ return super().train(**kwargs)
- loss = bce + regularization_strength * regularizer
- assert_all_finite(loss, "BNRE loss")
- return loss
def _get_losses(self, batch: Sequence[Tensor], loss_args: LossArgs) -> Tensor:
"""Overrides the parent class method to check the type of loss_args."""
diff --git a/sbi/inference/trainers/nre/nre_a.py b/sbi/inference/trainers/nre/nre_a.py
index 26653ce3a..02b7dc584 100644
--- a/sbi/inference/trainers/nre/nre_a.py
+++ b/sbi/inference/trainers/nre/nre_a.py
@@ -3,8 +3,6 @@
from typing import Dict, Optional, Union
-import torch
-from torch import Tensor, nn, ones
from torch.distributions import Distribution
from torch.utils.tensorboard.writer import SummaryWriter
@@ -12,11 +10,11 @@
from sbi.inference.trainers.nre.nre_base import (
RatioEstimatorTrainer,
)
+from sbi.inference.trainers.nre.nre_loss import AALRLoss, NRELossStrategy
from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
from sbi.neural_nets.ratio_estimators import RatioEstimator
from sbi.sbi_types import Tracker
from sbi.utils.sbiutils import del_entries
-from sbi.utils.torchutils import assert_all_finite
class NRE_A(RatioEstimatorTrainer):
@@ -111,6 +109,7 @@ def train(
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
loss_kwargs: Optional[LossArgsNRE_A] = None,
+ loss_strategy: Optional[NRELossStrategy] = None,
) -> RatioEstimator:
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
@@ -154,30 +153,8 @@ def train(
f" but got {type(loss_kwargs)}"
)
- return super().train(**kwargs)
-
- def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor:
- """Returns the binary cross-entropy loss for the trained classifier.
-
- The classifier takes as input a $(\theta,x)$ pair. It is trained to predict 1
- if the pair was sampled from the joint $p(\theta,x)$, and to predict 0 if the
- pair was sampled from the marginals $p(\theta)p(x)$.
- """
+ if loss_strategy is None:
+ kwargs["loss_strategy"] = AALRLoss()
- assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match."
- batch_size = theta.shape[0]
-
- logits = self._classifier_logits(theta, x, num_atoms)
- likelihood = torch.sigmoid(logits).squeeze()
-
- # Alternating pairs where there is one sampled from the joint and one
- # sampled from the marginals. The first element is sampled from the
- # joint p(theta, x) and is labelled 1. The second element is sampled
- # from the marginals p(theta)p(x) and is labelled 0. And so on.
- labels = ones(2 * batch_size, device=self._device) # two atoms
- labels[1::2] = 0.0
+ return super().train(**kwargs)
- # Binary cross entropy to learn the likelihood (AALR-specific)
- loss = nn.BCELoss()(likelihood, labels)
- assert_all_finite(loss, "NRE-A loss")
- return loss
diff --git a/sbi/inference/trainers/nre/nre_b.py b/sbi/inference/trainers/nre/nre_b.py
index 87e08edc8..7892abfa9 100644
--- a/sbi/inference/trainers/nre/nre_b.py
+++ b/sbi/inference/trainers/nre/nre_b.py
@@ -3,8 +3,6 @@
from typing import Dict, Optional, Union
-import torch
-from torch import Tensor
from torch.distributions import Distribution
from torch.utils.tensorboard.writer import SummaryWriter
@@ -12,11 +10,11 @@
from sbi.inference.trainers.nre.nre_base import (
RatioEstimatorTrainer,
)
+from sbi.inference.trainers.nre.nre_loss import NRELossStrategy, SRELoss
from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
from sbi.neural_nets.ratio_estimators import RatioEstimator
from sbi.sbi_types import Tracker
from sbi.utils.sbiutils import del_entries
-from sbi.utils.torchutils import assert_all_finite
class NRE_B(RatioEstimatorTrainer):
@@ -111,6 +109,7 @@ def train(
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
+ loss_strategy: Optional[NRELossStrategy] = None,
) -> RatioEstimator:
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
@@ -146,31 +145,8 @@ def train(
kwargs = del_entries(locals(), entries=("self", "__class__"))
kwargs["loss_kwargs"] = LossArgsNRE(num_atoms=kwargs.pop("num_atoms"))
- return super().train(**kwargs)
-
- def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor:
- r"""Return cross-entropy (via softmax activation) loss for 1-out-of-`num_atoms`
- classification.
-
- The classifier takes as input `num_atoms` $(\theta,x)$ pairs. Out of these
- pairs, one pair was sampled from the joint $p(\theta,x)$ and all others from the
- marginals $p(\theta)p(x)$. The classifier is trained to predict which of the
- pairs was sampled from the joint $p(\theta,x)$.
- """
+ if loss_strategy is None:
+ kwargs["loss_strategy"] = SRELoss()
- assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match."
- batch_size = theta.shape[0]
- logits = self._classifier_logits(theta, x, num_atoms)
-
- # For 1-out-of-`num_atoms` classification each datapoint consists
- # of `num_atoms` points, with one of them being the correct one.
- # We have a batch of `batch_size` such datapoints.
- logits = logits.reshape(batch_size, num_atoms)
-
- # Index 0 is the theta-x-pair sampled from the joint p(theta,x) and hence the
- # "correct" one for the 1-out-of-N classification.
- log_prob = logits[:, 0] - torch.logsumexp(logits, dim=-1)
+ return super().train(**kwargs)
- loss = -torch.mean(log_prob)
- assert_all_finite(loss, "NRE-B loss")
- return loss
diff --git a/sbi/inference/trainers/nre/nre_base.py b/sbi/inference/trainers/nre/nre_base.py
index 07b2cda0b..baa8b74f0 100644
--- a/sbi/inference/trainers/nre/nre_base.py
+++ b/sbi/inference/trainers/nre/nre_base.py
@@ -2,12 +2,11 @@
# under the Apache License Version 2.0, see
import warnings
-from abc import ABC, abstractmethod
+from abc import ABC
from dataclasses import asdict, replace
from typing import Any, Dict, Literal, Optional, Sequence, Tuple, Union
-import torch
-from torch import Tensor, eye, ones
+from torch import Tensor
from torch.distributions import Distribution
from torch.utils.tensorboard.writer import SummaryWriter
from typing_extensions import Self
@@ -30,6 +29,7 @@
LossArgs,
NeuralInference,
)
+from sbi.inference.trainers.nre.nre_loss import NRELossStrategy
from sbi.neural_nets import classifier_nn
from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
from sbi.neural_nets.ratio_estimators import RatioEstimator
@@ -38,7 +38,6 @@
check_estimator_arg,
clamp_and_warn,
)
-from sbi.utils.torchutils import repeat_rows
class RatioEstimatorTrainer(NeuralInference[RatioEstimator], ABC):
@@ -90,6 +89,8 @@ def __init__(
show_progress_bars=show_progress_bars,
)
+ self._loss_strategy: Optional[NRELossStrategy] = None
+
# As detailed in the docstring, `density_estimator` is either a string or
# a callable. The function creating the neural network is attached to
# `_build_neural_net`. It will be called in the first round and receive
@@ -101,9 +102,6 @@ def __init__(
else:
self._build_neural_net = classifier
- @abstractmethod
- def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor: ...
-
def append_simulations(
self,
theta: Tensor,
@@ -170,6 +168,7 @@ def train(
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
loss_kwargs: Optional[LossArgsNRE] = None,
+ loss_strategy: Optional[NRELossStrategy] = None,
) -> RatioEstimator:
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
@@ -372,27 +371,6 @@ def build_posterior(
importance_sampling_parameters=importance_sampling_parameters,
)
- def _classifier_logits(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor:
- """Return logits obtained through classifier forward pass.
-
- The logits are obtained from atomic sets of (theta,x) pairs.
- """
- batch_size = theta.shape[0]
- repeated_x = repeat_rows(x, num_atoms)
-
- # Choose `1` or `num_atoms - 1` thetas from the rest of the batch for each x.
- probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1)
-
- choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False)
-
- contrasting_theta = theta[choices]
-
- atomic_theta = torch.cat((theta[:, None, :], contrasting_theta), dim=1).reshape(
- batch_size * num_atoms, -1
- )
-
- return self._neural_net(atomic_theta, repeated_x)
-
def _get_potential_function(
self, prior: Distribution, estimator: RatioEstimator
) -> Tuple[RatioBasedPotential, TorchTransform]:
@@ -465,6 +443,9 @@ def _initialize_neural_network(
del x, theta
+ def _loss(self, *args, **kwargs) -> Tensor:
+ raise NotImplementedError("NRE trainers use _loss_strategy inside _get_losses.")
+
def _get_losses(self, batch: Sequence[Tensor], loss_args: LossArgs) -> Tensor:
"""
Compute losses for a batch of data.
@@ -488,6 +469,8 @@ def _get_losses(self, batch: Sequence[Tensor], loss_args: LossArgs) -> Tensor:
f" but got {type(loss_args)}"
)
- losses = self._loss(theta_batch, x_batch, **asdict(loss_args))
+ if self._loss_strategy is None:
+ raise RuntimeError("Loss strategy not initialized.")
+ losses = self._loss_strategy(theta_batch, x_batch, **asdict(loss_args))
return losses
diff --git a/sbi/inference/trainers/nre/nre_c.py b/sbi/inference/trainers/nre/nre_c.py
index c49a76fbe..edfeadef8 100644
--- a/sbi/inference/trainers/nre/nre_c.py
+++ b/sbi/inference/trainers/nre/nre_c.py
@@ -1,9 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see
-from typing import Dict, Optional, Sequence, Tuple, Union
+from typing import Dict, Optional, Sequence, Union
-import torch
from torch import Tensor
from torch.distributions import Distribution
from torch.utils.tensorboard.writer import SummaryWriter
@@ -12,11 +11,11 @@
from sbi.inference.trainers.nre.nre_base import (
RatioEstimatorTrainer,
)
+from sbi.inference.trainers.nre.nre_loss import CNRELoss, NRELossStrategy
from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
from sbi.neural_nets.ratio_estimators import RatioEstimator
from sbi.sbi_types import Tracker
from sbi.utils.sbiutils import del_entries
-from sbi.utils.torchutils import assert_all_finite
class NRE_C(RatioEstimatorTrainer):
@@ -109,6 +108,7 @@ def train(
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
+ loss_strategy: Optional[NRELossStrategy] = None,
) -> RatioEstimator:
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
@@ -157,89 +157,8 @@ def train(
num_atoms=kwargs.pop("num_classes") + 1, gamma=kwargs.pop("gamma")
)
- return super().train(**kwargs)
-
- def _loss(
- self, theta: Tensor, x: Tensor, num_atoms: int, gamma: float
- ) -> torch.Tensor:
- r"""Return cross-entropy loss (via ''multi-class sigmoid'' activation) for
- 1-out-of-`K + 1` classification.
-
- At optimum, this loss function returns the exact likelihood-to-evidence ratio
- in the first round.
- Details of loss computation are described in Contrastive Neural Ratio
- Estimation[1]. The paper does not discuss the sequential case.
-
- [1] _Contrastive Neural Ratio Estimation_, Benajmin Kurt Miller, et. al.,
- NeurIPS 2022, https://arxiv.org/abs/2210.06170
- """
-
- # Reminder: K = num_classes
- # The algorithm is written with K, so we convert back to K format rather than
- # reasoning in num_atoms.
- num_classes = num_atoms - 1
- assert num_classes >= 1, f"num_classes = {num_classes} must be greater than 1."
-
- assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match."
- batch_size = theta.shape[0]
-
- # We append a contrastive theta to the marginal case because we will remove
- # the jointly drawn
- # sample in the logits_marginal[:, 0] position. That makes the remaining sample
- # marginally drawn.
- # We have a batch of `batch_size` datapoints.
- logits_marginal = self._classifier_logits(theta, x, num_classes + 1).reshape(
- batch_size, num_classes + 1
- )
- logits_joint = self._classifier_logits(theta, x, num_classes).reshape(
- batch_size, num_classes
- )
-
- dtype = logits_marginal.dtype
- device = logits_marginal.device
-
- # Index 0 is the theta-x-pair sampled from the joint p(theta,x) and hence
- # we remove the jointly drawn sample from the logits_marginal
- logits_marginal = logits_marginal[:, 1:]
- # ... and retain it in the logits_joint. Now we have two arrays with K choices.
-
- # To use logsumexp, we extend the denominator logits with loggamma
- loggamma = torch.tensor(gamma, dtype=dtype, device=device).log()
- logK = torch.tensor(num_classes, dtype=dtype, device=device).log()
- denominator_marginal = torch.concat(
- [loggamma + logits_marginal, logK.expand((batch_size, 1))],
- dim=-1,
- )
- denominator_joint = torch.concat(
- [loggamma + logits_joint, logK.expand((batch_size, 1))],
- dim=-1,
- )
-
- # Compute the contributions to the loss from each term in the classification.
- log_prob_marginal = logK - torch.logsumexp(denominator_marginal, dim=-1)
- log_prob_joint = (
- loggamma + logits_joint[:, 0] - torch.logsumexp(denominator_joint, dim=-1)
- )
-
- # relative weights. p_marginal := p_0, and p_joint := p_K * K from the notation.
- p_marginal, p_joint = self._get_prior_probs_marginal_and_joint(gamma)
-
- loss = -torch.mean(p_marginal * log_prob_marginal + p_joint * log_prob_joint)
- assert_all_finite(loss, "NRE-C loss")
- return loss
-
- @staticmethod
- def _get_prior_probs_marginal_and_joint(gamma: float) -> Tuple[float, float]:
- r"""Return a tuple (p_marginal, p_joint) where `p_marginal := `$p_0$,
- `p_joint := `$p_K \cdot K$.
-
- We let the joint (dependently drawn) class to be equally likely across K
- options. The marginal class is therefore restricted to get the remaining
- probability.
- """
- p_joint = gamma / (1 + gamma)
- p_marginal = 1 / (1 + gamma)
- return p_marginal, p_joint
+ if loss_strategy is None:
+ kwargs["loss_strategy"] = CNRELoss()
def _get_losses(self, batch: Sequence[Tensor], loss_args: LossArgs) -> Tensor:
"""Overrides the parent class method to check the type of loss_args."""
diff --git a/sbi/inference/trainers/nre/nre_loss.py b/sbi/inference/trainers/nre/nre_loss.py
new file mode 100644
index 000000000..6a9f47c1c
--- /dev/null
+++ b/sbi/inference/trainers/nre/nre_loss.py
@@ -0,0 +1,221 @@
+# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
+# under the Apache License Version 2.0, see
+
+from typing import Protocol, Tuple
+
+import torch
+from torch import Tensor, eye, nn, ones
+
+from sbi.neural_nets.ratio_estimators import RatioEstimator
+from sbi.utils.torchutils import assert_all_finite, repeat_rows
+
+
+class NRELossStrategy(Protocol):
+ """Protocol defining the interface for all Neural Ratio Estimation loss strategies.
+
+ A strategy must implement a forward pass (__call__) mapping parameters (theta),
+ observations (x), and algorithm-specific keyword arguments (like num_atoms) to a
+ scalar loss tensor.
+ """
+ def __call__(
+ self,
+ neural_net: RatioEstimator,
+ device: str,
+ theta: Tensor,
+ x: Tensor,
+ **kwargs
+ ) -> Tensor: ...
+
+
+def _classifier_logits(
+ neural_net: RatioEstimator, theta: Tensor, x: Tensor, num_atoms: int
+) -> Tensor:
+ """Return logits obtained through classifier forward pass.
+
+ The logits are obtained from atomic sets of (theta,x) pairs.
+ """
+ batch_size = theta.shape[0]
+ repeated_x = repeat_rows(x, num_atoms)
+
+ # Choose `1` or `num_atoms - 1` thetas from the rest of the batch for each x.
+ probs = (
+ ones(batch_size, batch_size, device=theta.device)
+ * (1 - eye(batch_size, device=theta.device))
+ / (batch_size - 1)
+ )
+
+ choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False)
+
+ contrasting_theta = theta[choices]
+
+ atomic_theta = torch.cat((theta[:, None, :], contrasting_theta), dim=1).reshape(
+ batch_size * num_atoms, -1
+ )
+
+ return neural_net(atomic_theta, repeated_x)
+
+
+class AALRLoss:
+ """Neural Ratio Estimation (NRE-A / AALR) loss strategy.
+
+ Returns the binary cross-entropy loss for the trained classifier.
+ """
+ def __call__(self, neural_net: RatioEstimator, device: str, theta: Tensor, x: Tensor, num_atoms: int, **kwargs) -> Tensor:
+ assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match."
+ batch_size = theta.shape[0]
+
+ logits = _classifier_logits(neural_net, theta, x, num_atoms)
+ likelihood = torch.sigmoid(logits).squeeze()
+
+ # Alternating pairs where there is one sampled from the joint and one
+ # sampled from the marginals. The first element is sampled from the
+ # joint p(theta, x) and is labelled 1. The second element is sampled
+ # from the marginals p(theta)p(x) and is labelled 0. And so on.
+ labels = ones(2 * batch_size, device=device) # two atoms
+ labels[1::2] = 0.0
+
+ # Binary cross entropy to learn the likelihood (AALR-specific)
+ loss = nn.BCELoss()(likelihood, labels)
+ assert_all_finite(loss, "NRE-A loss")
+ return loss
+
+
+class SRELoss:
+ """Neural Ratio Estimation (NRE-B / SRE) loss strategy.
+
+ Returns cross-entropy (via softmax activation) loss for 1-out-of-`num_atoms`
+ classification.
+ """
+ def __call__(self, neural_net: RatioEstimator, device: str, theta: Tensor, x: Tensor, num_atoms: int, **kwargs) -> Tensor:
+ assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match."
+ batch_size = theta.shape[0]
+ logits = _classifier_logits(neural_net, theta, x, num_atoms)
+
+ # For 1-out-of-`num_atoms` classification each datapoint consists
+ # of `num_atoms` points, with one of them being the correct one.
+ # We have a batch of `batch_size` such datapoints.
+ logits = logits.reshape(batch_size, num_atoms)
+
+ # Index 0 is the theta-x-pair sampled from the joint p(theta,x) and hence the
+ # "correct" one for the 1-out-of-N classification.
+ log_prob = logits[:, 0] - torch.logsumexp(logits, dim=-1)
+
+ loss = -torch.mean(log_prob)
+ assert_all_finite(loss, "NRE-B loss")
+ return loss
+
+
+class CNRELoss:
+ """Neural Ratio Estimation (NRE-C / CNRE) loss strategy.
+
+ Returns cross-entropy loss (via 'multi-class sigmoid' activation) for
+ 1-out-of-`K + 1` classification.
+ """
+ @staticmethod
+ def _get_prior_probs_marginal_and_joint(gamma: float) -> Tuple[float, float]:
+ r"""Return a tuple (p_marginal, p_joint) where `p_marginal := `$p_0$,
+ `p_joint := `$p_K \cdot K$.
+ """
+ p_joint = gamma / (1 + gamma)
+ p_marginal = 1 / (1 + gamma)
+ return p_marginal, p_joint
+
+ def __call__(
+ self,
+ neural_net: RatioEstimator,
+ device: str,
+ theta: Tensor,
+ x: Tensor,
+ num_atoms: int,
+ gamma: float,
+ **kwargs
+ ) -> Tensor:
+ # Reminder: K = num_classes
+ num_classes = num_atoms - 1
+ assert num_classes >= 1, f"num_classes = {num_classes} must be greater than 1."
+
+ assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match."
+ batch_size = theta.shape[0]
+
+ logits_marginal = _classifier_logits(
+ neural_net, theta, x, num_classes + 1
+ ).reshape(batch_size, num_classes + 1)
+
+ logits_joint = _classifier_logits(
+ neural_net, theta, x, num_classes
+ ).reshape(batch_size, num_classes)
+
+ dtype = logits_marginal.dtype
+
+ # Index 0 is the theta-x-pair sampled from the joint p(theta,x) and hence
+ # we remove the jointly drawn sample from the logits_marginal
+ logits_marginal = logits_marginal[:, 1:]
+
+ # To use logsumexp, we extend the denominator logits with loggamma
+ loggamma = torch.tensor(gamma, dtype=dtype, device=device).log()
+ logK = torch.tensor(num_classes, dtype=dtype, device=device).log()
+ denominator_marginal = torch.concat(
+ [loggamma + logits_marginal, logK.expand((batch_size, 1))],
+ dim=-1,
+ )
+ denominator_joint = torch.concat(
+ [loggamma + logits_joint, logK.expand((batch_size, 1))],
+ dim=-1,
+ )
+
+ # Compute the contributions to the loss from each term in the classification.
+ log_prob_marginal = logK - torch.logsumexp(denominator_marginal, dim=-1)
+ log_prob_joint = (
+ loggamma + logits_joint[:, 0] - torch.logsumexp(denominator_joint, dim=-1)
+ )
+
+ # relative weights. p_marginal := p_0, and p_joint := p_K * K from the notation.
+ p_marginal, p_joint = CNRELoss._get_prior_probs_marginal_and_joint(gamma)
+
+ loss = -torch.mean(p_marginal * log_prob_marginal + p_joint * log_prob_joint)
+ assert_all_finite(loss, "NRE-C loss")
+ return loss
+
+
+class BNRELoss:
+ """Balanced Neural Ratio Estimation (BNRE) loss strategy.
+
+ A variation of NRE-A that adds a balancing regularizer to the
+ binary cross-entropy loss.
+ """
+ def __call__(
+ self,
+ neural_net: RatioEstimator,
+ device: str,
+ theta: Tensor,
+ x: Tensor,
+ num_atoms: int,
+ regularization_strength: float,
+ **kwargs
+ ) -> Tensor:
+ assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match."
+ batch_size = theta.shape[0]
+
+ logits = _classifier_logits(neural_net, theta, x, num_atoms)
+ likelihood = torch.sigmoid(logits).squeeze()
+
+ # Alternating pairs where there is one sampled from the joint and one
+ # sampled from the marginals. The first element is sampled from the
+ # joint p(theta, x) and is labelled 1. The second element is sampled
+ # from the marginals p(theta)p(x) and is labelled 0. And so on.
+ labels = ones(2 * batch_size, device=device) # two atoms
+ labels[1::2] = 0.0
+
+ # Binary cross entropy to learn the likelihood (AALR-specific)
+ bce = nn.BCELoss()(likelihood, labels)
+
+ # Balancing regularizer
+ regularizer = (
+ (torch.sigmoid(logits[0::2]) + torch.sigmoid(logits[1::2]) - 1)
+ .mean()
+ .square()
+ )
+
+ loss = bce + regularization_strength * regularizer
+ assert_all_finite(loss, "BNRE loss")
+ return loss
diff --git a/tests/inference/npe_loss_test.py b/tests/inference/npe_loss_test.py
new file mode 100644
index 000000000..7e22c9532
--- /dev/null
+++ b/tests/inference/npe_loss_test.py
@@ -0,0 +1,131 @@
+# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
+# under the Apache License Version 2.0, see
+
+import pytest
+import torch
+from torch.distributions import MultivariateNormal
+from typing import Any, Optional
+
+from sbi.inference.posteriors.direct_posterior import DirectPosterior
+from sbi.inference.trainers.npe.npe_loss import (
+ AtomicLoss,
+ ImportanceWeightedLoss,
+ NonAtomicGaussianLoss,
+)
+from sbi.neural_nets.estimators.base import ConditionalDensityEstimator
+from sbi.neural_nets.estimators.mixture_density_estimator import MixtureDensityEstimator
+from sbi.neural_nets.estimators.mog import MoG
+from sbi.utils.torchutils import BoxUniform
+
+
+class DummyEstimator(ConditionalDensityEstimator):
+ def __init__(self, is_mixture=False):
+ super().__init__(net=torch.nn.Identity(), input_shape=torch.Size((2,)), condition_shape=torch.Size((2,)))
+ self.is_mixture = is_mixture
+ self.condition_shape = torch.Size((2,))
+ self.input_shape = torch.Size((2,))
+
+ def log_prob(self, theta: torch.Tensor, condition: torch.Tensor, **kwargs):
+ # Dummy implementation returning zeros matching batch shape
+ batch_size = max(theta.shape[0], condition.shape[0])
+ return torch.zeros(batch_size, device=theta.device)
+
+ def sample(self, sample_shape: torch.Size, condition: torch.Tensor, **kwargs):
+ return torch.zeros(sample_shape + self.input_shape, device=condition.device)
+
+ def loss(self, theta, condition, **kwargs):
+ return torch.zeros(theta.shape[0])
+
+ def get_uncorrected_mog(self, condition: torch.Tensor) -> MoG:
+ assert self.is_mixture
+ # Returns simple 1-component MoG
+ batch_size = condition.shape[0]
+ dim = self.input_shape[0]
+ logits = torch.zeros(batch_size, 1)
+ means = torch.zeros(batch_size, 1, dim)
+ precisions = torch.eye(dim).view(1, 1, dim, dim).expand(batch_size, 1, dim, dim)
+ precf = torch.eye(dim).view(1, 1, dim, dim).expand(batch_size, 1, dim, dim)
+ return MoG(logits, means, precisions, prec_factors=precf)
+
+ @property
+ def has_input_transform(self):
+ return False
+
+# Proxy dummy specifically simulating MixtureDensityEstimator to pass isinstance
+class DummyMixtureEstimator(DummyEstimator, MixtureDensityEstimator): # type: ignore
+ def __init__(self):
+ DummyEstimator.__init__(self, is_mixture=True)
+
+
+class DummyPosterior(DirectPosterior):
+ def __init__(self, estimator, prior):
+ # We bypass DirectPosterior's strict checks for testing MoG extraction
+ self.posterior_estimator = estimator
+ self.prior = prior
+ self.default_x = torch.zeros(1, 2)
+
+
+@pytest.fixture
+def theta():
+ return torch.randn(5, 2)
+
+
+@pytest.fixture
+def x():
+ return torch.randn(5, 2)
+
+
+@pytest.fixture
+def masks():
+ return torch.ones(5)
+
+
+@pytest.fixture
+def prior():
+ return MultivariateNormal(torch.zeros(2), torch.eye(2))
+
+
+def test_atomic_loss_initialization_and_call(theta, x, masks, prior):
+ neural_net = DummyEstimator()
+ strategy = AtomicLoss(neural_net=neural_net, prior=prior, num_atoms=2)
+
+ assert strategy.uses_only_latest_round is False
+
+ loss = strategy(theta, x, masks, proposal=None)
+ assert loss.shape == (5,)
+
+
+def test_non_atomic_loss_initialization_and_call(theta, x, masks, prior):
+ neural_net = DummyMixtureEstimator()
+ proposal = DummyPosterior(neural_net, prior)
+
+ strategy = NonAtomicGaussianLoss(
+ neural_net=neural_net,
+ maybe_z_scored_prior=prior,
+ )
+
+ assert strategy.uses_only_latest_round is True
+
+ loss = strategy(theta, x, masks, proposal=proposal)
+ assert loss.shape == (5,)
+
+
+def test_importance_weighted_loss_initialization_and_call(theta, x, masks, prior):
+ neural_net = DummyEstimator()
+ proposal = DummyPosterior(neural_net, prior)
+
+ theta_roundwise = [torch.randn(10, 2), torch.randn(10, 2)]
+ proposal_roundwise = [prior, proposal]
+
+ strategy = ImportanceWeightedLoss(
+ neural_net=neural_net,
+ prior=prior,
+ round_idx=1,
+ theta_roundwise=theta_roundwise,
+ proposal_roundwise=proposal_roundwise,
+ )
+
+ assert strategy.uses_only_latest_round is False
+
+ loss = strategy(theta, x, masks, proposal=proposal)
+ assert loss.shape == (5,)