diff --git a/sbi/inference/trainers/npe/npe_a.py b/sbi/inference/trainers/npe/npe_a.py index e263a5797..4cb080c30 100644 --- a/sbi/inference/trainers/npe/npe_a.py +++ b/sbi/inference/trainers/npe/npe_a.py @@ -462,28 +462,7 @@ def build_posterior( return self._posterior - def _log_prob_proposal_posterior( - self, - theta: Tensor, - x: Tensor, - masks: Tensor, - proposal: Optional[Any], - ) -> Tensor: - """Return the log-probability of the proposal posterior. - - For SNPE-A this is the same as `self._neural_net.log_prob(theta, x)` in - `_loss()` to be found in `snpe_base.py`. - Args: - theta: Batch of parameters θ. - x: Batch of data. - masks: Mask that is True for prior samples in the batch in order to train - them with prior loss. - proposal: Proposal distribution. - - Returns: Log-probability of the proposal posterior. - """ - return self._neural_net.log_prob(theta, x) def _correct_for_proposal( diff --git a/sbi/inference/trainers/npe/npe_b.py b/sbi/inference/trainers/npe/npe_b.py index 2ba1f2461..de6c8069b 100644 --- a/sbi/inference/trainers/npe/npe_b.py +++ b/sbi/inference/trainers/npe/npe_b.py @@ -1,22 +1,22 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see -from typing import Any, Literal, Optional, Union +from typing import Callable, Dict, Literal, Optional, Union -import torch -from torch import Tensor from torch.distributions import Distribution from torch.utils.tensorboard.writer import SummaryWriter -import sbi.utils as utils from sbi.inference.trainers.npe.npe_base import ( PosteriorEstimatorTrainer, ) +from sbi.inference.trainers.npe.npe_loss import ( + ImportanceWeightedLoss, + NPELossStrategy, +) from sbi.neural_nets.estimators.base import ( ConditionalDensityEstimator, ConditionalEstimatorBuilder, ) -from sbi.neural_nets.estimators.shape_handling import reshape_to_sample_batch_event from sbi.sbi_types import Tracker from sbi.utils.sbiutils import del_entries @@ -107,72 +107,48 @@ def __init__( kwargs = del_entries(locals(), entries=("self", "__class__")) super().__init__(**kwargs) - def _log_prob_proposal_posterior( + def train( self, - theta: Tensor, - x: Tensor, - masks: Tensor, - proposal: Optional[Any], - ) -> Tensor: - """ - Return importance-weighted log probability (Lueckmann, Goncalves et al., 2017). - - Args: - theta: Batch of parameters θ. - x: Batch of data. - masks: Mask that is True for prior samples in the batch in order to train - them with prior loss. - proposal: Proposal distribution. - - Returns: - Importance-weighted log-probability of the proposal posterior. - """ - - # Evaluate prior - # we accept prior log prob to be -Inf at theta - # meaning that theta is out of the prior range (the weight is thus 0) - utils.assert_not_nan_or_plus_inf( - self._prior.log_prob(theta), "prior log probs of proposal samples" + training_batch_size: int = 200, + learning_rate: float = 5e-4, + validation_fraction: float = 0.1, + stop_after_epochs: int = 20, + max_num_epochs: int = 2**31 - 1, + clip_max_norm: Optional[float] = 5.0, + calibration_kernel: Optional[Callable] = None, + resume_training: bool = False, + force_first_round_loss: bool = False, + discard_prior_samples: bool = False, + retrain_from_scratch: bool = False, + show_train_summary: bool = False, + dataloader_kwargs: Optional[Dict] = None, + loss_strategy: Optional[NPELossStrategy] = None, + ) -> ConditionalDensityEstimator: + + kwargs = del_entries( + locals(), + entries=("self", "__class__", "loss_strategy"), ) - prior = torch.exp(self._prior.log_prob(theta)) - - # Evaluate proposal - # (as theta comes from prior and proposal from previous rounds, - # the last proposal is actually a mixture of the prior - # and of all the previous proposals with coefficients representing - # the proportion of the new theta added at each round) - prop = torch.zeros(self._round + 1, device=theta.device) - nb_samples = 0 # total number of theta from all the rounds - - for k in range(self._round + 1): - nb_samples += self._theta_roundwise[k].size(0) - # the number of new theta sampled in the round k - prop[k] = self._theta_roundwise[k].size(0) - - prop /= nb_samples - log_prop = torch.log(prop).repeat(theta.size(0), 1) - - log_previous_proposals = torch.zeros( - (theta.size(0), self._round + 1), device=theta.device - ) - for k, density in enumerate(self._proposal_roundwise): - # we accept the k th proposal log prob to be -Inf at theta - # meaning that theta is out of the k th proposal range - log_previous_proposals[:, k] = density.log_prob(theta) - utils.assert_not_nan_or_plus_inf( - log_previous_proposals[:, k], "proposal log probs of proposal samples" - ) - - log_proposal = torch.logsumexp(log_prop + log_previous_proposals, dim=1) - proposal = torch.exp(log_proposal) - # Construct the importance weights and normalize them - importance_weights = prior / proposal - importance_weights /= importance_weights.sum() + if len(self._data_round_index) == 0: + raise RuntimeError( + "No simulations found. You must call .append_simulations() " + "before calling .train()." + ) - theta = reshape_to_sample_batch_event(theta, theta.shape[1:]) - # Reshape the density estimator log probs - # from (sample_shape, batch_shape) to (batch_shape) - posterior_log_probs = self._neural_net.log_prob(theta, x).squeeze(dim=0) + self._round = max(self._data_round_index) + + if loss_strategy is not None: + self._loss_strategy = loss_strategy + elif self._round > 0: + self._loss_strategy = ImportanceWeightedLoss( + neural_net=self._neural_net, + prior=self._prior, + round_idx=self._round, + theta_roundwise=self._theta_roundwise, + proposal_roundwise=self._proposal_roundwise, + ) + else: + self._loss_strategy = None - return importance_weights * posterior_log_probs + return super().train(**kwargs) diff --git a/sbi/inference/trainers/npe/npe_base.py b/sbi/inference/trainers/npe/npe_base.py index 4105bf72f..8f168500c 100644 --- a/sbi/inference/trainers/npe/npe_base.py +++ b/sbi/inference/trainers/npe/npe_base.py @@ -2,7 +2,7 @@ # under the Apache License Version 2.0, see import warnings -from abc import ABC, abstractmethod +from abc import ABC from dataclasses import asdict from typing import Any, Callable, Dict, Literal, Optional, Sequence, Tuple, Union @@ -32,6 +32,7 @@ NeuralInference, check_if_proposal_has_default_x, ) +from sbi.inference.trainers.npe.npe_loss import NPELossStrategy from sbi.neural_nets import posterior_nn from sbi.neural_nets.estimators import ConditionalDensityEstimator from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder @@ -111,16 +112,9 @@ def __init__( self._build_neural_net = density_estimator self._proposal_roundwise = [] - self.use_non_atomic_loss = False + self._loss_strategy: Optional[NPELossStrategy] = None + - @abstractmethod - def _log_prob_proposal_posterior( - self, - theta: Tensor, - x: Tensor, - masks: Tensor, - proposal: Optional[Any], - ) -> Tensor: ... def append_simulations( self, @@ -194,10 +188,11 @@ def append_simulations( # Check for problematic z-scoring warn_if_invalid_for_zscoring(x) + is_atomic = self._loss_strategy is None or not getattr(self._loss_strategy, "uses_only_latest_round", True) if ( - type(self).__name__ == "SNPE_C" + type(self).__name__ in ("SNPE_C", "NPE_C") and current_round > 0 - and not self.use_non_atomic_loss + and is_atomic ): nle_nre_apt_msg_on_invalid_x( num_nans, @@ -254,6 +249,7 @@ def train( retrain_from_scratch: bool = False, show_train_summary: bool = False, dataloader_kwargs: Optional[dict] = None, + loss_strategy: Optional[NPELossStrategy] = None, ) -> ConditionalDensityEstimator: r"""Return density estimator that approximates the distribution $p(\theta|x)$. @@ -550,9 +546,11 @@ def _loss( # Use posterior log prob (without proposal correction) for first round. loss = self._neural_net.loss(theta, x) else: + if self._loss_strategy is None: + raise ValueError("Loss strategy is None but round > 0.") # Currently only works for `DensityEstimator` objects. # Must be extended ones other Estimators are implemented. See #966, - loss = -self._log_prob_proposal_posterior(theta, x, masks, proposal) + loss = -self._loss_strategy(theta, x, masks, proposal) assert_all_finite(loss, "NPE loss") return calibration_kernel(x) * loss @@ -650,7 +648,10 @@ def _get_start_index(self, context: StartIndexContext) -> int: # SNPE-A can, by construction of the algorithm, only use samples from the last # round. SNPE-A is the only algorithm that has an attribute `_ran_final_round`, # so this is how we check for whether or not we are using SNPE-A. - if self.use_non_atomic_loss or hasattr(self, "_ran_final_round"): + if ( + self._loss_strategy is not None + and self._loss_strategy.uses_only_latest_round + ) or hasattr(self, "_ran_final_round"): start_idx = self._round return start_idx diff --git a/sbi/inference/trainers/npe/npe_c.py b/sbi/inference/trainers/npe/npe_c.py index 9accdc2c6..bad251fcb 100644 --- a/sbi/inference/trainers/npe/npe_c.py +++ b/sbi/inference/trainers/npe/npe_c.py @@ -4,7 +4,6 @@ from typing import Callable, Dict, Literal, Optional, Union import torch -from torch import Tensor, eye, ones from torch.distributions import Distribution, MultivariateNormal, Uniform from torch.utils.tensorboard.writer import SummaryWriter @@ -12,6 +11,11 @@ from sbi.inference.trainers.npe.npe_base import ( PosteriorEstimatorTrainer, ) +from sbi.inference.trainers.npe.npe_loss import ( + AtomicLoss, + NPELossStrategy, + NonAtomicGaussianLoss, +) from sbi.neural_nets.estimators.base import ( ConditionalDensityEstimator, ConditionalEstimatorBuilder, @@ -19,21 +23,12 @@ from sbi.neural_nets.estimators.mixture_density_estimator import ( MixtureDensityEstimator, ) -from sbi.neural_nets.estimators.mog import MoG -from sbi.neural_nets.estimators.shape_handling import ( - reshape_to_batch_event, - reshape_to_sample_batch_event, -) from sbi.sbi_types import Tracker from sbi.utils import ( - batched_mixture_mv, - batched_mixture_vmv, check_dist_class, - clamp_and_warn, del_entries, - repeat_rows, ) -from sbi.utils.torchutils import BoxUniform, assert_all_finite +from sbi.utils.torchutils import BoxUniform class NPE_C(PosteriorEstimatorTrainer): @@ -141,6 +136,7 @@ def train( retrain_from_scratch: bool = False, show_train_summary: bool = False, dataloader_kwargs: Optional[Dict] = None, + loss_strategy: Optional[NPELossStrategy] = None, ) -> ConditionalDensityEstimator: r"""Return density estimator that approximates the distribution $p(\theta|x)$. @@ -183,18 +179,12 @@ def train( Returns: Density estimator that approximates the distribution $p(\theta|x)$. """ - if len(self._data_round_index) == 0: raise RuntimeError( "No simulations found. You must call .append_simulations() " "before calling .train()." ) - # WARNING: sneaky trick ahead. We proxy the parent's `train` here, - # requiring the signature to have `num_atoms`, save it for use below, and - # continue. It's sneaky because we are using the object (self) as a namespace - # to pass arguments between functions, and that's implicit state management. - self._num_atoms = num_atoms - self._use_combined_loss = use_combined_loss + kwargs = del_entries( locals(), entries=("self", "__class__", "num_atoms", "use_combined_loss"), @@ -202,13 +192,12 @@ def train( self._round = max(self._data_round_index) - if self._round > 0: - # Set the proposal to the last proposal that was passed by the user. For - # atomic SNPE, it does not matter what the proposal is. For non-atomic - # SNPE, we only use the latest data that was passed, i.e. the one from the - # last proposal. + if loss_strategy is not None: + self._loss_strategy = loss_strategy + elif self._round > 0: + # Set the proposal to the last proposal that was passed by the user. proposal = self._proposal_roundwise[-1] - self.use_non_atomic_loss = ( + use_non_atomic_loss = ( isinstance(proposal, DirectPosterior) and isinstance(proposal.posterior_estimator, MixtureDensityEstimator) and isinstance(self._neural_net, MixtureDensityEstimator) @@ -217,75 +206,69 @@ def train( )[0] ) - algorithm = "non-atomic" if self.use_non_atomic_loss else "atomic" + algorithm = "non-atomic" if use_non_atomic_loss else "atomic" print(f"Using SNPE-C with {algorithm} loss") - if self.use_non_atomic_loss: + if use_non_atomic_loss: # Take care of z-scoring, pre-compute and store prior terms. self._set_state_for_mog_proposal() + # Instantiate Non-Atomic Strategy + if isinstance(self._maybe_z_scored_prior, MultivariateNormal): + prec_m_prod_prior = torch.mv( + self._maybe_z_scored_prior.precision_matrix, + self._maybe_z_scored_prior.loc, + ) + else: + prec_m_prod_prior = None + + self._loss_strategy = NonAtomicGaussianLoss( + neural_net=self._neural_net, + maybe_z_scored_prior=self._maybe_z_scored_prior, + prec_m_prod_prior=prec_m_prod_prior, + z_score_theta=self.z_score_theta, + ) + else: + # Instantiate Atomic Strategy + self._loss_strategy = AtomicLoss( + neural_net=self._neural_net, + prior=self._prior, + num_atoms=num_atoms, + use_combined_loss=use_combined_loss, + ) + else: + # Default to None for first round (equivalent to MLE) + self._loss_strategy = None + return super().train(**kwargs) def _set_state_for_mog_proposal(self) -> None: """Set state variables that are used at each training step of non-atomic SNPE-C. Three things are computed: - 1) Check if z-scoring was requested. We check if the MixtureDensityEstimator - has an input transform enabled via the has_input_transform property. - 2) Define a (potentially standardized) prior. It's standardized if z-scoring - had been requested. - 3) Compute (Precision * mean) for the prior. This quantity is used at every - training step if the prior is Gaussian. + 1) Check if z-scoring was requested. + 2) Define a (potentially standardized) prior. + 3) Compute (Precision * mean) for the prior. """ - # Check if z-scoring is enabled on the MixtureDensityEstimator assert isinstance(self._neural_net, MixtureDensityEstimator) self.z_score_theta = self._neural_net.has_input_transform self._set_maybe_z_scored_prior() - if isinstance(self._maybe_z_scored_prior, MultivariateNormal): - self.prec_m_prod_prior = torch.mv( - self._maybe_z_scored_prior.precision_matrix, # type: ignore - self._maybe_z_scored_prior.loc, # type: ignore - ) - def _set_maybe_z_scored_prior(self) -> None: - r"""Compute and store potentially standardized prior (if z-scoring was done). - - The proposal posterior is: - $pp(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$ - - Let's denote z-scored theta by `a`: a = (theta - mean) / std - Then pp'(a|x) = 1/Z_2 * q'(a|x) * prop'(a) / p'(a)$ - - The ' indicates that the evaluation occurs in standardized space. The constant - scaling factor has been absorbed into Z_2. - From the above equation, we see that we need to evaluate the prior **in - standardized space**. We build the standardized prior in this function. - - The standardize transform that is applied to the samples theta does not use - the exact prior mean and std (due to implementation issues). Hence, the z-scored - prior will not be exactly have mean=0 and std=1. - """ + r"""Compute and store potentially standardized prior (if z-scoring was done).""" if self.z_score_theta: # Get z-score parameters from the MixtureDensityEstimator - # The transform is: z = (theta - shift) / scale - # where shift = mean (estimated from samples) and scale = std (estimated) assert isinstance(self._neural_net, MixtureDensityEstimator) shift = self._neural_net._transform_shift scale = self._neural_net._transform_scale - # The MixtureDensityEstimator uses: z = (theta - shift) / scale - # where shift = mean and scale = std (estimated from training data) estim_prior_mean = shift estim_prior_std = scale # Compute the discrepancy of the true prior mean and std and the mean and # std that was empirically estimated from samples. - # N(theta|m,s) = N((theta-m_e)/s_e|(m-m_e)/s_e, s/s_e) - # Above: m,s are true prior mean and std. m_e,s_e are estimated prior mean - # and std (estimated from samples and used to build standardize transform). almost_zero_mean = (self._prior.mean - estim_prior_mean) / estim_prior_std almost_one_std = torch.sqrt(self._prior.variance) / estim_prior_std @@ -301,415 +284,4 @@ def _set_maybe_z_scored_prior(self) -> None: else: self._maybe_z_scored_prior = self._prior - def _log_prob_proposal_posterior( - self, - theta: Tensor, - x: Tensor, - masks: Tensor, - proposal: DirectPosterior, - ) -> Tensor: - """Return the log-probability of the proposal posterior. - - If the proposal is a MoG, the density estimator is a MoG, and the prior is - either Gaussian or uniform, we use non-atomic loss. Else, use atomic loss (which - suffers from leakage). - - Args: - theta: Batch of parameters θ. - x: Batch of data. - masks: Mask that is True for prior samples in the batch in order to train - them with prior loss. - proposal: Proposal distribution. - - Returns: Log-probability of the proposal posterior. - """ - - if self.use_non_atomic_loss: - if not isinstance(self._neural_net, MixtureDensityEstimator): - raise ValueError( - "The density estimator must be a MixtureDensityEstimator " - "for non-atomic loss." - ) - - return self._log_prob_proposal_posterior_mog(theta, x, proposal) - else: - if not hasattr(self._neural_net, "log_prob"): - raise ValueError( - "The neural estimator must have a log_prob method, for\ - atomic loss. It should at best follow the \ - sbi.neural_nets 'DensityEstiamtor' interface." - ) - return self._log_prob_proposal_posterior_atomic(theta, x, masks) - - def _log_prob_proposal_posterior_atomic( - self, theta: Tensor, x: Tensor, masks: Tensor - ): - """Return log probability of the proposal posterior for atomic proposals. - - We have two main options when evaluating the proposal posterior. - (1) Generate atoms from the proposal prior. - (2) Generate atoms from a more targeted distribution, such as the most - recent posterior. - If we choose the latter, it is likely beneficial not to do this in the first - round, since we would be sampling from a randomly-initialized neural density - estimator. - - Args: - theta: Batch of parameters θ. - x: Batch of data. - masks: Mask that is True for prior samples in the batch in order to train - them with prior loss. - - Returns: - Log-probability of the proposal posterior. - """ - batch_size = theta.shape[0] - - num_atoms = int( - clamp_and_warn("num_atoms", self._num_atoms, min_val=2, max_val=batch_size) - ) - - # Each set of parameter atoms is evaluated using the same x, - # so we repeat rows of the data x, e.g. [1, 2] -> [1, 1, 2, 2] - repeated_x = repeat_rows(x, num_atoms) - - # To generate the full set of atoms for a given item in the batch, - # we sample without replacement num_atoms - 1 times from the rest - # of the theta in the batch. - probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1) - - choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False) - contrasting_theta = theta[choices] - - # We can now create our sets of atoms from the contrasting parameter sets - # we have generated. - atomic_theta = torch.cat((theta[:, None, :], contrasting_theta), dim=1).reshape( - batch_size * num_atoms, -1 - ) - - # Get (batch_size * num_atoms) log prob prior evals. - log_prob_prior = self._prior.log_prob(atomic_theta) - log_prob_prior = log_prob_prior.reshape(batch_size, num_atoms) - assert_all_finite(log_prob_prior, "prior eval") - - # Evaluate large batch giving (batch_size * num_atoms) log prob posterior evals. - atomic_theta = reshape_to_sample_batch_event( - atomic_theta, atomic_theta.shape[1:] - ) - repeated_x = reshape_to_batch_event( - repeated_x, self._neural_net.condition_shape - ) - log_prob_posterior = self._neural_net.log_prob(atomic_theta, repeated_x) - assert_all_finite(log_prob_posterior, "posterior eval") - log_prob_posterior = log_prob_posterior.reshape(batch_size, num_atoms) - - # Compute unnormalized proposal posterior. - unnormalized_log_prob = log_prob_posterior - log_prob_prior - - # Normalize proposal posterior across discrete set of atoms. - log_prob_proposal_posterior = unnormalized_log_prob[:, 0] - torch.logsumexp( - unnormalized_log_prob, dim=-1 - ) - assert_all_finite(log_prob_proposal_posterior, "proposal posterior eval") - - # XXX This evaluates the posterior on _all_ prior samples - if self._use_combined_loss: - theta = reshape_to_sample_batch_event(theta, self._neural_net.input_shape) - x = reshape_to_batch_event(x, self._neural_net.condition_shape) - log_prob_posterior_non_atomic = self._neural_net.log_prob(theta, x) - # squeeze to remove sample dimension, which is always one during the loss - # evaluation of `SNPE_C` (because we have one theta vector per x vector). - log_prob_posterior_non_atomic = log_prob_posterior_non_atomic.squeeze(dim=0) - masks = masks.reshape(-1) - log_prob_proposal_posterior = ( - masks * log_prob_posterior_non_atomic + log_prob_proposal_posterior - ) - - return log_prob_proposal_posterior - - def _log_prob_proposal_posterior_mog( - self, theta: Tensor, x: Tensor, proposal: DirectPosterior - ) -> Tensor: - """Return log-probability of the proposal posterior for MoG proposal. - - For MoG proposals and MoG density estimators, this can be done in closed form - and does not require atomic loss (i.e. there will be no leakage issues). - - Notation: - - m are mean vectors. - prec are precision matrices. - cov are covariance matrices. - - _p at the end indicates that it is the proposal. - _d indicates that it is the density estimator. - _pp indicates the proposal posterior. - - All tensors will have shapes (batch_dim, num_components, ...) - - Args: - theta: Batch of parameters θ. - x: Batch of data. - proposal: Proposal distribution. - - Returns: - Log-probability of the proposal posterior. - """ - # Get the proposal MoG at the default_x - assert isinstance(proposal.posterior_estimator, MixtureDensityEstimator) - assert proposal.default_x is not None, "Proposal must have default_x set" - mog_p = proposal.posterior_estimator.get_uncorrected_mog(proposal.default_x) - norm_logits_p = mog_p.log_weights # Already normalized - m_p = mog_p.means - prec_p = mog_p.precisions - - # Get the density estimator MoG at the training data x - assert isinstance(self._neural_net, MixtureDensityEstimator) - mog_d = self._neural_net.get_uncorrected_mog(x) - norm_logits_d = mog_d.log_weights # Already normalized - m_d = mog_d.means - prec_d = mog_d.precisions - - # z-score theta if z-scoring was requested. - theta = self._maybe_z_score_theta(theta) - - # Compute the MoG parameters of the proposal posterior. - ( - logits_pp, - m_pp, - prec_pp, - cov_pp, - ) = self._automatic_posterior_transformation( - norm_logits_p, m_p, prec_p, norm_logits_d, m_d, prec_d - ) - - # Create MoG for proposal posterior and compute log_prob - # We need precision_factors for MoG, compute via Cholesky - precf_pp = torch.linalg.cholesky(prec_pp, upper=True) - mog_pp = MoG( - logits=logits_pp, - means=m_pp, - precisions=prec_pp, - precision_factors=precf_pp, - ) - - # Compute the log_prob of theta under the product. - log_prob_proposal_posterior = mog_pp.log_prob(theta) - assert_all_finite( - log_prob_proposal_posterior, - """the evaluation of the MoG proposal posterior. This is likely due to a - numerical instability in the training procedure. Please create an issue on - Github.""", - ) - - return log_prob_proposal_posterior - - def _automatic_posterior_transformation( - self, - logits_p: Tensor, - means_p: Tensor, - precisions_p: Tensor, - logits_d: Tensor, - means_d: Tensor, - precisions_d: Tensor, - ): - r"""Returns the MoG parameters of the proposal posterior. - - The proposal posterior is: - $pp(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$ - In words: proposal posterior = posterior estimate * proposal / prior. - - If the posterior estimate and the proposal are MoG and the prior is either - Gaussian or uniform, we can solve this in closed-form. The is implemented in - this function. - - This function implements Appendix A1 from Greenberg et al. 2019. - - We have to build L*K components. How do we do this? - Example: proposal has two components, density estimator has three components. - Let's call the two components of the proposal i,j and the three components - of the density estimator x,y,z. We have to multiply every component of the - proposal with every component of the density estimator. So, what we do is: - 1) for the proposal, build: i,i,i,j,j,j. Done with torch.repeat_interleave() - 2) for the density estimator, build: x,y,z,x,y,z. Done with torch.repeat() - 3) Multiply them with simple matrix operations. - - Args: - logits_p: Component weight of each Gaussian of the proposal. - means_p: Mean of each Gaussian of the proposal. - precisions_p: Precision matrix of each Gaussian of the proposal. - logits_d: Component weight for each Gaussian of the density estimator. - means_d: Mean of each Gaussian of the density estimator. - precisions_d: Precision matrix of each Gaussian of the density estimator. - - Returns: (Component weight, mean, precision matrix, covariance matrix) of each - Gaussian of the proposal posterior. Has L*K terms (proposal has L terms, - density estimator has K terms). - """ - - precisions_pp, covariances_pp = self._precisions_proposal_posterior( - precisions_p, precisions_d - ) - - means_pp = self._means_proposal_posterior( - covariances_pp, means_p, precisions_p, means_d, precisions_d - ) - - logits_pp = self._logits_proposal_posterior( - means_pp, - precisions_pp, - covariances_pp, - logits_p, - means_p, - precisions_p, - logits_d, - means_d, - precisions_d, - ) - - return logits_pp, means_pp, precisions_pp, covariances_pp - - def _precisions_proposal_posterior( - self, precisions_p: Tensor, precisions_d: Tensor - ): - """Return the precisions and covariances of the proposal posterior. - - Args: - precisions_p: Precision matrices of the proposal distribution. - precisions_d: Precision matrices of the density estimator. - - Returns: (Precisions, Covariances) of the proposal posterior. L*K terms. - """ - - num_comps_p = precisions_p.shape[1] - num_comps_d = precisions_d.shape[1] - - precisions_p_rep = precisions_p.repeat_interleave(num_comps_d, dim=1) - precisions_d_rep = precisions_d.repeat(1, num_comps_p, 1, 1) - - precisions_pp = precisions_p_rep + precisions_d_rep - if isinstance(self._maybe_z_scored_prior, MultivariateNormal): - precisions_pp -= self._maybe_z_scored_prior.precision_matrix - - covariances_pp = torch.inverse(precisions_pp) - - return precisions_pp, covariances_pp - - def _means_proposal_posterior( - self, - covariances_pp: Tensor, - means_p: Tensor, - precisions_p: Tensor, - means_d: Tensor, - precisions_d: Tensor, - ): - """Return the means of the proposal posterior. - - means_pp = C_ix * (P_i * m_i + P_x * m_x - P_o * m_o). - - Args: - covariances_pp: Covariance matrices of the proposal posterior. - means_p: Means of the proposal distribution. - precisions_p: Precision matrices of the proposal distribution. - means_d: Means of the density estimator. - precisions_d: Precision matrices of the density estimator. - - Returns: Means of the proposal posterior. L*K terms. - """ - - num_comps_p = precisions_p.shape[1] - num_comps_d = precisions_d.shape[1] - - # First, compute the product P_i * m_i and P_j * m_j - prec_m_prod_p = batched_mixture_mv(precisions_p, means_p) - prec_m_prod_d = batched_mixture_mv(precisions_d, means_d) - - # Repeat them to allow for matrix operations: same trick as for the precisions. - prec_m_prod_p_rep = prec_m_prod_p.repeat_interleave(num_comps_d, dim=1) - prec_m_prod_d_rep = prec_m_prod_d.repeat(1, num_comps_p, 1) - - # Means = C_ij * (P_i * m_i + P_x * m_x - P_o * m_o). - summed_cov_m_prod_rep = prec_m_prod_p_rep + prec_m_prod_d_rep - if isinstance(self._maybe_z_scored_prior, MultivariateNormal): - summed_cov_m_prod_rep -= self.prec_m_prod_prior - - means_pp = batched_mixture_mv(covariances_pp, summed_cov_m_prod_rep) - - return means_pp - - def _maybe_z_score_theta(self, theta: Tensor) -> Tensor: - """Return potentially standardized theta if z-scoring was requested.""" - - if self.z_score_theta: - assert isinstance(self._neural_net, MixtureDensityEstimator) - theta = self._neural_net._transform_input(theta) - - return theta - - @staticmethod - def _logits_proposal_posterior( - means_pp: Tensor, - precisions_pp: Tensor, - covariances_pp: Tensor, - logits_p: Tensor, - means_p: Tensor, - precisions_p: Tensor, - logits_d: Tensor, - means_d: Tensor, - precisions_d: Tensor, - ): - """Return the component weights (i.e. logits) of the proposal posterior. - - Args: - means_pp: Means of the proposal posterior. - precisions_pp: Precision matrices of the proposal posterior. - covariances_pp: Covariance matrices of the proposal posterior. - logits_p: Component weights (i.e. logits) of the proposal distribution. - means_p: Means of the proposal distribution. - precisions_p: Precision matrices of the proposal distribution. - logits_d: Component weights (i.e. logits) of the density estimator. - means_d: Means of the density estimator. - precisions_d: Precision matrices of the density estimator. - - Returns: Component weights of the proposal posterior. L*K terms. - """ - - num_comps_p = precisions_p.shape[1] - num_comps_d = precisions_d.shape[1] - - # Compute log(alpha_i * beta_j) - logits_p_rep = logits_p.repeat_interleave(num_comps_d, dim=1) - logits_d_rep = logits_d.repeat(1, num_comps_p) - logit_factors = logits_p_rep + logits_d_rep - - # Compute sqrt(det()/(det()*det())) - logdet_covariances_pp = torch.logdet(covariances_pp) - logdet_covariances_p = -torch.logdet(precisions_p) - logdet_covariances_d = -torch.logdet(precisions_d) - - # Repeat the proposal and density estimator terms such that there are LK terms. - # Same trick as has been used above. - logdet_covariances_p_rep = logdet_covariances_p.repeat_interleave( - num_comps_d, dim=1 - ) - logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_p) - - log_sqrt_det_ratio = 0.5 * ( - logdet_covariances_pp - - (logdet_covariances_p_rep + logdet_covariances_d_rep) - ) - - # Compute for proposal, density estimator, and proposal posterior: - # mu_i.T * P_i * mu_i - exponent_p = batched_mixture_vmv(precisions_p, means_p) - exponent_d = batched_mixture_vmv(precisions_d, means_d) - exponent_pp = batched_mixture_vmv(precisions_pp, means_pp) - - # Extend proposal and density estimator exponents to get LK terms. - exponent_p_rep = exponent_p.repeat_interleave(num_comps_d, dim=1) - exponent_d_rep = exponent_d.repeat(1, num_comps_p) - exponent = -0.5 * (exponent_p_rep + exponent_d_rep - exponent_pp) - - logits_pp = logit_factors + log_sqrt_det_ratio + exponent - return logits_pp diff --git a/sbi/inference/trainers/npe/npe_loss.py b/sbi/inference/trainers/npe/npe_loss.py new file mode 100644 index 000000000..af6ddcfa8 --- /dev/null +++ b/sbi/inference/trainers/npe/npe_loss.py @@ -0,0 +1,527 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Apache License Version 2.0, see + +from typing import Any, List, Optional + +try: + from typing import Protocol +except ImportError: + from typing_extensions import Protocol + +import torch +from torch import Tensor, eye, ones +from torch.distributions import Distribution, MultivariateNormal + +import sbi.utils as utils +from sbi.inference.posteriors.direct_posterior import DirectPosterior +from sbi.neural_nets.estimators.base import ConditionalDensityEstimator +from sbi.neural_nets.estimators.mixture_density_estimator import ( + MixtureDensityEstimator, +) +from sbi.neural_nets.estimators.mog import MoG +from sbi.neural_nets.estimators.shape_handling import ( + reshape_to_batch_event, + reshape_to_sample_batch_event, +) +from sbi.utils.sbiutils import ( + batched_mixture_mv, + batched_mixture_vmv, + clamp_and_warn, +) +from sbi.utils.torchutils import assert_all_finite, repeat_rows + + +class NPELossStrategy(Protocol): + """Protocol for NPE loss strategies.""" + + # Replaces the use_non_atomic_loss flag and the hasattr(self, "_ran_final_round") + uses_only_latest_round: bool + + def __call__( + self, + theta: Tensor, + x: Tensor, + masks: Tensor, + proposal: Optional[Any], + **kwargs, + ) -> Tensor: + """Calculate the log-probability of the proposal posterior. + + Args: + theta: Batch of parameters θ. + x: Batch of data. + masks: Mask that is True for prior samples in the batch. + proposal: Proposal distribution. + + Returns: + Log-probability of the proposal posterior. + """ + ... + + +class AtomicLoss: + """Atomic loss for NPE-C (sample-based).""" + + uses_only_latest_round: bool = False + + def __init__( + self, + neural_net: ConditionalDensityEstimator, + prior: Distribution, + num_atoms: int = 10, + use_combined_loss: bool = False, + ): + self._neural_net = neural_net + self._prior = prior + self._num_atoms = num_atoms + self._use_combined_loss = use_combined_loss + + def __call__( + self, + theta: Tensor, + x: Tensor, + masks: Tensor, + proposal: Optional[Any], + **kwargs, + ) -> Tensor: + """Return log probability of the proposal posterior for atomic proposals. + + We have two main options when evaluating the proposal posterior. + (1) Generate atoms from the proposal prior. + (2) Generate atoms from a more targeted distribution, such as the most + recent posterior. + If we choose the latter, it is likely beneficial not to do this in the first + round, since we would be sampling from a randomly-initialized neural density + estimator. + + Args: + theta: Batch of parameters θ. + x: Batch of data. + masks: Mask that is True for prior samples in the batch in order to train + them with prior loss. + proposal: Proposal distribution. + **kwargs: Extra arguments. + + Returns: + Log-probability of the proposal posterior. + """ + batch_size = theta.shape[0] + + num_atoms = int( + clamp_and_warn("num_atoms", self._num_atoms, min_val=2, max_val=batch_size) + ) + + # Each set of parameter atoms is evaluated using the same x + repeated_x = repeat_rows(x, num_atoms) + + # To generate the full set of atoms for a given item in the batch, + # we sample without replacement num_atoms - 1 times from the rest + # of the theta in the batch. + probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1) + + choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False) + contrasting_theta = theta[choices] + + # We can now create our sets of atoms from the contrasting parameter sets + atomic_theta = torch.cat((theta[:, None, :], contrasting_theta), dim=1).reshape( + batch_size * num_atoms, -1 + ) + + # Get (batch_size * num_atoms) log prob prior evals. + log_prob_prior = self._prior.log_prob(atomic_theta) + log_prob_prior = log_prob_prior.reshape(batch_size, num_atoms) + assert_all_finite(log_prob_prior, "prior eval") + + # Evaluate large batch giving (batch_size * num_atoms) log prob posterior evals. + atomic_theta = reshape_to_sample_batch_event( + atomic_theta, atomic_theta.shape[1:] + ) + repeated_x = reshape_to_batch_event( + repeated_x, self._neural_net.condition_shape + ) + log_prob_posterior = self._neural_net.log_prob(atomic_theta, repeated_x) + assert_all_finite(log_prob_posterior, "posterior eval") + log_prob_posterior = log_prob_posterior.reshape(batch_size, num_atoms) + + # Compute unnormalized proposal posterior. + unnormalized_log_prob = log_prob_posterior - log_prob_prior + + # Normalize proposal posterior across discrete set of atoms. + log_prob_proposal_posterior = unnormalized_log_prob[:, 0] - torch.logsumexp( + unnormalized_log_prob, dim=-1 + ) + assert_all_finite(log_prob_proposal_posterior, "proposal posterior eval") + + # combined loss helps prevent density leaking with bounded priors. + if self._use_combined_loss: + theta = reshape_to_sample_batch_event(theta, self._neural_net.input_shape) + x = reshape_to_batch_event(x, self._neural_net.condition_shape) + log_prob_posterior_non_atomic = self._neural_net.log_prob(theta, x) + # squeeze to remove sample dimension, which is always one during the loss + # evaluation of `SNPE_C`. + log_prob_posterior_non_atomic = log_prob_posterior_non_atomic.squeeze(dim=0) + masks = masks.reshape(-1) + log_prob_proposal_posterior = ( + masks * log_prob_posterior_non_atomic + log_prob_proposal_posterior + ) + + return log_prob_proposal_posterior + + +class NonAtomicGaussianLoss: + """Non-atomic loss for NPE-C (analytical MoG).""" + + uses_only_latest_round: bool = True + + def __init__( + self, + neural_net: MixtureDensityEstimator, + maybe_z_scored_prior: Distribution, + prec_m_prod_prior: Optional[Tensor] = None, + z_score_theta: bool = False, + ): + self._neural_net = neural_net + self._maybe_z_scored_prior = maybe_z_scored_prior + self.prec_m_prod_prior = prec_m_prod_prior + self.z_score_theta = z_score_theta + + def __call__( + self, + theta: Tensor, + x: Tensor, + masks: Tensor, + proposal: DirectPosterior, + **kwargs, + ) -> Tensor: + """Return log-probability of the proposal posterior for MoG proposal. + + For MoG proposals and MoG density estimators, this can be done in closed form + and does not require atomic loss (i.e. there will be no leakage issues). + + Notation: + + m are mean vectors. + prec are precision matrices. + cov are covariance matrices. + + _p at the end indicates that it is the proposal. + _d indicates that it is the density estimator. + _pp indicates the proposal posterior. + + All tensors will have shapes (batch_dim, num_components, ...) + + Args: + theta: Batch of parameters θ. + x: Batch of data. + masks: Mask that is True for prior samples in the batch. + proposal: Proposal distribution. + **kwargs: Extra arguments. + + Returns: + Log-probability of the proposal posterior. + """ + # Get the proposal MoG at the default_x + assert isinstance(proposal.posterior_estimator, MixtureDensityEstimator), ( + "Proposal posterior_estimator must be MixtureDensityEstimator for " + "non-atomic loss." + ) + assert proposal.default_x is not None, "Proposal must have default_x set" + mog_p = proposal.posterior_estimator.get_uncorrected_mog(proposal.default_x) + norm_logits_p = mog_p.log_weights # Already normalized + m_p = mog_p.means + prec_p = mog_p.precisions + + # Get the density estimator MoG at the training data x + mog_d = self._neural_net.get_uncorrected_mog(x) + norm_logits_d = mog_d.log_weights # Already normalized + m_d = mog_d.means + prec_d = mog_d.precisions + + # z-score theta if it z-scoring had been requested. + if self.z_score_theta: + theta = self._neural_net._transform_input(theta) + + # Compute the MoG parameters of the proposal posterior. + ( + logits_pp, + m_pp, + prec_pp, + cov_pp, + ) = self._automatic_posterior_transformation( + norm_logits_p, m_p, prec_p, norm_logits_d, m_d, prec_d + ) + + # Create MoG for proposal posterior and compute log_prob + # We need precision_factors for MoG, compute via Cholesky + precf_pp = torch.linalg.cholesky(prec_pp, upper=True) + mog_pp = MoG( + logits=logits_pp, + means=m_pp, + precisions=prec_pp, + precision_factors=precf_pp, + ) + + # Compute the log_prob of theta under the product. + log_prob_proposal_posterior = mog_pp.log_prob(theta) + assert_all_finite( + log_prob_proposal_posterior, + """the evaluation of the MoG proposal posterior. This is likely due to a + numerical instability in the training procedure. Please create an issue on + Github.""", + ) + + return log_prob_proposal_posterior + + def _automatic_posterior_transformation( + self, + logits_p: Tensor, + means_p: Tensor, + precisions_p: Tensor, + logits_d: Tensor, + means_d: Tensor, + precisions_d: Tensor, + ): + r"""Returns the MoG parameters of the proposal posterior. + + The proposal posterior is: + $pp(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$ + In words: proposal posterior = posterior estimate * proposal / prior. + + If the posterior estimate and the proposal are MoG and the prior is either + Gaussian or uniform, we can solve this in closed-form. The is implemented in + this function. + + This function implements Appendix A1 from Greenberg et al. 2019. + + We have to build L*K components. How do we do this? + Example: proposal has two components, density estimator has three components. + Let's call the two components of the proposal i,j and the three components + of the density estimator x,y,z. We have to multiply every component of the + proposal with every component of the density estimator. So, what we do is: + 1) for the proposal, build: i,i,i,j,j,j. Done with torch.repeat_interleave() + 2) for the density estimator, build: x,y,z,x,y,z. Done with torch.repeat() + 3) Multiply them with simple matrix operations. + + Args: + logits_p: Component weight of each Gaussian of the proposal. + means_p: Mean of each Gaussian of the proposal. + precisions_p: Precision matrix of each Gaussian of the proposal. + logits_d: Component weight for each Gaussian of the density estimator. + means_d: Mean of each Gaussian of the density estimator. + precisions_d: Precision matrix of each Gaussian of the density estimator. + + Returns: (Component weight, mean, precision matrix, covariance matrix) of each + Gaussian of the proposal posterior. Has L*K terms (proposal has L terms, + density estimator has K terms). + """ + precisions_pp, covariances_pp = self._precisions_proposal_posterior( + precisions_p, precisions_d + ) + + means_pp = self._means_proposal_posterior( + covariances_pp, means_p, precisions_p, means_d, precisions_d + ) + + logits_pp = self._logits_proposal_posterior( + means_pp, + precisions_pp, + covariances_pp, + logits_p, + means_p, + precisions_p, + logits_d, + means_d, + precisions_d, + ) + + return logits_pp, means_pp, precisions_pp, covariances_pp + + def _precisions_proposal_posterior( + self, precisions_p: Tensor, precisions_d: Tensor + ): + """Return the precisions and covariances of the proposal posterior.""" + num_comps_p = precisions_p.shape[1] + num_comps_d = precisions_d.shape[1] + + precisions_p_rep = precisions_p.repeat_interleave(num_comps_d, dim=1) + precisions_d_rep = precisions_d.repeat(1, num_comps_p, 1, 1) + + precisions_pp = precisions_p_rep + precisions_d_rep + if isinstance(self._maybe_z_scored_prior, MultivariateNormal): + precisions_pp -= self._maybe_z_scored_prior.precision_matrix + + covariances_pp = torch.inverse(precisions_pp) + + return precisions_pp, covariances_pp + + def _means_proposal_posterior( + self, + covariances_pp: Tensor, + means_p: Tensor, + precisions_p: Tensor, + means_d: Tensor, + precisions_d: Tensor, + ): + """Return the means of the proposal posterior.""" + num_comps_p = precisions_p.shape[1] + num_comps_d = precisions_d.shape[1] + + # First, compute the product P_i * m_i and P_j * m_j + prec_m_prod_p = batched_mixture_mv(precisions_p, means_p) + prec_m_prod_d = batched_mixture_mv(precisions_d, means_d) + + # Repeat them to allow for matrix operations + prec_m_prod_p_rep = prec_m_prod_p.repeat_interleave(num_comps_d, dim=1) + prec_m_prod_d_rep = prec_m_prod_d.repeat(1, num_comps_p, 1) + + # Means = C_ij * (P_i * m_i + P_x * m_x - P_o * m_o). + summed_cov_m_prod_rep = prec_m_prod_p_rep + prec_m_prod_d_rep + + if self.prec_m_prod_prior is not None: + summed_cov_m_prod_rep -= self.prec_m_prod_prior + + means_pp = batched_mixture_mv(covariances_pp, summed_cov_m_prod_rep) + + return means_pp + + @staticmethod + def _logits_proposal_posterior( + means_pp: Tensor, + precisions_pp: Tensor, + covariances_pp: Tensor, + logits_p: Tensor, + means_p: Tensor, + precisions_p: Tensor, + logits_d: Tensor, + means_d: Tensor, + precisions_d: Tensor, + ): + """Return the component weights (i.e. logits) of the proposal posterior.""" + num_comps_p = precisions_p.shape[1] + num_comps_d = precisions_d.shape[1] + + # Compute log(alpha_i * beta_j) + logits_p_rep = logits_p.repeat_interleave(num_comps_d, dim=1) + logits_d_rep = logits_d.repeat(1, num_comps_p) + logit_factors = logits_p_rep + logits_d_rep + + # Compute sqrt(det()/(det()*det())) + logdet_covariances_pp = torch.logdet(covariances_pp) + logdet_covariances_p = -torch.logdet(precisions_p) + logdet_covariances_d = -torch.logdet(precisions_d) + + # Repeat the proposal and density estimator terms + logdet_covariances_p_rep = logdet_covariances_p.repeat_interleave( + num_comps_d, dim=1 + ) + logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_p) + + log_sqrt_det_ratio = 0.5 * ( + logdet_covariances_pp + - (logdet_covariances_p_rep + logdet_covariances_d_rep) + ) + + # Compute for proposal, density estimator, and proposal posterior: + exponent_p = batched_mixture_vmv(precisions_p, means_p) + exponent_d = batched_mixture_vmv(precisions_d, means_d) + exponent_pp = batched_mixture_vmv(precisions_pp, means_pp) + + # Extend proposal and density estimator exponents + exponent_p_rep = exponent_p.repeat_interleave(num_comps_d, dim=1) + exponent_d_rep = exponent_d.repeat(1, num_comps_p) + exponent = -0.5 * (exponent_p_rep + exponent_d_rep - exponent_pp) + + logits_pp = logit_factors + log_sqrt_det_ratio + exponent + + return logits_pp + + +class ImportanceWeightedLoss: + """Importance-weighted loss for NPE-B.""" + + uses_only_latest_round: bool = False + + def __init__( + self, + neural_net: ConditionalDensityEstimator, + prior: Distribution, + round_idx: int, + theta_roundwise: List[Tensor], + proposal_roundwise: List[Any], + ): + self._neural_net = neural_net + self._prior = prior + self._round_idx = round_idx + self._theta_roundwise = theta_roundwise + self._proposal_roundwise = proposal_roundwise + + def __call__( + self, + theta: Tensor, + x: Tensor, + masks: Tensor, + proposal: Optional[Any], + **kwargs, + ) -> Tensor: + """ + Return importance-weighted log probability (Lueckmann, Goncalves et al., 2017). + + Args: + theta: Batch of parameters θ. + x: Batch of data. + masks: Mask that is True for prior samples in the batch in order to train + them with prior loss. + proposal: Proposal distribution. + **kwargs: Extra arguments. + + Returns: + Importance-weighted log-probability of the proposal posterior. + """ + # Evaluate prior + # we accept prior log prob to be -Inf at theta + # meaning that theta is out of the prior range (the weight is thus 0) + utils.assert_not_nan_or_plus_inf( + self._prior.log_prob(theta), "prior log probs of proposal samples" + ) + prior = torch.exp(self._prior.log_prob(theta)) + + # Evaluate proposal + # (as theta comes from prior and proposal from previous rounds, + # the last proposal is actually a mixture of the prior + # and of all the previous proposals with coefficients representing + # the proportion of the new theta added at each round) + prop = torch.zeros(self._round_idx + 1, device=theta.device) + nb_samples = 0 # total number of theta from all the rounds + + for k in range(self._round_idx + 1): + nb_samples += self._theta_roundwise[k].size(0) + # the number of new theta sampled in the round k + prop[k] = self._theta_roundwise[k].size(0) + + prop /= nb_samples + log_prop = torch.log(prop).repeat(theta.size(0), 1) + + log_previous_proposals = torch.zeros( + (theta.size(0), self._round_idx + 1), device=theta.device + ) + for k, density in enumerate(self._proposal_roundwise): + # we accept the k th proposal log prob to be -Inf at theta + # meaning that theta is out of the k th proposal range + log_previous_proposals[:, k] = density.log_prob(theta) + utils.assert_not_nan_or_plus_inf( + log_previous_proposals[:, k], "proposal log probs of proposal samples" + ) + + log_proposal = torch.logsumexp(log_prop + log_previous_proposals, dim=1) + proposal_weighted = torch.exp(log_proposal) + + # Construct the importance weights and normalize them + importance_weights = prior / proposal_weighted + importance_weights /= importance_weights.sum() + + theta = reshape_to_sample_batch_event(theta, theta.shape[1:]) + # Reshape the density estimator log probs + # from (sample_shape, batch_shape) to (batch_shape) + posterior_log_probs = self._neural_net.log_prob(theta, x).squeeze(dim=0) + + return importance_weights * posterior_log_probs diff --git a/sbi/inference/trainers/nre/bnre.py b/sbi/inference/trainers/nre/bnre.py index 933b6c74d..0f8d84b20 100644 --- a/sbi/inference/trainers/nre/bnre.py +++ b/sbi/inference/trainers/nre/bnre.py @@ -3,18 +3,17 @@ from typing import Dict, Optional, Sequence, Union -import torch -from torch import Tensor, nn, ones +from torch import Tensor from torch.distributions import Distribution from torch.utils.tensorboard.writer import SummaryWriter from sbi.inference.trainers._contracts import LossArgs, LossArgsBNRE from sbi.inference.trainers.nre.nre_a import NRE_A +from sbi.inference.trainers.nre.nre_loss import BNRELoss, NRELossStrategy from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder from sbi.neural_nets.ratio_estimators import RatioEstimator from sbi.sbi_types import Tracker from sbi.utils.sbiutils import del_entries -from sbi.utils.torchutils import assert_all_finite class BNRE(NRE_A): @@ -113,6 +112,7 @@ def train( retrain_from_scratch: bool = False, show_train_summary: bool = False, dataloader_kwargs: Optional[Dict] = None, + loss_strategy: Optional[NRELossStrategy] = None, ) -> RatioEstimator: r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$. Args: @@ -156,44 +156,11 @@ def train( regularization_strength=kwargs.pop("regularization_strength"), ) - return super().train(**kwargs) - - def _loss( - self, theta: Tensor, x: Tensor, num_atoms: int, regularization_strength: float - ) -> Tensor: - """Returns the binary cross-entropy loss for the trained classifier. - - The classifier takes as input a $(\theta,x)$ pair. It is trained to predict 1 - if the pair was sampled from the joint $p(\theta,x)$, and to predict 0 if the - pair was sampled from the marginals $p(\theta)p(x)$. - """ - - assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match." - batch_size = theta.shape[0] + if loss_strategy is None: + kwargs["loss_strategy"] = BNRELoss() - logits = self._classifier_logits(theta, x, num_atoms) - likelihood = torch.sigmoid(logits).squeeze() - - # Alternating pairs where there is one sampled from the joint and one - # sampled from the marginals. The first element is sampled from the - # joint p(theta, x) and is labelled 1. The second element is sampled - # from the marginals p(theta)p(x) and is labelled 0. And so on. - labels = ones(2 * batch_size, device=self._device) # two atoms - labels[1::2] = 0.0 - - # Binary cross entropy to learn the likelihood (AALR-specific) - bce = nn.BCELoss()(likelihood, labels) - - # Balancing regularizer - regularizer = ( - (torch.sigmoid(logits[0::2]) + torch.sigmoid(logits[1::2]) - 1) - .mean() - .square() - ) + return super().train(**kwargs) - loss = bce + regularization_strength * regularizer - assert_all_finite(loss, "BNRE loss") - return loss def _get_losses(self, batch: Sequence[Tensor], loss_args: LossArgs) -> Tensor: """Overrides the parent class method to check the type of loss_args.""" diff --git a/sbi/inference/trainers/nre/nre_a.py b/sbi/inference/trainers/nre/nre_a.py index 26653ce3a..02b7dc584 100644 --- a/sbi/inference/trainers/nre/nre_a.py +++ b/sbi/inference/trainers/nre/nre_a.py @@ -3,8 +3,6 @@ from typing import Dict, Optional, Union -import torch -from torch import Tensor, nn, ones from torch.distributions import Distribution from torch.utils.tensorboard.writer import SummaryWriter @@ -12,11 +10,11 @@ from sbi.inference.trainers.nre.nre_base import ( RatioEstimatorTrainer, ) +from sbi.inference.trainers.nre.nre_loss import AALRLoss, NRELossStrategy from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder from sbi.neural_nets.ratio_estimators import RatioEstimator from sbi.sbi_types import Tracker from sbi.utils.sbiutils import del_entries -from sbi.utils.torchutils import assert_all_finite class NRE_A(RatioEstimatorTrainer): @@ -111,6 +109,7 @@ def train( show_train_summary: bool = False, dataloader_kwargs: Optional[Dict] = None, loss_kwargs: Optional[LossArgsNRE_A] = None, + loss_strategy: Optional[NRELossStrategy] = None, ) -> RatioEstimator: r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$. @@ -154,30 +153,8 @@ def train( f" but got {type(loss_kwargs)}" ) - return super().train(**kwargs) - - def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor: - """Returns the binary cross-entropy loss for the trained classifier. - - The classifier takes as input a $(\theta,x)$ pair. It is trained to predict 1 - if the pair was sampled from the joint $p(\theta,x)$, and to predict 0 if the - pair was sampled from the marginals $p(\theta)p(x)$. - """ + if loss_strategy is None: + kwargs["loss_strategy"] = AALRLoss() - assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match." - batch_size = theta.shape[0] - - logits = self._classifier_logits(theta, x, num_atoms) - likelihood = torch.sigmoid(logits).squeeze() - - # Alternating pairs where there is one sampled from the joint and one - # sampled from the marginals. The first element is sampled from the - # joint p(theta, x) and is labelled 1. The second element is sampled - # from the marginals p(theta)p(x) and is labelled 0. And so on. - labels = ones(2 * batch_size, device=self._device) # two atoms - labels[1::2] = 0.0 + return super().train(**kwargs) - # Binary cross entropy to learn the likelihood (AALR-specific) - loss = nn.BCELoss()(likelihood, labels) - assert_all_finite(loss, "NRE-A loss") - return loss diff --git a/sbi/inference/trainers/nre/nre_b.py b/sbi/inference/trainers/nre/nre_b.py index 87e08edc8..7892abfa9 100644 --- a/sbi/inference/trainers/nre/nre_b.py +++ b/sbi/inference/trainers/nre/nre_b.py @@ -3,8 +3,6 @@ from typing import Dict, Optional, Union -import torch -from torch import Tensor from torch.distributions import Distribution from torch.utils.tensorboard.writer import SummaryWriter @@ -12,11 +10,11 @@ from sbi.inference.trainers.nre.nre_base import ( RatioEstimatorTrainer, ) +from sbi.inference.trainers.nre.nre_loss import NRELossStrategy, SRELoss from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder from sbi.neural_nets.ratio_estimators import RatioEstimator from sbi.sbi_types import Tracker from sbi.utils.sbiutils import del_entries -from sbi.utils.torchutils import assert_all_finite class NRE_B(RatioEstimatorTrainer): @@ -111,6 +109,7 @@ def train( retrain_from_scratch: bool = False, show_train_summary: bool = False, dataloader_kwargs: Optional[Dict] = None, + loss_strategy: Optional[NRELossStrategy] = None, ) -> RatioEstimator: r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$. @@ -146,31 +145,8 @@ def train( kwargs = del_entries(locals(), entries=("self", "__class__")) kwargs["loss_kwargs"] = LossArgsNRE(num_atoms=kwargs.pop("num_atoms")) - return super().train(**kwargs) - - def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor: - r"""Return cross-entropy (via softmax activation) loss for 1-out-of-`num_atoms` - classification. - - The classifier takes as input `num_atoms` $(\theta,x)$ pairs. Out of these - pairs, one pair was sampled from the joint $p(\theta,x)$ and all others from the - marginals $p(\theta)p(x)$. The classifier is trained to predict which of the - pairs was sampled from the joint $p(\theta,x)$. - """ + if loss_strategy is None: + kwargs["loss_strategy"] = SRELoss() - assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match." - batch_size = theta.shape[0] - logits = self._classifier_logits(theta, x, num_atoms) - - # For 1-out-of-`num_atoms` classification each datapoint consists - # of `num_atoms` points, with one of them being the correct one. - # We have a batch of `batch_size` such datapoints. - logits = logits.reshape(batch_size, num_atoms) - - # Index 0 is the theta-x-pair sampled from the joint p(theta,x) and hence the - # "correct" one for the 1-out-of-N classification. - log_prob = logits[:, 0] - torch.logsumexp(logits, dim=-1) + return super().train(**kwargs) - loss = -torch.mean(log_prob) - assert_all_finite(loss, "NRE-B loss") - return loss diff --git a/sbi/inference/trainers/nre/nre_base.py b/sbi/inference/trainers/nre/nre_base.py index 07b2cda0b..baa8b74f0 100644 --- a/sbi/inference/trainers/nre/nre_base.py +++ b/sbi/inference/trainers/nre/nre_base.py @@ -2,12 +2,11 @@ # under the Apache License Version 2.0, see import warnings -from abc import ABC, abstractmethod +from abc import ABC from dataclasses import asdict, replace from typing import Any, Dict, Literal, Optional, Sequence, Tuple, Union -import torch -from torch import Tensor, eye, ones +from torch import Tensor from torch.distributions import Distribution from torch.utils.tensorboard.writer import SummaryWriter from typing_extensions import Self @@ -30,6 +29,7 @@ LossArgs, NeuralInference, ) +from sbi.inference.trainers.nre.nre_loss import NRELossStrategy from sbi.neural_nets import classifier_nn from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder from sbi.neural_nets.ratio_estimators import RatioEstimator @@ -38,7 +38,6 @@ check_estimator_arg, clamp_and_warn, ) -from sbi.utils.torchutils import repeat_rows class RatioEstimatorTrainer(NeuralInference[RatioEstimator], ABC): @@ -90,6 +89,8 @@ def __init__( show_progress_bars=show_progress_bars, ) + self._loss_strategy: Optional[NRELossStrategy] = None + # As detailed in the docstring, `density_estimator` is either a string or # a callable. The function creating the neural network is attached to # `_build_neural_net`. It will be called in the first round and receive @@ -101,9 +102,6 @@ def __init__( else: self._build_neural_net = classifier - @abstractmethod - def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor: ... - def append_simulations( self, theta: Tensor, @@ -170,6 +168,7 @@ def train( show_train_summary: bool = False, dataloader_kwargs: Optional[Dict] = None, loss_kwargs: Optional[LossArgsNRE] = None, + loss_strategy: Optional[NRELossStrategy] = None, ) -> RatioEstimator: r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$. @@ -372,27 +371,6 @@ def build_posterior( importance_sampling_parameters=importance_sampling_parameters, ) - def _classifier_logits(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor: - """Return logits obtained through classifier forward pass. - - The logits are obtained from atomic sets of (theta,x) pairs. - """ - batch_size = theta.shape[0] - repeated_x = repeat_rows(x, num_atoms) - - # Choose `1` or `num_atoms - 1` thetas from the rest of the batch for each x. - probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1) - - choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False) - - contrasting_theta = theta[choices] - - atomic_theta = torch.cat((theta[:, None, :], contrasting_theta), dim=1).reshape( - batch_size * num_atoms, -1 - ) - - return self._neural_net(atomic_theta, repeated_x) - def _get_potential_function( self, prior: Distribution, estimator: RatioEstimator ) -> Tuple[RatioBasedPotential, TorchTransform]: @@ -465,6 +443,9 @@ def _initialize_neural_network( del x, theta + def _loss(self, *args, **kwargs) -> Tensor: + raise NotImplementedError("NRE trainers use _loss_strategy inside _get_losses.") + def _get_losses(self, batch: Sequence[Tensor], loss_args: LossArgs) -> Tensor: """ Compute losses for a batch of data. @@ -488,6 +469,8 @@ def _get_losses(self, batch: Sequence[Tensor], loss_args: LossArgs) -> Tensor: f" but got {type(loss_args)}" ) - losses = self._loss(theta_batch, x_batch, **asdict(loss_args)) + if self._loss_strategy is None: + raise RuntimeError("Loss strategy not initialized.") + losses = self._loss_strategy(theta_batch, x_batch, **asdict(loss_args)) return losses diff --git a/sbi/inference/trainers/nre/nre_c.py b/sbi/inference/trainers/nre/nre_c.py index c49a76fbe..edfeadef8 100644 --- a/sbi/inference/trainers/nre/nre_c.py +++ b/sbi/inference/trainers/nre/nre_c.py @@ -1,9 +1,8 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see -from typing import Dict, Optional, Sequence, Tuple, Union +from typing import Dict, Optional, Sequence, Union -import torch from torch import Tensor from torch.distributions import Distribution from torch.utils.tensorboard.writer import SummaryWriter @@ -12,11 +11,11 @@ from sbi.inference.trainers.nre.nre_base import ( RatioEstimatorTrainer, ) +from sbi.inference.trainers.nre.nre_loss import CNRELoss, NRELossStrategy from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder from sbi.neural_nets.ratio_estimators import RatioEstimator from sbi.sbi_types import Tracker from sbi.utils.sbiutils import del_entries -from sbi.utils.torchutils import assert_all_finite class NRE_C(RatioEstimatorTrainer): @@ -109,6 +108,7 @@ def train( retrain_from_scratch: bool = False, show_train_summary: bool = False, dataloader_kwargs: Optional[Dict] = None, + loss_strategy: Optional[NRELossStrategy] = None, ) -> RatioEstimator: r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$. @@ -157,89 +157,8 @@ def train( num_atoms=kwargs.pop("num_classes") + 1, gamma=kwargs.pop("gamma") ) - return super().train(**kwargs) - - def _loss( - self, theta: Tensor, x: Tensor, num_atoms: int, gamma: float - ) -> torch.Tensor: - r"""Return cross-entropy loss (via ''multi-class sigmoid'' activation) for - 1-out-of-`K + 1` classification. - - At optimum, this loss function returns the exact likelihood-to-evidence ratio - in the first round. - Details of loss computation are described in Contrastive Neural Ratio - Estimation[1]. The paper does not discuss the sequential case. - - [1] _Contrastive Neural Ratio Estimation_, Benajmin Kurt Miller, et. al., - NeurIPS 2022, https://arxiv.org/abs/2210.06170 - """ - - # Reminder: K = num_classes - # The algorithm is written with K, so we convert back to K format rather than - # reasoning in num_atoms. - num_classes = num_atoms - 1 - assert num_classes >= 1, f"num_classes = {num_classes} must be greater than 1." - - assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match." - batch_size = theta.shape[0] - - # We append a contrastive theta to the marginal case because we will remove - # the jointly drawn - # sample in the logits_marginal[:, 0] position. That makes the remaining sample - # marginally drawn. - # We have a batch of `batch_size` datapoints. - logits_marginal = self._classifier_logits(theta, x, num_classes + 1).reshape( - batch_size, num_classes + 1 - ) - logits_joint = self._classifier_logits(theta, x, num_classes).reshape( - batch_size, num_classes - ) - - dtype = logits_marginal.dtype - device = logits_marginal.device - - # Index 0 is the theta-x-pair sampled from the joint p(theta,x) and hence - # we remove the jointly drawn sample from the logits_marginal - logits_marginal = logits_marginal[:, 1:] - # ... and retain it in the logits_joint. Now we have two arrays with K choices. - - # To use logsumexp, we extend the denominator logits with loggamma - loggamma = torch.tensor(gamma, dtype=dtype, device=device).log() - logK = torch.tensor(num_classes, dtype=dtype, device=device).log() - denominator_marginal = torch.concat( - [loggamma + logits_marginal, logK.expand((batch_size, 1))], - dim=-1, - ) - denominator_joint = torch.concat( - [loggamma + logits_joint, logK.expand((batch_size, 1))], - dim=-1, - ) - - # Compute the contributions to the loss from each term in the classification. - log_prob_marginal = logK - torch.logsumexp(denominator_marginal, dim=-1) - log_prob_joint = ( - loggamma + logits_joint[:, 0] - torch.logsumexp(denominator_joint, dim=-1) - ) - - # relative weights. p_marginal := p_0, and p_joint := p_K * K from the notation. - p_marginal, p_joint = self._get_prior_probs_marginal_and_joint(gamma) - - loss = -torch.mean(p_marginal * log_prob_marginal + p_joint * log_prob_joint) - assert_all_finite(loss, "NRE-C loss") - return loss - - @staticmethod - def _get_prior_probs_marginal_and_joint(gamma: float) -> Tuple[float, float]: - r"""Return a tuple (p_marginal, p_joint) where `p_marginal := `$p_0$, - `p_joint := `$p_K \cdot K$. - - We let the joint (dependently drawn) class to be equally likely across K - options. The marginal class is therefore restricted to get the remaining - probability. - """ - p_joint = gamma / (1 + gamma) - p_marginal = 1 / (1 + gamma) - return p_marginal, p_joint + if loss_strategy is None: + kwargs["loss_strategy"] = CNRELoss() def _get_losses(self, batch: Sequence[Tensor], loss_args: LossArgs) -> Tensor: """Overrides the parent class method to check the type of loss_args.""" diff --git a/sbi/inference/trainers/nre/nre_loss.py b/sbi/inference/trainers/nre/nre_loss.py new file mode 100644 index 000000000..6a9f47c1c --- /dev/null +++ b/sbi/inference/trainers/nre/nre_loss.py @@ -0,0 +1,221 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Apache License Version 2.0, see + +from typing import Protocol, Tuple + +import torch +from torch import Tensor, eye, nn, ones + +from sbi.neural_nets.ratio_estimators import RatioEstimator +from sbi.utils.torchutils import assert_all_finite, repeat_rows + + +class NRELossStrategy(Protocol): + """Protocol defining the interface for all Neural Ratio Estimation loss strategies. + + A strategy must implement a forward pass (__call__) mapping parameters (theta), + observations (x), and algorithm-specific keyword arguments (like num_atoms) to a + scalar loss tensor. + """ + def __call__( + self, + neural_net: RatioEstimator, + device: str, + theta: Tensor, + x: Tensor, + **kwargs + ) -> Tensor: ... + + +def _classifier_logits( + neural_net: RatioEstimator, theta: Tensor, x: Tensor, num_atoms: int +) -> Tensor: + """Return logits obtained through classifier forward pass. + + The logits are obtained from atomic sets of (theta,x) pairs. + """ + batch_size = theta.shape[0] + repeated_x = repeat_rows(x, num_atoms) + + # Choose `1` or `num_atoms - 1` thetas from the rest of the batch for each x. + probs = ( + ones(batch_size, batch_size, device=theta.device) + * (1 - eye(batch_size, device=theta.device)) + / (batch_size - 1) + ) + + choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False) + + contrasting_theta = theta[choices] + + atomic_theta = torch.cat((theta[:, None, :], contrasting_theta), dim=1).reshape( + batch_size * num_atoms, -1 + ) + + return neural_net(atomic_theta, repeated_x) + + +class AALRLoss: + """Neural Ratio Estimation (NRE-A / AALR) loss strategy. + + Returns the binary cross-entropy loss for the trained classifier. + """ + def __call__(self, neural_net: RatioEstimator, device: str, theta: Tensor, x: Tensor, num_atoms: int, **kwargs) -> Tensor: + assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match." + batch_size = theta.shape[0] + + logits = _classifier_logits(neural_net, theta, x, num_atoms) + likelihood = torch.sigmoid(logits).squeeze() + + # Alternating pairs where there is one sampled from the joint and one + # sampled from the marginals. The first element is sampled from the + # joint p(theta, x) and is labelled 1. The second element is sampled + # from the marginals p(theta)p(x) and is labelled 0. And so on. + labels = ones(2 * batch_size, device=device) # two atoms + labels[1::2] = 0.0 + + # Binary cross entropy to learn the likelihood (AALR-specific) + loss = nn.BCELoss()(likelihood, labels) + assert_all_finite(loss, "NRE-A loss") + return loss + + +class SRELoss: + """Neural Ratio Estimation (NRE-B / SRE) loss strategy. + + Returns cross-entropy (via softmax activation) loss for 1-out-of-`num_atoms` + classification. + """ + def __call__(self, neural_net: RatioEstimator, device: str, theta: Tensor, x: Tensor, num_atoms: int, **kwargs) -> Tensor: + assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match." + batch_size = theta.shape[0] + logits = _classifier_logits(neural_net, theta, x, num_atoms) + + # For 1-out-of-`num_atoms` classification each datapoint consists + # of `num_atoms` points, with one of them being the correct one. + # We have a batch of `batch_size` such datapoints. + logits = logits.reshape(batch_size, num_atoms) + + # Index 0 is the theta-x-pair sampled from the joint p(theta,x) and hence the + # "correct" one for the 1-out-of-N classification. + log_prob = logits[:, 0] - torch.logsumexp(logits, dim=-1) + + loss = -torch.mean(log_prob) + assert_all_finite(loss, "NRE-B loss") + return loss + + +class CNRELoss: + """Neural Ratio Estimation (NRE-C / CNRE) loss strategy. + + Returns cross-entropy loss (via 'multi-class sigmoid' activation) for + 1-out-of-`K + 1` classification. + """ + @staticmethod + def _get_prior_probs_marginal_and_joint(gamma: float) -> Tuple[float, float]: + r"""Return a tuple (p_marginal, p_joint) where `p_marginal := `$p_0$, + `p_joint := `$p_K \cdot K$. + """ + p_joint = gamma / (1 + gamma) + p_marginal = 1 / (1 + gamma) + return p_marginal, p_joint + + def __call__( + self, + neural_net: RatioEstimator, + device: str, + theta: Tensor, + x: Tensor, + num_atoms: int, + gamma: float, + **kwargs + ) -> Tensor: + # Reminder: K = num_classes + num_classes = num_atoms - 1 + assert num_classes >= 1, f"num_classes = {num_classes} must be greater than 1." + + assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match." + batch_size = theta.shape[0] + + logits_marginal = _classifier_logits( + neural_net, theta, x, num_classes + 1 + ).reshape(batch_size, num_classes + 1) + + logits_joint = _classifier_logits( + neural_net, theta, x, num_classes + ).reshape(batch_size, num_classes) + + dtype = logits_marginal.dtype + + # Index 0 is the theta-x-pair sampled from the joint p(theta,x) and hence + # we remove the jointly drawn sample from the logits_marginal + logits_marginal = logits_marginal[:, 1:] + + # To use logsumexp, we extend the denominator logits with loggamma + loggamma = torch.tensor(gamma, dtype=dtype, device=device).log() + logK = torch.tensor(num_classes, dtype=dtype, device=device).log() + denominator_marginal = torch.concat( + [loggamma + logits_marginal, logK.expand((batch_size, 1))], + dim=-1, + ) + denominator_joint = torch.concat( + [loggamma + logits_joint, logK.expand((batch_size, 1))], + dim=-1, + ) + + # Compute the contributions to the loss from each term in the classification. + log_prob_marginal = logK - torch.logsumexp(denominator_marginal, dim=-1) + log_prob_joint = ( + loggamma + logits_joint[:, 0] - torch.logsumexp(denominator_joint, dim=-1) + ) + + # relative weights. p_marginal := p_0, and p_joint := p_K * K from the notation. + p_marginal, p_joint = CNRELoss._get_prior_probs_marginal_and_joint(gamma) + + loss = -torch.mean(p_marginal * log_prob_marginal + p_joint * log_prob_joint) + assert_all_finite(loss, "NRE-C loss") + return loss + + +class BNRELoss: + """Balanced Neural Ratio Estimation (BNRE) loss strategy. + + A variation of NRE-A that adds a balancing regularizer to the + binary cross-entropy loss. + """ + def __call__( + self, + neural_net: RatioEstimator, + device: str, + theta: Tensor, + x: Tensor, + num_atoms: int, + regularization_strength: float, + **kwargs + ) -> Tensor: + assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match." + batch_size = theta.shape[0] + + logits = _classifier_logits(neural_net, theta, x, num_atoms) + likelihood = torch.sigmoid(logits).squeeze() + + # Alternating pairs where there is one sampled from the joint and one + # sampled from the marginals. The first element is sampled from the + # joint p(theta, x) and is labelled 1. The second element is sampled + # from the marginals p(theta)p(x) and is labelled 0. And so on. + labels = ones(2 * batch_size, device=device) # two atoms + labels[1::2] = 0.0 + + # Binary cross entropy to learn the likelihood (AALR-specific) + bce = nn.BCELoss()(likelihood, labels) + + # Balancing regularizer + regularizer = ( + (torch.sigmoid(logits[0::2]) + torch.sigmoid(logits[1::2]) - 1) + .mean() + .square() + ) + + loss = bce + regularization_strength * regularizer + assert_all_finite(loss, "BNRE loss") + return loss diff --git a/tests/inference/npe_loss_test.py b/tests/inference/npe_loss_test.py new file mode 100644 index 000000000..7e22c9532 --- /dev/null +++ b/tests/inference/npe_loss_test.py @@ -0,0 +1,131 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Apache License Version 2.0, see + +import pytest +import torch +from torch.distributions import MultivariateNormal +from typing import Any, Optional + +from sbi.inference.posteriors.direct_posterior import DirectPosterior +from sbi.inference.trainers.npe.npe_loss import ( + AtomicLoss, + ImportanceWeightedLoss, + NonAtomicGaussianLoss, +) +from sbi.neural_nets.estimators.base import ConditionalDensityEstimator +from sbi.neural_nets.estimators.mixture_density_estimator import MixtureDensityEstimator +from sbi.neural_nets.estimators.mog import MoG +from sbi.utils.torchutils import BoxUniform + + +class DummyEstimator(ConditionalDensityEstimator): + def __init__(self, is_mixture=False): + super().__init__(net=torch.nn.Identity(), input_shape=torch.Size((2,)), condition_shape=torch.Size((2,))) + self.is_mixture = is_mixture + self.condition_shape = torch.Size((2,)) + self.input_shape = torch.Size((2,)) + + def log_prob(self, theta: torch.Tensor, condition: torch.Tensor, **kwargs): + # Dummy implementation returning zeros matching batch shape + batch_size = max(theta.shape[0], condition.shape[0]) + return torch.zeros(batch_size, device=theta.device) + + def sample(self, sample_shape: torch.Size, condition: torch.Tensor, **kwargs): + return torch.zeros(sample_shape + self.input_shape, device=condition.device) + + def loss(self, theta, condition, **kwargs): + return torch.zeros(theta.shape[0]) + + def get_uncorrected_mog(self, condition: torch.Tensor) -> MoG: + assert self.is_mixture + # Returns simple 1-component MoG + batch_size = condition.shape[0] + dim = self.input_shape[0] + logits = torch.zeros(batch_size, 1) + means = torch.zeros(batch_size, 1, dim) + precisions = torch.eye(dim).view(1, 1, dim, dim).expand(batch_size, 1, dim, dim) + precf = torch.eye(dim).view(1, 1, dim, dim).expand(batch_size, 1, dim, dim) + return MoG(logits, means, precisions, prec_factors=precf) + + @property + def has_input_transform(self): + return False + +# Proxy dummy specifically simulating MixtureDensityEstimator to pass isinstance +class DummyMixtureEstimator(DummyEstimator, MixtureDensityEstimator): # type: ignore + def __init__(self): + DummyEstimator.__init__(self, is_mixture=True) + + +class DummyPosterior(DirectPosterior): + def __init__(self, estimator, prior): + # We bypass DirectPosterior's strict checks for testing MoG extraction + self.posterior_estimator = estimator + self.prior = prior + self.default_x = torch.zeros(1, 2) + + +@pytest.fixture +def theta(): + return torch.randn(5, 2) + + +@pytest.fixture +def x(): + return torch.randn(5, 2) + + +@pytest.fixture +def masks(): + return torch.ones(5) + + +@pytest.fixture +def prior(): + return MultivariateNormal(torch.zeros(2), torch.eye(2)) + + +def test_atomic_loss_initialization_and_call(theta, x, masks, prior): + neural_net = DummyEstimator() + strategy = AtomicLoss(neural_net=neural_net, prior=prior, num_atoms=2) + + assert strategy.uses_only_latest_round is False + + loss = strategy(theta, x, masks, proposal=None) + assert loss.shape == (5,) + + +def test_non_atomic_loss_initialization_and_call(theta, x, masks, prior): + neural_net = DummyMixtureEstimator() + proposal = DummyPosterior(neural_net, prior) + + strategy = NonAtomicGaussianLoss( + neural_net=neural_net, + maybe_z_scored_prior=prior, + ) + + assert strategy.uses_only_latest_round is True + + loss = strategy(theta, x, masks, proposal=proposal) + assert loss.shape == (5,) + + +def test_importance_weighted_loss_initialization_and_call(theta, x, masks, prior): + neural_net = DummyEstimator() + proposal = DummyPosterior(neural_net, prior) + + theta_roundwise = [torch.randn(10, 2), torch.randn(10, 2)] + proposal_roundwise = [prior, proposal] + + strategy = ImportanceWeightedLoss( + neural_net=neural_net, + prior=prior, + round_idx=1, + theta_roundwise=theta_roundwise, + proposal_roundwise=proposal_roundwise, + ) + + assert strategy.uses_only_latest_round is False + + loss = strategy(theta, x, masks, proposal=proposal) + assert loss.shape == (5,)