1717 gaussian_mixture ,
1818 uniform_prior_gaussian_mixture ,
1919)
20+ from sbi .utils .sbiutils import seed_all_backends
2021
2122# =============================================================================
2223# Fixtures
@@ -55,6 +56,8 @@ def sim_setup() -> SimulatorSetup:
5556@pytest .fixture (scope = "session" )
5657def badly_trained_npe (sim_setup ):
5758 """A poorly trained NPE for testing LC2ST detection of bad posteriors."""
59+ # seed explicitly to keep session scope deterministic.
60+ seed_all_backends (1 )
5861 theta_train = sim_setup .prior .sample ((50 ,))
5962 x_train = sim_setup .simulator (theta_train )
6063
@@ -66,6 +69,8 @@ def badly_trained_npe(sim_setup):
6669@pytest .fixture (scope = "session" )
6770def well_trained_npe (sim_setup ):
6871 """A well-trained NPE for testing LC2ST false positive rate."""
72+ # seed explicitly to keep session scope deterministic.
73+ seed_all_backends (1 )
6974 theta_train = sim_setup .prior .sample ((5_000 ,))
7075 x_train = sim_setup .simulator (theta_train )
7176
@@ -77,6 +82,8 @@ def well_trained_npe(sim_setup):
7782@pytest .fixture (scope = "session" )
7883def cal_data (sim_setup , badly_trained_npe ) -> CalibrationData :
7984 """Calibration data for LC2ST tests."""
85+ # seed explicitly to keep session scope deterministic.
86+ seed_all_backends (1 )
8087 num_cal = 100
8188 thetas = sim_setup .prior .sample ((num_cal ,))
8289 xs = sim_setup .simulator (thetas )
@@ -129,7 +136,7 @@ def test_lc2st_methods(method, cal_data, badly_trained_npe, theta_o, x_o):
129136 else :
130137 npe = badly_trained_npe
131138 kwargs_init = {
132- "flow_inverse_transform" : lambda t , x : npe .net . _transform ( t , context = x )[ 0 ] ,
139+ "flow_inverse_transform" : npe .inverse_transform ,
133140 "flow_base_dist" : torch .distributions .MultivariateNormal (
134141 torch .zeros (2 ), torch .eye (2 )
135142 ),
@@ -180,7 +187,7 @@ def test_lc2st_parameter_combinations(
180187 else :
181188 npe = badly_trained_npe
182189 kwargs_init = {
183- "flow_inverse_transform" : lambda t , x : npe .net . _transform ( t , context = x )[ 0 ] ,
190+ "flow_inverse_transform" : npe .inverse_transform ,
184191 "flow_base_dist" : torch .distributions .MultivariateNormal (
185192 torch .zeros (2 ), torch .eye (2 )
186193 ),
@@ -294,7 +301,7 @@ def test_lc2st_nf_with_pretrained_null_is_ready_after_observed_training(
294301 """
295302 npe = badly_trained_npe
296303 kwargs_init = {
297- "flow_inverse_transform" : lambda t , x : npe .net . _transform ( t , context = x )[ 0 ] ,
304+ "flow_inverse_transform" : npe .inverse_transform ,
298305 "flow_base_dist" : torch .distributions .MultivariateNormal (
299306 torch .zeros (2 ), torch .eye (2 )
300307 ),
@@ -612,7 +619,7 @@ def test_lc2st_true_positive_rate(method, sim_setup, badly_trained_npe):
612619 else :
613620 npe = badly_trained_npe
614621 kwargs_init = {
615- "flow_inverse_transform" : lambda t , x : npe .net . _transform ( t , context = x )[ 0 ] ,
622+ "flow_inverse_transform" : npe .inverse_transform ,
616623 "flow_base_dist" : torch .distributions .MultivariateNormal (
617624 torch .zeros (2 ), torch .eye (2 )
618625 ),
@@ -665,7 +672,7 @@ def test_lc2st_false_positive_rate(method, sim_setup, well_trained_npe, set_seed
665672 else :
666673 npe = well_trained_npe
667674 kwargs_init = {
668- "flow_inverse_transform" : lambda t , x : npe .net . _transform ( t , context = x )[ 0 ] ,
675+ "flow_inverse_transform" : npe .inverse_transform ,
669676 "flow_base_dist" : torch .distributions .MultivariateNormal (
670677 torch .zeros (2 ), torch .eye (2 )
671678 ),
0 commit comments