Skip to content

Commit 839762d

Browse files
jsvetterdgedonmanuelgloecklermanuelgloeckler
authored
Add NPE-PFN (#1778)
* wip: update pyproject toml and copy over npe_pfn implementation * first working pure density estimator, with some rough edges still * wip: adding builder, some decisisons necessary soon * first kinda working version, many rough edges * working, with inheritance from neural inference * revert some unnecessary changes * some device handling * wip, dont allow standardzing x for now * no z scoring in default * very strict handling of standardization * first working filtering logic * some renaming * cleaner via build_posterior * completely get rid of train * add TODO * simplify stuff, add max context * add flexible filtering * update docstrings * implement sample_and_log_prob * more docstrings * fix filter_size validation lower bound * fix all reported precommit issues * run mini bm and add imports * deal with TabPFN license * use embedded dataset for filtering * small fix in posterior parameters * address comments from review * add tests for NPE-PFN * resolve versioning of tabpfn regressor * update fail safe without tabpfn dependency * fix missing header and ruff/pyright compliance * resolve review comments * one small CPU test * remove unnecessary test * remove dummy fit to trigger license agreement early * fix for repeated warning in build_posterior() * explicit device handling for tabpfnregressor --------- Co-authored-by: dgedon <daniel.gedon@gmx.de> Co-authored-by: manuelgloeckler <38903899+manuelgloeckler@users.noreply.github.com> Co-authored-by: manuelgloeckler <manug@Manuels-MacBook-Pro.local>
1 parent cd69050 commit 839762d

20 files changed

Lines changed: 1491 additions & 19 deletions

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@ dependencies = [
4444
]
4545

4646
[project.optional-dependencies]
47+
tabpfn = ["tabpfn"]
48+
notebook = ["notebook"]
4749
pymc = ["pymc>=5.0.0"]
4850
pyro = ["pyro-ppl>=1.3.1"]
49-
all = ["sbi[pymc]", "sbi[pyro]"]
50-
notebook = ["notebook"]
51+
all = ["sbi[pymc]", "sbi[pyro]", "sbi[tabpfn]"]
5152
doc = [
5253
"sbi[notebook]",
5354
# Documentation

sbi/inference/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
)
1010
from sbi.inference.trainers.marginal import MarginalTrainer
1111
from sbi.inference.trainers.nle import MNLE, NLE_A
12-
from sbi.inference.trainers.npe import MNPE, NPE_A, NPE_B, NPE_C # noqa: F401
12+
from sbi.inference.trainers.npe import MNPE, NPE_A, NPE_B, NPE_C, NPE_PFN # noqa: F401
1313
from sbi.inference.trainers.nre import BNRE, NRE_A, NRE_B, NRE_C # noqa: F401
1414
from sbi.inference.trainers.vfpe import FMPE, NPSE
1515

@@ -20,7 +20,7 @@
2020
SNPE_A = NPE_A
2121
SNPE_B = NPE_B
2222
SNPE = APT = SNPE_C = NPE = NPE_C
23-
_npe_family = ["NPE_A", "NPE_B", "NPE_C"]
23+
_npe_family = ["NPE_A", "NPE_B", "NPE_C", "NPE_PFN"]
2424

2525

2626
SRE = SNRE = SNRE_B = NRE = NRE_B

sbi/inference/posteriors/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
from sbi.inference.posteriors.direct_posterior import DirectPosterior
55
from sbi.inference.posteriors.ensemble_posterior import EnsemblePosterior
6+
from sbi.inference.posteriors.filtered_direct_posterior import FilteredDirectPosterior
67
from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior
78
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
89
from sbi.inference.posteriors.npe_a_posterior import NPE_A_Posterior
910
from sbi.inference.posteriors.posterior_parameters import (
1011
DirectPosteriorParameters,
12+
FilteredDirectPosteriorParameters,
1113
ImportanceSamplingPosteriorParameters,
1214
MCMCPosteriorParameters,
1315
RejectionPosteriorParameters,
@@ -28,9 +30,11 @@
2830
"VectorFieldPosterior",
2931
"VIPosterior",
3032
"DirectPosteriorParameters",
33+
"FilteredDirectPosteriorParameters",
3134
"ImportanceSamplingPosteriorParameters",
3235
"MCMCPosteriorParameters",
3336
"RejectionPosteriorParameters",
3437
"VectorFieldPosteriorParameters",
3538
"VIPosteriorParameters",
39+
"FilteredDirectPosterior",
3640
]
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
2+
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
3+
4+
import warnings
5+
from typing import Callable, Literal, Optional, Union
6+
7+
import torch
8+
from torch import Tensor
9+
from torch.distributions import Distribution
10+
11+
from sbi.inference.posteriors.direct_posterior import DirectPosterior
12+
from sbi.neural_nets.estimators.tabpfn_flow import TabPFNFlow
13+
from sbi.sbi_types import Shape
14+
15+
FilterMode = Literal["knn", "first"]
16+
FilterFn = Callable[[Tensor, Tensor, Tensor, int], Tensor]
17+
FilterType = Union[FilterMode, FilterFn]
18+
19+
20+
class FilteredDirectPosterior(DirectPosterior):
21+
r"""Direct posterior with context filtering for TabPFN estimators.
22+
23+
For every queried condition `x`, this posterior selects a subset of context
24+
simulations and updates the underlying `TabPFNFlow` context before delegating to
25+
`DirectPosterior` sampling / log-probability logic.
26+
"""
27+
28+
def __init__(
29+
self,
30+
estimator: TabPFNFlow,
31+
prior: Distribution,
32+
full_context_input: Tensor,
33+
full_context_condition: Tensor,
34+
max_sampling_batch_size: int = 10_000,
35+
device: Optional[Union[str, torch.device]] = None,
36+
x_shape: Optional[torch.Size] = None,
37+
enable_transform: bool = True,
38+
filter_type: FilterType = "knn",
39+
filter_size: int = 2048,
40+
):
41+
r"""Initialize a direct posterior with observation-dependent context filtering.
42+
43+
Args:
44+
estimator: TabPFN-based posterior estimator used for evaluation.
45+
prior: Prior distribution over parameters.
46+
full_context_input: Full set of context inputs (typically `theta`).
47+
full_context_condition: Full set of context conditions (typically `x`).
48+
max_sampling_batch_size: Maximum number of samples drawn per internal batch.
49+
device: Device on which posterior computations are performed.
50+
x_shape: Optional event shape for observations.
51+
enable_transform: Whether to use unconstrained-space transforms for MAP.
52+
filter_type: Context filtering strategy. Either `"knn"`, `"first"`,
53+
or a callable returning selected indices.
54+
filter_size: Maximum number of context points retained per observation.
55+
"""
56+
if filter_size <= 1:
57+
raise ValueError(f"filter_size must be greater than 1, got {filter_size}.")
58+
59+
super().__init__(
60+
posterior_estimator=estimator,
61+
prior=prior,
62+
max_sampling_batch_size=max_sampling_batch_size,
63+
device=device,
64+
x_shape=x_shape,
65+
enable_transform=enable_transform,
66+
)
67+
68+
self.filter_size = int(filter_size)
69+
self.filtering = filter_type
70+
self._full_context_input = full_context_input
71+
self._full_context_condition = full_context_condition
72+
self._full_context_condition_embedded = estimator.embed(full_context_condition)
73+
74+
def _validate_filter_indices(self, indices: Tensor, num_context: int) -> Tensor:
75+
"""Validate and normalize context indices returned by a filter."""
76+
77+
if indices.numel() < 2:
78+
raise ValueError("Filtering function must return at least two indices.")
79+
80+
indices = indices.to(device=self._full_context_input.device, dtype=torch.long)
81+
unique_indices = torch.unique(indices, sorted=False)
82+
if unique_indices.numel() < indices.numel():
83+
warnings.warn(
84+
"Filtering function returned duplicate indices. Duplicates were "
85+
"removed before setting context.",
86+
stacklevel=2,
87+
)
88+
89+
return unique_indices
90+
91+
def _select_context_indices(self, condition_embedded: Tensor) -> Tensor:
92+
"""Select context indices according to the configured filtering strategy."""
93+
num_context = self._full_context_condition_embedded.shape[0]
94+
k = min(self.filter_size, num_context)
95+
96+
if k >= num_context:
97+
return torch.arange(num_context, device=self._full_context_input.device)
98+
99+
if isinstance(self.filtering, str):
100+
if self.filtering == "knn":
101+
indices = _knn_filter_indices(
102+
condition_embedded, self._full_context_condition_embedded, k
103+
)
104+
elif self.filtering == "first":
105+
indices = _first_filter_indices(k, self._full_context_input.device)
106+
else:
107+
raise RuntimeError(f"Unsupported filtering mode: {self.filtering}")
108+
109+
return self._validate_filter_indices(indices, num_context)
110+
111+
indices = self.filtering(
112+
condition_embedded,
113+
self._full_context_input,
114+
self._full_context_condition_embedded,
115+
k,
116+
)
117+
return self._validate_filter_indices(indices, num_context)
118+
119+
def _set_context_for_x_o(self, x_o: Tensor) -> None:
120+
"""Filter and set estimator context for a single queried observation."""
121+
condition_embedded = self.posterior_estimator.embed(x_o)
122+
unique_indices = self._select_context_indices(condition_embedded)
123+
124+
self.posterior_estimator.set_context(
125+
self._full_context_input[unique_indices],
126+
self._full_context_condition[unique_indices],
127+
)
128+
129+
def sample(
130+
self,
131+
sample_shape: Shape = torch.Size(),
132+
x: Optional[Tensor] = None,
133+
max_sampling_batch_size: int = 10_000,
134+
show_progress_bars: bool = True,
135+
reject_outside_prior: bool = True,
136+
max_sampling_time: Optional[float] = None,
137+
return_partial_on_timeout: bool = False,
138+
) -> Tensor:
139+
r"""Sample from the posterior after setting context for the queried `x`.
140+
141+
Args:
142+
sample_shape: Shape of the returned sample batch.
143+
x: Observation to condition on. Uses the default observation if `None`.
144+
max_sampling_batch_size: Maximum internal sampling batch size.
145+
show_progress_bars: Whether to display progress bars.
146+
reject_outside_prior: Whether to reject samples outside prior support.
147+
max_sampling_time: Optional timeout in seconds.
148+
return_partial_on_timeout: Whether to return collected samples on timeout.
149+
150+
Returns:
151+
Samples from the filtered direct posterior.
152+
"""
153+
x_for_context = self._x_else_default_x(x)
154+
self._set_context_for_x_o(x_for_context)
155+
return super().sample(
156+
sample_shape=sample_shape,
157+
x=x,
158+
max_sampling_batch_size=max_sampling_batch_size,
159+
show_progress_bars=show_progress_bars,
160+
reject_outside_prior=reject_outside_prior,
161+
max_sampling_time=max_sampling_time,
162+
return_partial_on_timeout=return_partial_on_timeout,
163+
)
164+
165+
def sample_batched(
166+
self,
167+
sample_shape: Shape,
168+
x: Tensor,
169+
max_sampling_batch_size: int = 10_000,
170+
show_progress_bars: bool = True,
171+
reject_outside_prior: bool = True,
172+
max_sampling_time: Optional[float] = None,
173+
return_partial_on_timeout: bool = False,
174+
) -> Tensor:
175+
"""Batched sampling is not supported for observation-dependent filtering."""
176+
raise NotImplementedError(
177+
"Filtering makes the context observation dependent. "
178+
"Batched inference requires sharing context, "
179+
"which is currently not supported."
180+
)
181+
182+
def log_prob(
183+
self,
184+
theta: Tensor,
185+
x: Optional[Tensor] = None,
186+
norm_posterior: bool = True,
187+
track_gradients: bool = False,
188+
leakage_correction_params: Optional[dict] = None,
189+
) -> Tensor:
190+
r"""Evaluate posterior log-probability after setting context for `x`.
191+
192+
Args:
193+
theta: Parameters at which to evaluate log-probability.
194+
x: Observation to condition on. Uses the default observation if `None`.
195+
norm_posterior: Whether to include leakage correction normalization.
196+
track_gradients: Whether to evaluate with gradient tracking.
197+
leakage_correction_params: Optional parameters for leakage correction.
198+
199+
Returns:
200+
Posterior log-probabilities for ``theta`` conditioned on ``x``.
201+
"""
202+
x_for_context = self._x_else_default_x(x)
203+
self._set_context_for_x_o(x_for_context)
204+
return super().log_prob(
205+
theta=theta,
206+
x=x,
207+
norm_posterior=norm_posterior,
208+
track_gradients=track_gradients,
209+
leakage_correction_params=leakage_correction_params,
210+
)
211+
212+
def log_prob_batched(
213+
self,
214+
theta: Tensor,
215+
x: Tensor,
216+
norm_posterior: bool = True,
217+
track_gradients: bool = False,
218+
leakage_correction_params: Optional[dict] = None,
219+
) -> Tensor:
220+
"""Batched log-probability is unsupported with per-observation filtering."""
221+
raise NotImplementedError(
222+
"Filtering makes the context observation dependent. "
223+
"Batched inference requires sharing context, "
224+
"which is currently not supported."
225+
)
226+
227+
def map(
228+
self,
229+
x=None,
230+
num_iter=1000,
231+
num_to_optimize=100,
232+
learning_rate=0.01,
233+
init_method="posterior",
234+
num_init_samples=1000,
235+
save_best_every=10,
236+
show_progress_bars=False,
237+
force_update=False,
238+
):
239+
"""MAP is not supported because gradient-based optimization is unavailable."""
240+
raise NotImplementedError(
241+
"Computing the MAP requires gradients, which are currently not supported "
242+
"for NPE-PFN."
243+
)
244+
245+
246+
def _knn_filter_indices(
247+
condition_embedded: Tensor,
248+
full_context_condition: Tensor,
249+
filter_size: int,
250+
) -> Tensor:
251+
"""Return flattened k-nearest-neighbor context indices."""
252+
distances = torch.cdist(condition_embedded, full_context_condition, p=2)
253+
nn_indices = torch.topk(distances, k=filter_size, largest=False, dim=1).indices
254+
return nn_indices.reshape(-1)
255+
256+
257+
def _first_filter_indices(filter_size: int, device: torch.device) -> Tensor:
258+
"""Return indices of the first `filter_size` context entries."""
259+
return torch.arange(filter_size, device=device)

sbi/inference/posteriors/posterior_parameters.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,44 @@ def validate(self):
131131
raise ValueError("max_sampling_batch_size must be greater than 0.")
132132

133133

134+
@dataclass(frozen=True)
135+
class FilteredDirectPosteriorParameters(PosteriorParameters):
136+
"""Parameters for initializing `FilteredDirectPosterior`.
137+
138+
Fields:
139+
max_sampling_batch_size: Batchsize of samples drawn from
140+
the proposal at every iteration.
141+
enable_transform: Whether to transform parameters to unconstrained space
142+
during MAP optimization. When False, an identity transform will be
143+
returned for `theta_transform`.
144+
filter_size: Number of context simulations retained after filtering.
145+
filter_type: Filtering strategy. Either `"knn"`, `"first"`, or a
146+
callable returning context indices.
147+
"""
148+
149+
max_sampling_batch_size: int = 10_000
150+
enable_transform: bool = True
151+
filter_size: int = 2048
152+
filter_type: Union[Literal["knn", "first"], Callable] = "knn"
153+
154+
def validate(self):
155+
"""Validate `FilteredDirectPosteriorParameters` fields."""
156+
157+
if not is_positive_int(self.max_sampling_batch_size):
158+
raise ValueError("max_sampling_batch_size must be greater than 0.")
159+
160+
if not is_positive_int(self.filter_size - 1):
161+
raise ValueError("filter_size must be greater than 1.")
162+
163+
if not (
164+
(isinstance(self.filter_type, str) and self.filter_type in {"knn", "first"})
165+
or callable(self.filter_type)
166+
):
167+
raise ValueError(
168+
"filter_type must be one of ['knn', 'first'] or a callable."
169+
)
170+
171+
134172
@dataclass(frozen=True)
135173
class ImportanceSamplingPosteriorParameters(PosteriorParameters):
136174
"""

sbi/inference/potentials/posterior_based_potential.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
mcmc_transform,
2121
within_support,
2222
)
23-
from sbi.utils.torchutils import ensure_theta_batched
23+
from sbi.utils.torchutils import ensure_theta_batched, infer_module_device
2424

2525

2626
def posterior_estimator_based_potential(
@@ -49,7 +49,7 @@ def posterior_estimator_based_potential(
4949
to unconstrained space.
5050
"""
5151

52-
device = str(next(posterior_estimator.parameters()).device)
52+
device = infer_module_device(posterior_estimator, fallback="cpu")
5353

5454
potential_fn = PosteriorBasedPotential(
5555
posterior_estimator, prior, x_o, device=device

0 commit comments

Comments
 (0)