Skip to content

Commit e7fbec0

Browse files
janfbclaude
andauthored
refactor: LC2ST module refactoring with states (#1727)
* fix flaky test * fix: correct typos and improve error handling. * feat: enhance LC2ST with new state management and structured score return type * refactor: streamline normalization process and enhance classifier defaults * refactor: enhance LC2ST test structure with dataclasses for setup and calibration * fix: correct typos and enhance documentation in LC2ST class * refactor tests * refactor: LC2ST class and related tests for improved clarity and functionality - Updated parameter names in LC2ST initialization for consistency (thetas -> prior_samples). - Modified get_scores and get_statistics_under_null_hypothesis methods to return LC2STScores objects, encapsulating probabilities and scores. - Adjusted usage of get_scores in the tutorial and tests to reflect the new return type. - input validation in LC2ST to prevent indexing errors. - Updated tests to assert the presence of scores in the returned null statistics. * fix: address review remarks on LC2STstate and nb Fixes three merge blockers from review: - LC2ST_NF(trained_clfs_null=...) was deadlocked: __init__ left _state at INITIALIZED when pretrained null classifiers were passed, so train_on_observed_data advanced to OBSERVED_TRAINED (not READY) and p_value raised. Now advances to NULL_TRAINED when pretrained classifiers are supplied. - train_on_observed_data downgraded READY -> OBSERVED_TRAINED on retrain, breaking the documented loop-over-seeds workflow from the tutorial. READY is now preserved. - Advanced-tutorial notebook called np.quantile / axes.hist on the new LC2STScores return value. Both failing cells now extract .scores. * style: ruff format new LC2ST regression tests * fix: address LC2ST review remarks on API and types - Narrow get_scores / get_statistics_under_null_hypothesis return type to Union[LC2STScores, Tuple[np.ndarray, np.ndarray]]; the bare np.ndarray branch was never returned. - Clarify return_probs deprecation warning to mention the (probs, scores) tuple order and the eventual removal. - Guard z-score normalization against constant feature dimensions: std == 0 is replaced by 1.0 so constant columns become pass-through (mean-centered) instead of producing NaN/Inf. - Rewrite the error message raised when re-entering train_under_null_hypothesis so it applies cleanly to both LC2ST (permutation, data-dependent) and LC2ST_NF (analytical, reusable). - Fix docstring: verbosity in get_statistics_under_null_hypothesis defaults to 0, not 1. - Add regression tests for single-normalization in null training, constant-dim normalization robustness, and document the lc2st_instance fixture scope. * fix: use FutureWarning for user-facing LC2ST deprecations --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 8facb74 commit e7fbec0

5 files changed

Lines changed: 1242 additions & 436 deletions

File tree

docs/advanced_tutorials/13_diagnostics_lc2st.ipynb

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@
252252
"source": [
253253
"# set up the LC2ST: train the classifiers\n",
254254
"lc2st = LC2ST(\n",
255-
" thetas=theta_cal,\n",
255+
" prior_samples=theta_cal,\n",
256256
" xs=x_cal,\n",
257257
" posterior_samples=post_samples_cal,\n",
258258
" classifier=\"mlp\",\n",
@@ -301,20 +301,21 @@
301301
"source": [
302302
"fig, axes = plt.subplots(1,len(thetas_star), figsize=(12,3))\n",
303303
"for i in range(len(thetas_star)):\n",
304-
" probs, scores = lc2st.get_scores(\n",
304+
" lc2st_scores = lc2st.get_scores(\n",
305305
" theta_o=post_samples_star[i],\n",
306306
" x_o=xs_star[i],\n",
307-
" return_probs=True,\n",
308307
" trained_clfs=lc2st.trained_clfs\n",
309308
" )\n",
309+
" probs = lc2st_scores.probabilities\n",
310+
" scores = lc2st_scores.scores\n",
310311
" T_data = lc2st.get_statistic_on_observed_data(\n",
311312
" theta_o=post_samples_star[i],\n",
312313
" x_o=xs_star[i]\n",
313314
" )\n",
314315
" T_null = lc2st.get_statistics_under_null_hypothesis(\n",
315316
" theta_o=post_samples_star[i],\n",
316317
" x_o=xs_star[i]\n",
317-
" )\n",
318+
" ).scores\n",
318319
" p_value = lc2st.p_value(post_samples_star[i], xs_star[i])\n",
319320
" reject = lc2st.reject_test(post_samples_star[i], xs_star[i], alpha=conf_alpha)\n",
320321
"\n",
@@ -383,17 +384,15 @@
383384
"\n",
384385
"fig, axes = plt.subplots(1,len(thetas_star), figsize=(12,3))\n",
385386
"for i in range(len(thetas_star)):\n",
386-
" probs_data, _ = lc2st.get_scores(\n",
387+
" probs_data = lc2st.get_scores(\n",
387388
" theta_o=post_samples_star[i],\n",
388389
" x_o=xs_star[i],\n",
389-
" return_probs=True,\n",
390390
" trained_clfs=lc2st.trained_clfs\n",
391-
" )\n",
392-
" probs_null, _ = lc2st.get_statistics_under_null_hypothesis(\n",
391+
" ).probabilities\n",
392+
" probs_null = lc2st.get_statistics_under_null_hypothesis(\n",
393393
" theta_o=post_samples_star[i],\n",
394-
" x_o=xs_star[i],\n",
395-
" return_probs=True\n",
396-
" )\n",
394+
" x_o=xs_star[i]\n",
395+
" ).probabilities\n",
397396
"\n",
398397
" pp_plot_lc2st(\n",
399398
" probs=[probs_data],\n",
@@ -451,12 +450,11 @@
451450
"\n",
452451
"fig, axes = plt.subplots(len(thetas_star), 3, figsize=(9,6), constrained_layout=True)\n",
453452
"for i in range(len(thetas_star)):\n",
454-
" probs_data, _ = lc2st.get_scores(\n",
453+
" probs_data = lc2st.get_scores(\n",
455454
" theta_o=post_samples_star[i][:1000],\n",
456455
" x_o=xs_star[i],\n",
457-
" return_probs=True,\n",
458456
" trained_clfs=lc2st.trained_clfs\n",
459-
" )\n",
457+
" ).probabilities\n",
460458
" dict_probs_marginals = get_probs_per_marginal(\n",
461459
" probs_data[0],\n",
462460
" post_samples_star[i][:1000].numpy()\n",
@@ -543,7 +541,7 @@
543541
"For different classifier architectures, you should choose the one with the smallest variance. \n",
544542
"\n",
545543
"### Number of calibration samples\n",
546-
"A similar check can also be performed via cross-validation: set the `num_folds` parameter of your `LC2ST` object, train on observed data and call `lc2st.get_scores(theta_o, x_o, lc2st.trained_clfs)`. This outputs the test statistics obtained for each cv-fold. You should choose the smallest calibration set size that gives you a small enough variance over the test statistics. \n",
544+
"A similar check can also be performed via cross-validation: set the `num_folds` parameter of your `LC2ST` object, train on observed data and call `lc2st.get_scores(theta_o, x_o, lc2st.trained_clfs)`. This returns an `LC2STScores` object with the test statistics (`.scores`) for each cv-fold. You should choose the smallest calibration set size that gives you a small enough variance over the test statistics. \n",
547545
"\n",
548546
"> Note: Ideally, these checks should be performed in a **separable data setting**, i.e. for a dataset `theta_o, x_o` coming from a sub-optimal estimator: the classifier is supposed to be able to discriminate between the two classes; the test is supposed to be rejected; the variance is supposed to be small. In other words, we are ensuring a **high statistical power** (our true positive rate) of our test. If you want to be really rigurous, you should also check the type I error (or false positive rate), that should be controlled by the significance level of your test (cf. Figure 2 in [[Linhart et al., 2023]](https://arxiv.org/abs/2306.03580)).\n",
549547
"\n",
@@ -612,7 +610,7 @@
612610
") # same as npe.net._distribution\n",
613611
"\n",
614612
"lc2st_nf = LC2ST_NF(\n",
615-
" thetas=theta_cal,\n",
613+
" prior_samples=theta_cal,\n",
616614
" xs=x_cal,\n",
617615
" posterior_samples=post_samples_cal,\n",
618616
" flow_inverse_transform=flow_inverse_transform,\n",
@@ -660,13 +658,14 @@
660658
"source": [
661659
"fig, axes = plt.subplots(1,len(thetas_star), figsize=(12,3))\n",
662660
"for i in range(len(thetas_star)):\n",
663-
" probs, scores = lc2st_nf.get_scores(\n",
661+
" lc2st_scores = lc2st_nf.get_scores(\n",
664662
" x_o=xs_star[i],\n",
665-
" return_probs=True,\n",
666663
" trained_clfs=lc2st_nf.trained_clfs\n",
667664
" )\n",
665+
" probs = lc2st_scores.probabilities\n",
666+
" scores = lc2st_scores.scores\n",
668667
" T_data = lc2st_nf.get_statistic_on_observed_data(x_o=xs_star[i])\n",
669-
" T_null = lc2st_nf.get_statistics_under_null_hypothesis(x_o=xs_star[i])\n",
668+
" T_null = lc2st_nf.get_statistics_under_null_hypothesis(x_o=xs_star[i]).scores\n",
670669
" p_value = lc2st_nf.p_value(xs_star[i])\n",
671670
" reject = lc2st_nf.reject_test(xs_star[i], alpha=conf_alpha)\n",
672671
"\n",
@@ -731,15 +730,13 @@
731730
"\n",
732731
"fig, axes = plt.subplots(1,len(thetas_star), figsize=(12,3))\n",
733732
"for i in range(len(thetas_star)):\n",
734-
" probs_data, _ = lc2st_nf.get_scores(\n",
733+
" probs_data = lc2st_nf.get_scores(\n",
735734
" x_o=xs_star[i],\n",
736-
" return_probs=True,\n",
737735
" trained_clfs=lc2st_nf.trained_clfs\n",
738-
" )\n",
739-
" probs_null, _ = lc2st_nf.get_statistics_under_null_hypothesis(\n",
740-
" x_o=xs_star[i],\n",
741-
" return_probs=True\n",
742-
" )\n",
736+
" ).probabilities\n",
737+
" probs_null = lc2st_nf.get_statistics_under_null_hypothesis(\n",
738+
" x_o=xs_star[i]\n",
739+
" ).probabilities\n",
743740
"\n",
744741
" pp_plot_lc2st(\n",
745742
" probs=[probs_data],\n",
@@ -791,11 +788,10 @@
791788
" inv_ref_samples = lc2st_nf.flow_inverse_transform(\n",
792789
" ref_samples_star[i], xs_star[i]\n",
793790
" ).detach()\n",
794-
" probs_data, _ = lc2st_nf.get_scores(\n",
791+
" probs_data = lc2st_nf.get_scores(\n",
795792
" x_o=xs_star[i],\n",
796-
" return_probs=True,\n",
797793
" trained_clfs=lc2st_nf.trained_clfs\n",
798-
" )\n",
794+
" ).probabilities\n",
799795
" marginal_probs = get_probs_per_marginal(\n",
800796
" probs_data[0],\n",
801797
" lc2st_nf.theta_o.numpy()\n",

docs/how_to_guide/13_diagnostics_lc2st.ipynb

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
"\n",
4040
"# Train the L-C2ST classifier.\n",
4141
"lc2st = LC2ST(\n",
42-
" thetas=prior_samples,\n",
42+
" prior_samples=prior_samples,\n",
4343
" xs=prior_predictives,\n",
4444
" posterior_samples=post_samples_cal,\n",
4545
" classifier=\"mlp\",\n",
@@ -50,17 +50,19 @@
5050
"\n",
5151
"# Note: x_o must have a batch-dimension. I.e. `x_o.shape == (1, observation_shape)`.\n",
5252
"post_samples_star = posterior.sample((10_000,), x=x_o)\n",
53-
"probs_data, scores_data = lc2st.get_scores(\n",
53+
"scores_data = lc2st.get_scores(\n",
5454
" theta_o=post_samples_star,\n",
5555
" x_o=x_o,\n",
56-
" return_probs=True,\n",
5756
" trained_clfs=lc2st.trained_clfs\n",
5857
")\n",
59-
"probs_null, scores_null = lc2st.get_statistics_under_null_hypothesis(\n",
58+
"probs_data = scores_data.probabilities\n",
59+
"scores_data = scores_data.scores\n",
60+
"scores_null = lc2st.get_statistics_under_null_hypothesis(\n",
6061
" theta_o=post_samples_star,\n",
6162
" x_o=x_o,\n",
62-
" return_probs=True,\n",
6363
")\n",
64+
"probs_null = scores_null.probabilities\n",
65+
"scores_null = scores_null.scores\n",
6466
"\n",
6567
"conf_alpha = 0.05\n",
6668
"p_value = lc2st.p_value(post_samples_star, torch.as_tensor(x_o).unsqueeze(0))\n",
@@ -94,7 +96,7 @@
9496
"source": [
9597
"If the red line is outside of the two dotted black lines (as above), then L-C2ST rejects the null-hypothesis that the approximate posterior matches the true posterior (i.e., your posterior is likely wrong).\n",
9698
"\n",
97-
"If the posterior is wrong, then you can get insights into whether the posterior is under- or over-confident as follows:"
99+
"If the posterior is wrong, then you can get insights into whether the posterior is under- or over-confident as follows. The call above returns an `LC2STScores` object; use `.probabilities` for the classifier probabilities and `.scores` for the test statistics."
98100
]
99101
},
100102
{

sbi/diagnostics/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from sbi.diagnostics.lc2st import LC2ST
4+
from sbi.diagnostics.lc2st import LC2ST, LC2ST_NF, LC2STScores, LC2STState
55
from sbi.diagnostics.misspecification import (
66
calc_misspecification_logprob,
77
calc_misspecification_mmd,
@@ -16,6 +16,9 @@
1616
"check_tarp",
1717
"run_tarp",
1818
"LC2ST",
19+
"LC2ST_NF",
20+
"LC2STScores",
21+
"LC2STState",
1922
"calc_misspecification_logprob",
2023
"calc_misspecification_mmd",
2124
]

0 commit comments

Comments
 (0)