|
| 1 | +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed |
| 2 | +# under the Apache License Version 2.0, see <https://www.apache.org/licenses/> |
| 3 | + |
| 4 | +import warnings |
| 5 | +from typing import Callable, Literal, Optional, Union |
| 6 | + |
| 7 | +import torch |
| 8 | +from torch import Tensor |
| 9 | +from torch.distributions import Distribution |
| 10 | + |
| 11 | +from sbi.inference.posteriors.direct_posterior import DirectPosterior |
| 12 | +from sbi.neural_nets.estimators.tabpfn_flow import TabPFNFlow |
| 13 | +from sbi.sbi_types import Shape |
| 14 | + |
| 15 | +FilterMode = Literal["knn", "first"] |
| 16 | +FilterFn = Callable[[Tensor, Tensor, Tensor, int], Tensor] |
| 17 | +FilterType = Union[FilterMode, FilterFn] |
| 18 | + |
| 19 | + |
| 20 | +class FilteredDirectPosterior(DirectPosterior): |
| 21 | + r"""Direct posterior with context filtering for TabPFN estimators. |
| 22 | +
|
| 23 | + For every queried condition `x`, this posterior selects a subset of context |
| 24 | + simulations and updates the underlying `TabPFNFlow` context before delegating to |
| 25 | + `DirectPosterior` sampling / log-probability logic. |
| 26 | + """ |
| 27 | + |
| 28 | + def __init__( |
| 29 | + self, |
| 30 | + estimator: TabPFNFlow, |
| 31 | + prior: Distribution, |
| 32 | + full_context_input: Tensor, |
| 33 | + full_context_condition: Tensor, |
| 34 | + max_sampling_batch_size: int = 10_000, |
| 35 | + device: Optional[Union[str, torch.device]] = None, |
| 36 | + x_shape: Optional[torch.Size] = None, |
| 37 | + enable_transform: bool = True, |
| 38 | + filter_type: FilterType = "knn", |
| 39 | + filter_size: int = 2048, |
| 40 | + ): |
| 41 | + r"""Initialize a direct posterior with observation-dependent context filtering. |
| 42 | +
|
| 43 | + Args: |
| 44 | + estimator: TabPFN-based posterior estimator used for evaluation. |
| 45 | + prior: Prior distribution over parameters. |
| 46 | + full_context_input: Full set of context inputs (typically `theta`). |
| 47 | + full_context_condition: Full set of context conditions (typically `x`). |
| 48 | + max_sampling_batch_size: Maximum number of samples drawn per internal batch. |
| 49 | + device: Device on which posterior computations are performed. |
| 50 | + x_shape: Optional event shape for observations. |
| 51 | + enable_transform: Whether to use unconstrained-space transforms for MAP. |
| 52 | + filter_type: Context filtering strategy. Either `"knn"`, `"first"`, |
| 53 | + or a callable returning selected indices. |
| 54 | + filter_size: Maximum number of context points retained per observation. |
| 55 | + """ |
| 56 | + if filter_size <= 1: |
| 57 | + raise ValueError(f"filter_size must be greater than 1, got {filter_size}.") |
| 58 | + |
| 59 | + super().__init__( |
| 60 | + posterior_estimator=estimator, |
| 61 | + prior=prior, |
| 62 | + max_sampling_batch_size=max_sampling_batch_size, |
| 63 | + device=device, |
| 64 | + x_shape=x_shape, |
| 65 | + enable_transform=enable_transform, |
| 66 | + ) |
| 67 | + |
| 68 | + self.filter_size = int(filter_size) |
| 69 | + self.filtering = filter_type |
| 70 | + self._full_context_input = full_context_input |
| 71 | + self._full_context_condition = full_context_condition |
| 72 | + self._full_context_condition_embedded = estimator.embed(full_context_condition) |
| 73 | + |
| 74 | + def _validate_filter_indices(self, indices: Tensor, num_context: int) -> Tensor: |
| 75 | + """Validate and normalize context indices returned by a filter.""" |
| 76 | + |
| 77 | + if indices.numel() < 2: |
| 78 | + raise ValueError("Filtering function must return at least two indices.") |
| 79 | + |
| 80 | + indices = indices.to(device=self._full_context_input.device, dtype=torch.long) |
| 81 | + unique_indices = torch.unique(indices, sorted=False) |
| 82 | + if unique_indices.numel() < indices.numel(): |
| 83 | + warnings.warn( |
| 84 | + "Filtering function returned duplicate indices. Duplicates were " |
| 85 | + "removed before setting context.", |
| 86 | + stacklevel=2, |
| 87 | + ) |
| 88 | + |
| 89 | + return unique_indices |
| 90 | + |
| 91 | + def _select_context_indices(self, condition_embedded: Tensor) -> Tensor: |
| 92 | + """Select context indices according to the configured filtering strategy.""" |
| 93 | + num_context = self._full_context_condition_embedded.shape[0] |
| 94 | + k = min(self.filter_size, num_context) |
| 95 | + |
| 96 | + if k >= num_context: |
| 97 | + return torch.arange(num_context, device=self._full_context_input.device) |
| 98 | + |
| 99 | + if isinstance(self.filtering, str): |
| 100 | + if self.filtering == "knn": |
| 101 | + indices = _knn_filter_indices( |
| 102 | + condition_embedded, self._full_context_condition_embedded, k |
| 103 | + ) |
| 104 | + elif self.filtering == "first": |
| 105 | + indices = _first_filter_indices(k, self._full_context_input.device) |
| 106 | + else: |
| 107 | + raise RuntimeError(f"Unsupported filtering mode: {self.filtering}") |
| 108 | + |
| 109 | + return self._validate_filter_indices(indices, num_context) |
| 110 | + |
| 111 | + indices = self.filtering( |
| 112 | + condition_embedded, |
| 113 | + self._full_context_input, |
| 114 | + self._full_context_condition_embedded, |
| 115 | + k, |
| 116 | + ) |
| 117 | + return self._validate_filter_indices(indices, num_context) |
| 118 | + |
| 119 | + def _set_context_for_x_o(self, x_o: Tensor) -> None: |
| 120 | + """Filter and set estimator context for a single queried observation.""" |
| 121 | + condition_embedded = self.posterior_estimator.embed(x_o) |
| 122 | + unique_indices = self._select_context_indices(condition_embedded) |
| 123 | + |
| 124 | + self.posterior_estimator.set_context( |
| 125 | + self._full_context_input[unique_indices], |
| 126 | + self._full_context_condition[unique_indices], |
| 127 | + ) |
| 128 | + |
| 129 | + def sample( |
| 130 | + self, |
| 131 | + sample_shape: Shape = torch.Size(), |
| 132 | + x: Optional[Tensor] = None, |
| 133 | + max_sampling_batch_size: int = 10_000, |
| 134 | + show_progress_bars: bool = True, |
| 135 | + reject_outside_prior: bool = True, |
| 136 | + max_sampling_time: Optional[float] = None, |
| 137 | + return_partial_on_timeout: bool = False, |
| 138 | + ) -> Tensor: |
| 139 | + r"""Sample from the posterior after setting context for the queried `x`. |
| 140 | +
|
| 141 | + Args: |
| 142 | + sample_shape: Shape of the returned sample batch. |
| 143 | + x: Observation to condition on. Uses the default observation if `None`. |
| 144 | + max_sampling_batch_size: Maximum internal sampling batch size. |
| 145 | + show_progress_bars: Whether to display progress bars. |
| 146 | + reject_outside_prior: Whether to reject samples outside prior support. |
| 147 | + max_sampling_time: Optional timeout in seconds. |
| 148 | + return_partial_on_timeout: Whether to return collected samples on timeout. |
| 149 | +
|
| 150 | + Returns: |
| 151 | + Samples from the filtered direct posterior. |
| 152 | + """ |
| 153 | + x_for_context = self._x_else_default_x(x) |
| 154 | + self._set_context_for_x_o(x_for_context) |
| 155 | + return super().sample( |
| 156 | + sample_shape=sample_shape, |
| 157 | + x=x, |
| 158 | + max_sampling_batch_size=max_sampling_batch_size, |
| 159 | + show_progress_bars=show_progress_bars, |
| 160 | + reject_outside_prior=reject_outside_prior, |
| 161 | + max_sampling_time=max_sampling_time, |
| 162 | + return_partial_on_timeout=return_partial_on_timeout, |
| 163 | + ) |
| 164 | + |
| 165 | + def sample_batched( |
| 166 | + self, |
| 167 | + sample_shape: Shape, |
| 168 | + x: Tensor, |
| 169 | + max_sampling_batch_size: int = 10_000, |
| 170 | + show_progress_bars: bool = True, |
| 171 | + reject_outside_prior: bool = True, |
| 172 | + max_sampling_time: Optional[float] = None, |
| 173 | + return_partial_on_timeout: bool = False, |
| 174 | + ) -> Tensor: |
| 175 | + """Batched sampling is not supported for observation-dependent filtering.""" |
| 176 | + raise NotImplementedError( |
| 177 | + "Filtering makes the context observation dependent. " |
| 178 | + "Batched inference requires sharing context, " |
| 179 | + "which is currently not supported." |
| 180 | + ) |
| 181 | + |
| 182 | + def log_prob( |
| 183 | + self, |
| 184 | + theta: Tensor, |
| 185 | + x: Optional[Tensor] = None, |
| 186 | + norm_posterior: bool = True, |
| 187 | + track_gradients: bool = False, |
| 188 | + leakage_correction_params: Optional[dict] = None, |
| 189 | + ) -> Tensor: |
| 190 | + r"""Evaluate posterior log-probability after setting context for `x`. |
| 191 | +
|
| 192 | + Args: |
| 193 | + theta: Parameters at which to evaluate log-probability. |
| 194 | + x: Observation to condition on. Uses the default observation if `None`. |
| 195 | + norm_posterior: Whether to include leakage correction normalization. |
| 196 | + track_gradients: Whether to evaluate with gradient tracking. |
| 197 | + leakage_correction_params: Optional parameters for leakage correction. |
| 198 | +
|
| 199 | + Returns: |
| 200 | + Posterior log-probabilities for ``theta`` conditioned on ``x``. |
| 201 | + """ |
| 202 | + x_for_context = self._x_else_default_x(x) |
| 203 | + self._set_context_for_x_o(x_for_context) |
| 204 | + return super().log_prob( |
| 205 | + theta=theta, |
| 206 | + x=x, |
| 207 | + norm_posterior=norm_posterior, |
| 208 | + track_gradients=track_gradients, |
| 209 | + leakage_correction_params=leakage_correction_params, |
| 210 | + ) |
| 211 | + |
| 212 | + def log_prob_batched( |
| 213 | + self, |
| 214 | + theta: Tensor, |
| 215 | + x: Tensor, |
| 216 | + norm_posterior: bool = True, |
| 217 | + track_gradients: bool = False, |
| 218 | + leakage_correction_params: Optional[dict] = None, |
| 219 | + ) -> Tensor: |
| 220 | + """Batched log-probability is unsupported with per-observation filtering.""" |
| 221 | + raise NotImplementedError( |
| 222 | + "Filtering makes the context observation dependent. " |
| 223 | + "Batched inference requires sharing context, " |
| 224 | + "which is currently not supported." |
| 225 | + ) |
| 226 | + |
| 227 | + def map( |
| 228 | + self, |
| 229 | + x=None, |
| 230 | + num_iter=1000, |
| 231 | + num_to_optimize=100, |
| 232 | + learning_rate=0.01, |
| 233 | + init_method="posterior", |
| 234 | + num_init_samples=1000, |
| 235 | + save_best_every=10, |
| 236 | + show_progress_bars=False, |
| 237 | + force_update=False, |
| 238 | + ): |
| 239 | + """MAP is not supported because gradient-based optimization is unavailable.""" |
| 240 | + raise NotImplementedError( |
| 241 | + "Computing the MAP requires gradients, which are currently not supported " |
| 242 | + "for NPE-PFN." |
| 243 | + ) |
| 244 | + |
| 245 | + |
| 246 | +def _knn_filter_indices( |
| 247 | + condition_embedded: Tensor, |
| 248 | + full_context_condition: Tensor, |
| 249 | + filter_size: int, |
| 250 | +) -> Tensor: |
| 251 | + """Return flattened k-nearest-neighbor context indices.""" |
| 252 | + distances = torch.cdist(condition_embedded, full_context_condition, p=2) |
| 253 | + nn_indices = torch.topk(distances, k=filter_size, largest=False, dim=1).indices |
| 254 | + return nn_indices.reshape(-1) |
| 255 | + |
| 256 | + |
| 257 | +def _first_filter_indices(filter_size: int, device: torch.device) -> Tensor: |
| 258 | + """Return indices of the first `filter_size` context entries.""" |
| 259 | + return torch.arange(filter_size, device=device) |
0 commit comments