Skip to content

Add optional check_finite flag to process_x#1828

Open
patelshivani2283-lab wants to merge 2 commits intosbi-dev:mainfrom
patelshivani2283-lab:fix-nan-padding-check
Open

Add optional check_finite flag to process_x#1828
patelshivani2283-lab wants to merge 2 commits intosbi-dev:mainfrom
patelshivani2283-lab:fix-nan-padding-check

Conversation

@patelshivani2283-lab
Copy link
Copy Markdown
Contributor

This PR adds an optional check_finite argument to process_x.

By default it is set to True, so the existing behavior stays the same.
If set to False, it skips the NaN/Inf check (assert_all_finite).

This can be useful in cases where the data is already validated and the extra check is not needed.

@patelshivani2283-lab
Copy link
Copy Markdown
Contributor Author

This is an initial step toward #1717. I can extend this to propagate the flag through posterior classes if needed.

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 31, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 86.31%. Comparing base (937efc2) to head (dd290ee).
⚠️ Report is 43 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1828      +/-   ##
==========================================
- Coverage   88.54%   86.31%   -2.23%     
==========================================
  Files         137      143       +6     
  Lines       11515    17753    +6238     
==========================================
+ Hits        10196    15324    +5128     
- Misses       1319     2429    +1110     
Flag Coverage Δ
fast 82.77% <100.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/utils/user_input_checks.py 76.80% <100.00%> (+0.12%) ⬆️

... and 77 files with indirect coverage changes

@patelshivani2283-lab
Copy link
Copy Markdown
Contributor Author

Hi! Just wanted to follow up on this PR

I added the check_finite flag to process_x to make NaN/Inf validation optional while keeping the default behavior unchanged.

I was also thinking about whether this should be propagated further (e.g., through posterior classes), but I wasn’t sure if that’s the intended direction or if a more minimal change is preferred.

Happy to adjust based on your feedback

@janfb
Copy link
Copy Markdown
Contributor

janfb commented May 8, 2026

Hi @patelshivani2283-lab, thanks for picking this up and sorry for the slow response on my side!

You correctly identified process_x as the relevant function, that's what I would have done as well. But after looking into the call graph a bit more, I think the cleanest fix is actually differen than what I sketched in #1717. The issue: process_x is called twice on the same tensor in the non-direct sampling path (once in NeuralPosterior._x_else_default_x, then again in BasePotential.set_x), and process_x's real job is dtype/shape normalization, not finiteness validation.

PR #1701 added assert_all_finite inside process_x, but I think that's actually the wrong place because we now would have to pass this new kwarg through all the call sites. We can avoid all that by moving the assertion up to the layer that actually owns the user-facing data contract, the posterior class.

I suggest the following:

  1. In sbi/utils/user_input_checks.py — drop the check_finite parameter you added, and remove the assert_all_finite(x, ...) call from process_x entirely.

  2. In sbi/inference/posteriors/base_posterior.pyNeuralPosterior.__init__() accepts a new check_finite_x: bool = True parameter, stored as self._check_finite_x. Then in set_default_x() (line 183) and _x_else_default_x() (line 191), call assert_all_finite(x, "Observed data x_o contains Nans or Infs.") after process_x, conditional on self._check_finite_x.

  3. In sbi/inference/posteriors/direct_posterior.py and sbi/inference/posteriors/vector_field_posterior.py — accept check_finite_x in __init__() and forward to super().__init__(). Both DirectPosterior (NPE) and VectorFieldPosterior (FMPE, NPSE) are realistic combinations with PermutationInvariantEmbedding, so both need the kwarg.

  4. In sbi/inference/posteriors/posterior_parameters.py — add check_finite_x: bool = True to the base PosteriorParameters dataclass. It then auto-flows to all *PosteriorParameters subclasses (DirectPosteriorParameters, VectorFieldPosteriorParameters, etc.).

Note: MCMC / rejection / VI / importance subclass __init__s don't need to accept the kwarg yet, those backends are rarely combined with PermutationInvariantEmbedding (they're post-hoc samplers on top of NPE/NLE/NRE).

  1. In tests/user_input_checks_test.py — the two @pytest.mark.xfail cases at lines 215-228 (one NaN, one Inf) document the assertion contract on process_x. Since we're removing that contract from process_x, please drop those xfail markers.

  2. In tests/embedding_net_test.py — drop @pytest.mark.xfail at line 448 and build the posterior with posterior_parameters=DirectPosteriorParameters(check_finite_x=False).

Your other questions:

  • Show a warning when check_finite_x=False? No, keep it silent. It's an explicit opt-in for a documented use case (NaN padding from PermutationInvariantEmbedding), and a warning on every sample call would be noisy.
  • Please add a short test in tests/user_input_checks_test.py confirming NeuralPosterior(...).set_default_x(x_with_nan) raises by default and works with check_finite_x=False.

Sorry for the back-and-forth on the design! Your initial PR pointed at the right function and that's what made the cleaner approach visible. Thanks for sticking with this 🙏
Jan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants