Skip to content

Commit 9076791

Browse files
authored
Added roundtrip test for precision estimation (#161)
1 parent 5df69bb commit 9076791

1 file changed

Lines changed: 39 additions & 0 deletions

File tree

tests/test_precision_estimation.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,45 @@ def test_snapshot_fit_precision_cholesky_approximate():
6666
np.testing.assert_allclose(entries_at_zero[::9], desired, atol=1e-8)
6767

6868

69+
@pytest.mark.parametrize("seed", range(99))
70+
def test_precision_cholesky_roundtrip(seed):
71+
"""Starting from a known, sparse precision matrix, we generate data,
72+
then try to infer the known values from the samples."""
73+
74+
# Create sparse, pos.def precision matrix
75+
rng = np.random.default_rng(seed)
76+
n = 25 # Size
77+
density = 0.1
78+
79+
# Create sparse pos def precision matrix
80+
F = rng.normal(size=(n, n))
81+
F[rng.uniform(size=(n, n)) > density] = 0
82+
Prec = F.T @ F + np.eye(n)
83+
assert np.all(np.linalg.svd(Prec).S > 0), "Pos def"
84+
85+
G_matrix = (~np.isclose(Prec, 0.0)).astype(int)
86+
Graph_u = nx.from_scipy_sparse_array(sp.sparse.csc_array(G_matrix))
87+
88+
Cov = np.linalg.inv(Prec)
89+
U = rng.multivariate_normal(mean=np.zeros(n), cov=Cov, size=99)
90+
91+
# Estimate precision using known structure
92+
Prec_est, *_ = precest.fit_precision_cholesky(
93+
U=U, Graph_u=Graph_u, ordering_method="amd"
94+
)
95+
Prec_est = Prec_est.todense()
96+
97+
RMSE = np.sqrt(np.mean((Prec - Prec_est) ** 2))
98+
99+
# Estimate the naive way - invert the empirical covariance
100+
Prec_naive = np.linalg.inv(np.cov(U, rowvar=False))
101+
RMSE_naive = np.sqrt(np.mean((Prec - Prec_naive) ** 2))
102+
103+
# Here 0.77 was chosen to make all tests pass, to easier catch
104+
# regressions. Nothing special about the number. Main idea: beat naive!
105+
assert RMSE_naive * 0.77 > RMSE
106+
107+
69108
def test_objective_twice():
70109
# A regression test: ensure that two calls return the same result.
71110
rng = np.random.default_rng(42)

0 commit comments

Comments
 (0)