Add NPE-PFN#1778
Conversation
Codecov Report❌ Patch coverage is ❌ Your patch check has failed because the patch coverage (31.04%) is below the target coverage (50.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1778 +/- ##
==========================================
- Coverage 87.83% 86.29% -1.54%
==========================================
Files 140 143 +3
Lines 12958 13313 +355
==========================================
+ Hits 11382 11489 +107
- Misses 1576 1824 +248
Flags with carried forward coverage won't be shown. Click here to find out more.
|
There was a problem hiding this comment.
Looks overall great already! I added a comment on a few minor issues.
What is currently still missing is adding some tests i.e.:
By keeping the context size small should allow these tests to run relatively fast, no? Otherwise we might mark most of them as slow.
|
I addressed your comments @manuelgloeckler. Some changes:
|
There was a problem hiding this comment.
Thanks. Looks good overall.
We just need to add a few more tests, and fix a few problems that mostly came from merging in current main (and its corresponding changes). These problems just surfaced on running mini-sbibm, but there should be some tests in place that would catch this.
So I think mostly what is left todo is to add whole workflow tests like for other methods e.g. in linearGaussian_snpe_test.py
manuelgloeckler
left a comment
There was a problem hiding this comment.
Looks good.
I added one small full workflow tests. There is one issue with the device identification on TabPFN (see comment above). But I dont see a good way to avoid it for now.
| except StopIteration: | ||
| try: | ||
| return str(next(module.buffers()).device) | ||
| except StopIteration: |
There was a problem hiding this comment.
Not entirely sure, but will this be triggered everytime a TabPFNFlow is build (as they do not have parameters, no?).
If so this warning would be a bit of spam-like as there its by design without parameters.
Note sure whats the best solution here, however.
|
Thanks @manuelgloeckler. I did already add a full workflow test in tests/posterior_nn_test.pytests/posterior_nn_test.py. Let me know whether that should be removed with the new cpu test that you added. |
The goal of this PR is to add NPE-PFN to SBI, as discussed in #1682.
The implementation is realized mostly by three new components, which I will briefly describe in the following.
Happy to discuss all of this, as the exisiting assumptions encoded trough base classes like
NeuralInferenceorConditionalDensityEstimatorsometimes make more and sometimes make less sense for NPE-PFN.There are three key files that implement the method:
1.)
tabpfn_flow.pyimplements the in-contextConditionalDensityEstimatorbased on the autoregressive use of TabPFN. It behaves exactly like other estimators, and given some context dataset provides sampling and log-prob functionality.2.)
npe_pfn.pyimplements theNPE_PFNclass which, inherits fromNeuralInferenceand implements the basic logic used across the package (append_simulations,train,build_posterioretc.). Since NPE-PFN is training free, thetrainmethod is a stub, and most functionality is handled directly bybuild_posterior. This allows users to calltrainwithout breaking any previous workflow, but they can also "forget" about it as would be suggested by a training-free method.Since the TabPFN-based flow behaves like any other flow, NPE-PFN supports out-of-the-box many different types of posteriors (Direct, Rejection, IS, could add more, but inference is too slow for MCMC). However, a crucial feature of NPE-PFN is filtering, where the context dataset is selected based on a given observation.
To support this functionality, a new posterior class is required.
3.)
filtered_direct_posterior.pyimplements this posterior (inheriting fromDirectPosterior), which allows filtering based on different filters (usually KNN, but users can also provide a custom callable).There are many other smaller changes (builders, dataclasses, etc.) and so far no tests.
Also, this PR contains the core functionality for amortized inference. More advanced stuff like sequential inference, or even support for finetuning etc. (which we didn't even do in the paper) are not added.
It probably makes sense to dicuss this approach first, before I add fine-grained tests or possibly more functionality.
Here are results for the mini benchmark:
