Skip to content

Commit 38f43ca

Browse files
authored
fix: nflows inverse transform and retarget lc2st (#1865)
* tests: add seed to lc2st session fixtures * ruff * fix: typo in inverse transform, fix usage
1 parent f7ef8b9 commit 38f43ca

3 files changed

Lines changed: 15 additions & 8 deletions

File tree

docs/advanced_tutorials/13_diagnostics_lc2st.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@
604604
"source": [
605605
"from sbi.diagnostics.lc2st import LC2ST_NF\n",
606606
"\n",
607-
"flow_inverse_transform = lambda theta, x: npe.net._transform(theta, context=x)[0]\n",
607+
"flow_inverse_transform = npe.inverse_transform\n",
608608
"flow_base_dist = torch.distributions.MultivariateNormal(\n",
609609
" torch.zeros(2), torch.eye(2)\n",
610610
") # same as npe.net._distribution\n",

sbi/neural_nets/estimators/nflows_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def inverse_transform(self, input: Tensor, condition: Tensor) -> Tensor:
7070
input = input.reshape(-1, input.shape[-1])
7171
condition = condition.reshape(-1, *self.condition_shape)
7272

73-
noise, _ = self.net._transorm(input, context=condition)
74-
noise = noise.reshape(batch_shape)
73+
noise, _ = self.net._transform(input, context=condition)
74+
noise = noise.reshape(batch_shape + (noise.shape[-1],))
7575
return noise
7676

7777
def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:

tests/lc2st_test.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
gaussian_mixture,
1818
uniform_prior_gaussian_mixture,
1919
)
20+
from sbi.utils.sbiutils import seed_all_backends
2021

2122
# =============================================================================
2223
# Fixtures
@@ -55,6 +56,8 @@ def sim_setup() -> SimulatorSetup:
5556
@pytest.fixture(scope="session")
5657
def badly_trained_npe(sim_setup):
5758
"""A poorly trained NPE for testing LC2ST detection of bad posteriors."""
59+
# seed explicitly to keep session scope deterministic.
60+
seed_all_backends(1)
5861
theta_train = sim_setup.prior.sample((50,))
5962
x_train = sim_setup.simulator(theta_train)
6063

@@ -66,6 +69,8 @@ def badly_trained_npe(sim_setup):
6669
@pytest.fixture(scope="session")
6770
def well_trained_npe(sim_setup):
6871
"""A well-trained NPE for testing LC2ST false positive rate."""
72+
# seed explicitly to keep session scope deterministic.
73+
seed_all_backends(1)
6974
theta_train = sim_setup.prior.sample((5_000,))
7075
x_train = sim_setup.simulator(theta_train)
7176

@@ -77,6 +82,8 @@ def well_trained_npe(sim_setup):
7782
@pytest.fixture(scope="session")
7883
def cal_data(sim_setup, badly_trained_npe) -> CalibrationData:
7984
"""Calibration data for LC2ST tests."""
85+
# seed explicitly to keep session scope deterministic.
86+
seed_all_backends(1)
8087
num_cal = 100
8188
thetas = sim_setup.prior.sample((num_cal,))
8289
xs = sim_setup.simulator(thetas)
@@ -129,7 +136,7 @@ def test_lc2st_methods(method, cal_data, badly_trained_npe, theta_o, x_o):
129136
else:
130137
npe = badly_trained_npe
131138
kwargs_init = {
132-
"flow_inverse_transform": lambda t, x: npe.net._transform(t, context=x)[0],
139+
"flow_inverse_transform": npe.inverse_transform,
133140
"flow_base_dist": torch.distributions.MultivariateNormal(
134141
torch.zeros(2), torch.eye(2)
135142
),
@@ -180,7 +187,7 @@ def test_lc2st_parameter_combinations(
180187
else:
181188
npe = badly_trained_npe
182189
kwargs_init = {
183-
"flow_inverse_transform": lambda t, x: npe.net._transform(t, context=x)[0],
190+
"flow_inverse_transform": npe.inverse_transform,
184191
"flow_base_dist": torch.distributions.MultivariateNormal(
185192
torch.zeros(2), torch.eye(2)
186193
),
@@ -294,7 +301,7 @@ def test_lc2st_nf_with_pretrained_null_is_ready_after_observed_training(
294301
"""
295302
npe = badly_trained_npe
296303
kwargs_init = {
297-
"flow_inverse_transform": lambda t, x: npe.net._transform(t, context=x)[0],
304+
"flow_inverse_transform": npe.inverse_transform,
298305
"flow_base_dist": torch.distributions.MultivariateNormal(
299306
torch.zeros(2), torch.eye(2)
300307
),
@@ -612,7 +619,7 @@ def test_lc2st_true_positive_rate(method, sim_setup, badly_trained_npe):
612619
else:
613620
npe = badly_trained_npe
614621
kwargs_init = {
615-
"flow_inverse_transform": lambda t, x: npe.net._transform(t, context=x)[0],
622+
"flow_inverse_transform": npe.inverse_transform,
616623
"flow_base_dist": torch.distributions.MultivariateNormal(
617624
torch.zeros(2), torch.eye(2)
618625
),
@@ -665,7 +672,7 @@ def test_lc2st_false_positive_rate(method, sim_setup, well_trained_npe, set_seed
665672
else:
666673
npe = well_trained_npe
667674
kwargs_init = {
668-
"flow_inverse_transform": lambda t, x: npe.net._transform(t, context=x)[0],
675+
"flow_inverse_transform": npe.inverse_transform,
669676
"flow_base_dist": torch.distributions.MultivariateNormal(
670677
torch.zeros(2), torch.eye(2)
671678
),

0 commit comments

Comments
 (0)