Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions sbi/inference/trainers/npe/npe_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,28 +462,7 @@ def build_posterior(

return self._posterior

def _log_prob_proposal_posterior(
self,
theta: Tensor,
x: Tensor,
masks: Tensor,
proposal: Optional[Any],
) -> Tensor:
"""Return the log-probability of the proposal posterior.

For SNPE-A this is the same as `self._neural_net.log_prob(theta, x)` in
`_loss()` to be found in `snpe_base.py`.

Args:
theta: Batch of parameters θ.
x: Batch of data.
masks: Mask that is True for prior samples in the batch in order to train
them with prior loss.
proposal: Proposal distribution.

Returns: Log-probability of the proposal posterior.
"""
return self._neural_net.log_prob(theta, x)


def _correct_for_proposal(
Expand Down
114 changes: 45 additions & 69 deletions sbi/inference/trainers/npe/npe_b.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Any, Literal, Optional, Union
from typing import Callable, Dict, Literal, Optional, Union

import torch
from torch import Tensor
from torch.distributions import Distribution
from torch.utils.tensorboard.writer import SummaryWriter

import sbi.utils as utils
from sbi.inference.trainers.npe.npe_base import (
PosteriorEstimatorTrainer,
)
from sbi.inference.trainers.npe.npe_loss import (
ImportanceWeightedLoss,
NPELossStrategy,
)
from sbi.neural_nets.estimators.base import (
ConditionalDensityEstimator,
ConditionalEstimatorBuilder,
)
from sbi.neural_nets.estimators.shape_handling import reshape_to_sample_batch_event
from sbi.sbi_types import Tracker
from sbi.utils.sbiutils import del_entries

Expand Down Expand Up @@ -107,72 +107,48 @@ def __init__(
kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)

def _log_prob_proposal_posterior(
def train(
self,
theta: Tensor,
x: Tensor,
masks: Tensor,
proposal: Optional[Any],
) -> Tensor:
"""
Return importance-weighted log probability (Lueckmann, Goncalves et al., 2017).

Args:
theta: Batch of parameters θ.
x: Batch of data.
masks: Mask that is True for prior samples in the batch in order to train
them with prior loss.
proposal: Proposal distribution.

Returns:
Importance-weighted log-probability of the proposal posterior.
"""

# Evaluate prior
# we accept prior log prob to be -Inf at theta
# meaning that theta is out of the prior range (the weight is thus 0)
utils.assert_not_nan_or_plus_inf(
self._prior.log_prob(theta), "prior log probs of proposal samples"
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: int = 2**31 - 1,
clip_max_norm: Optional[float] = 5.0,
calibration_kernel: Optional[Callable] = None,
resume_training: bool = False,
force_first_round_loss: bool = False,
discard_prior_samples: bool = False,
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
loss_strategy: Optional[NPELossStrategy] = None,
) -> ConditionalDensityEstimator:

kwargs = del_entries(
locals(),
entries=("self", "__class__", "loss_strategy"),
)
prior = torch.exp(self._prior.log_prob(theta))

# Evaluate proposal
# (as theta comes from prior and proposal from previous rounds,
# the last proposal is actually a mixture of the prior
# and of all the previous proposals with coefficients representing
# the proportion of the new theta added at each round)
prop = torch.zeros(self._round + 1, device=theta.device)
nb_samples = 0 # total number of theta from all the rounds

for k in range(self._round + 1):
nb_samples += self._theta_roundwise[k].size(0)
# the number of new theta sampled in the round k
prop[k] = self._theta_roundwise[k].size(0)

prop /= nb_samples
log_prop = torch.log(prop).repeat(theta.size(0), 1)

log_previous_proposals = torch.zeros(
(theta.size(0), self._round + 1), device=theta.device
)
for k, density in enumerate(self._proposal_roundwise):
# we accept the k th proposal log prob to be -Inf at theta
# meaning that theta is out of the k th proposal range
log_previous_proposals[:, k] = density.log_prob(theta)
utils.assert_not_nan_or_plus_inf(
log_previous_proposals[:, k], "proposal log probs of proposal samples"
)

log_proposal = torch.logsumexp(log_prop + log_previous_proposals, dim=1)
proposal = torch.exp(log_proposal)

# Construct the importance weights and normalize them
importance_weights = prior / proposal
importance_weights /= importance_weights.sum()
if len(self._data_round_index) == 0:
raise RuntimeError(
"No simulations found. You must call .append_simulations() "
"before calling .train()."
)

theta = reshape_to_sample_batch_event(theta, theta.shape[1:])
# Reshape the density estimator log probs
# from (sample_shape, batch_shape) to (batch_shape)
posterior_log_probs = self._neural_net.log_prob(theta, x).squeeze(dim=0)
self._round = max(self._data_round_index)

if loss_strategy is not None:
self._loss_strategy = loss_strategy
elif self._round > 0:
self._loss_strategy = ImportanceWeightedLoss(
neural_net=self._neural_net,
prior=self._prior,
round_idx=self._round,
theta_roundwise=self._theta_roundwise,
proposal_roundwise=self._proposal_roundwise,
)
else:
self._loss_strategy = None

return importance_weights * posterior_log_probs
return super().train(**kwargs)
29 changes: 15 additions & 14 deletions sbi/inference/trainers/npe/npe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import warnings
from abc import ABC, abstractmethod
from abc import ABC
from dataclasses import asdict
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -32,6 +32,7 @@
NeuralInference,
check_if_proposal_has_default_x,
)
from sbi.inference.trainers.npe.npe_loss import NPELossStrategy
from sbi.neural_nets import posterior_nn
from sbi.neural_nets.estimators import ConditionalDensityEstimator
from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
Expand Down Expand Up @@ -111,16 +112,9 @@ def __init__(
self._build_neural_net = density_estimator

self._proposal_roundwise = []
self.use_non_atomic_loss = False
self._loss_strategy: Optional[NPELossStrategy] = None


@abstractmethod
def _log_prob_proposal_posterior(
self,
theta: Tensor,
x: Tensor,
masks: Tensor,
proposal: Optional[Any],
) -> Tensor: ...

def append_simulations(
self,
Expand Down Expand Up @@ -194,10 +188,11 @@ def append_simulations(

# Check for problematic z-scoring
warn_if_invalid_for_zscoring(x)
is_atomic = self._loss_strategy is None or not getattr(self._loss_strategy, "uses_only_latest_round", True)
if (
type(self).__name__ == "SNPE_C"
type(self).__name__ in ("SNPE_C", "NPE_C")
and current_round > 0
and not self.use_non_atomic_loss
and is_atomic
):
nle_nre_apt_msg_on_invalid_x(
num_nans,
Expand Down Expand Up @@ -254,6 +249,7 @@ def train(
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[dict] = None,
loss_strategy: Optional[NPELossStrategy] = None,
) -> ConditionalDensityEstimator:
r"""Return density estimator that approximates the distribution $p(\theta|x)$.

Expand Down Expand Up @@ -550,9 +546,11 @@ def _loss(
# Use posterior log prob (without proposal correction) for first round.
loss = self._neural_net.loss(theta, x)
else:
if self._loss_strategy is None:
raise ValueError("Loss strategy is None but round > 0.")
# Currently only works for `DensityEstimator` objects.
# Must be extended ones other Estimators are implemented. See #966,
loss = -self._log_prob_proposal_posterior(theta, x, masks, proposal)
loss = -self._loss_strategy(theta, x, masks, proposal)

assert_all_finite(loss, "NPE loss")
return calibration_kernel(x) * loss
Expand Down Expand Up @@ -650,7 +648,10 @@ def _get_start_index(self, context: StartIndexContext) -> int:
# SNPE-A can, by construction of the algorithm, only use samples from the last
# round. SNPE-A is the only algorithm that has an attribute `_ran_final_round`,
# so this is how we check for whether or not we are using SNPE-A.
if self.use_non_atomic_loss or hasattr(self, "_ran_final_round"):
if (
self._loss_strategy is not None
and self._loss_strategy.uses_only_latest_round
) or hasattr(self, "_ran_final_round"):
start_idx = self._round

return start_idx
Expand Down
Loading