Skip to content

Commit 8fd4938

Browse files
committed
Slight tweaks to banana distribution.
1 parent 0f6cd3b commit 8fd4938

2 files changed

Lines changed: 14 additions & 16 deletions

File tree

pints/tests/test_toy_twisted_gaussian_logpdf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_sampling_and_kl_divergence(self):
3737
Test TwistedGaussianLogPDF.kl_divergence() and .sample().
3838
"""
3939
# Ensure consistent output
40-
#np.random.seed(1)
40+
np.random.seed(1)
4141

4242
# Create banana LogPDFs
4343
d = 6

pints/toy/_twisted_gaussian_banana.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,8 @@ def __init__(self, dimension=10, b=0.1, V=100):
5454
self._V = float(V)
5555

5656
# Create phi
57-
self._mean = np.zeros(self._dimension)
58-
self._cov = np.eye(self._dimension)
59-
self._phi = scipy.stats.multivariate_normal(self._mean, self._cov)
57+
self._phi = scipy.stats.multivariate_normal(
58+
np.zeros(self._dimension), np.eye(self._dimension))
6059

6160
def __call__(self, x):
6261
y = np.array(x, copy=True)
@@ -98,17 +97,18 @@ def kl_divergence(self, samples):
9897
# - k
9998
# )
10099
#
101-
# using s1 = real sigma, as this needs to be inverted and the real one
102-
# is more likely to be invertible than the sample one
100+
# For this distribution, s1 is the identify matrix, and m1 is zero,
101+
# so it simplifies to
102+
#
103+
# dkl = 0.5 * (trace(s0) + m0.dot(m0) - log(det(s0)) - k))
104+
#
103105
m0 = np.mean(y, axis=0)
104-
m1 = self._mean
105106
s0 = np.cov(y.T)
106-
s1 = self._cov
107-
cov_inv = np.linalg.inv(s1)
108-
dkl1 = np.trace(cov_inv.dot(s0))
109-
dkl2 = np.dot((m1 - m0).T, cov_inv).dot(m1 - m0)
110-
dkl3 = np.log(np.linalg.det(s1) / np.linalg.det(s0))
111-
return 0.5 * (dkl1 + dkl2 + dkl3 - self._dimension)
107+
s1 = np.eye(self._dimension)
108+
cov_inv = s1
109+
return 0.5 * (
110+
np.trace(s0) + m0.dot(m0)
111+
- np.log(np.linalg.det(s0)) - self._dimension)
112112

113113
def n_parameters(self):
114114
""" See :meth:`pints.LogPDF.n_parameters()`. """
@@ -121,10 +121,8 @@ def sample(self, n):
121121
if n < 0:
122122
raise ValueError('Number of samples cannot be negative.')
123123

124-
x = np.random.randn(n, 2)
124+
x = self._phi.rvs(n)
125125
x[:, 0] *= np.sqrt(self._V)
126126
x[:, 1] -= self._b * (x[:, 0] ** 2 - self._V)
127-
if self._dimension > 2:
128-
x = np.hstack((x, np.random.randn(n, self._dimension - 2)))
129127
return x
130128

0 commit comments

Comments
 (0)