Skip to content

Commit 273e1fc

Browse files
authored
Merge pull request #447 from pints-team/i444-eggbox
Simple egg box toy logpdf for functional testing
2 parents 9de8c13 + 3671408 commit 273e1fc

7 files changed

Lines changed: 498 additions & 0 deletions

File tree

docs/source/toy/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@ examples.
2323
repressilator_model
2424
rosenbrock
2525
sir_model
26+
simple_egg_box_logpdf
2627
twisted_gaussian_logpdf
2728

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
***************************
2+
Simple Egg Box Distribution
3+
***************************
4+
5+
.. module:: pints.toy
6+
7+
.. autoclass:: SimpleEggBoxLogPDF
8+

examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,6 @@ relevant code.
6969

7070
- [Multimodal normal distribution](./toy-distribution-multimodal-normal.ipynb)
7171
- [Rosenbrock function](./toy-distribution-rosenbrock.ipynb)
72+
- [Simple Egg Box](./toy-distribution-simple-egg-box.ipynb)
7273
- [Twisted Gaussian Banana](./toy-distribution-twisted-gaussian.ipynb)
7374

examples/toy-distribution-simple-egg-box.ipynb

Lines changed: 268 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#!/usr/bin/env python3
2+
#
3+
# Tests the simple egg box toy LogPDF.
4+
#
5+
# This file is part of PINTS.
6+
# Copyright (c) 2017-2018, University of Oxford.
7+
# For licensing information, see the LICENSE file distributed with the PINTS
8+
# software package.
9+
#
10+
import pints
11+
import pints.toy
12+
import unittest
13+
import numpy as np
14+
15+
16+
class TestSimpleEggBoxLogPDF(unittest.TestCase):
17+
"""
18+
Tests the simple egg box logpdf toy distribution.
19+
"""
20+
def test_simple_egg_box_logpdf(self):
21+
# Test basics
22+
f = pints.toy.SimpleEggBoxLogPDF()
23+
self.assertEqual(f.n_parameters(), 2)
24+
self.assertTrue(np.isscalar(f(np.zeros(2))))
25+
26+
# Test construction errors
27+
self.assertRaises(
28+
ValueError, pints.toy.SimpleEggBoxLogPDF, sigma=0)
29+
self.assertRaises(
30+
ValueError, pints.toy.SimpleEggBoxLogPDF, r=0)
31+
32+
def test_sampling_and_divergence(self):
33+
"""
34+
Tests :meth:`SimpleEggBoxLogPDF.kl_score()`.
35+
"""
36+
# Ensure consistent output
37+
np.random.seed(1)
38+
39+
# Create some log pdfs
40+
log_pdf1 = pints.toy.SimpleEggBoxLogPDF(2, 4)
41+
log_pdf2 = pints.toy.SimpleEggBoxLogPDF(3, 6)
42+
43+
# Generate samples from each
44+
n = 100
45+
samples1 = log_pdf1.sample(n)
46+
samples2 = log_pdf2.sample(n)
47+
48+
# Test divergence scores
49+
s11 = log_pdf1.kl_score(samples1)
50+
s12 = log_pdf1.kl_score(samples2)
51+
self.assertLess(s11, s12)
52+
s21 = log_pdf2.kl_score(samples1)
53+
s22 = log_pdf2.kl_score(samples2)
54+
self.assertLess(s22, s21)
55+
56+
# Test penalising if a mode is missing
57+
samples3 = np.vstack((
58+
samples2[samples2[:, 0] > 0], # Top half
59+
samples2[samples2[:, 1] < 0], # Left half
60+
))
61+
s23 = log_pdf2.kl_score(samples3)
62+
self.assertLess(s22, s23)
63+
self.assertGreater(s23 / s22, 100)
64+
65+
# Test sample arguments
66+
self.assertRaises(ValueError, log_pdf1.sample, -1)
67+
68+
# Test shape testing
69+
self.assertEqual(samples1.shape, (n, 2))
70+
x = np.ones((n, 3))
71+
self.assertRaises(ValueError, log_pdf1.kl_score, x)
72+
x = np.ones((n, 2, 2))
73+
self.assertRaises(ValueError, log_pdf1.kl_score, x)
74+
75+
76+
if __name__ == '__main__':
77+
print('Add -v for more debug output')
78+
import sys
79+
if '-v' in sys.argv:
80+
debug = True
81+
unittest.main()
82+

pints/toy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ._parabola import ParabolicError # noqa
2424
from ._repressilator_model import RepressilatorModel # noqa
2525
from ._rosenbrock import RosenbrockError, RosenbrockLogPDF # noqa
26+
from ._simple_egg_box import SimpleEggBoxLogPDF # noqa
2627
from ._sir_model import SIRModel # noqa
2728
from ._twisted_gaussian_banana import TwistedGaussianLogPDF # noqa
2829

pints/toy/_simple_egg_box.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#
2+
# Simple egg-box LogPDF
3+
#
4+
# This file is part of PINTS.
5+
# Copyright (c) 2017, University of Oxford.
6+
# For licensing information, see the LICENSE file distributed with the PINTS
7+
# software package.
8+
#
9+
from __future__ import absolute_import, division
10+
from __future__ import print_function, unicode_literals
11+
import pints
12+
import numpy as np
13+
import scipy.stats
14+
15+
16+
class SimpleEggBoxLogPDF(pints.LogPDF):
17+
"""
18+
Two-dimensional multimodal Normal distribution, with four more-or-less
19+
independent modes, each centered in a different quadrant.
20+
21+
Arguments:
22+
23+
``sigma``
24+
The variance of each mode.
25+
``r``
26+
The first mode will be located at ``(d, d)``, ``(-d, d)``, (-d, -d)``,
27+
and ``(d, -d)``, where ``d = r * sigma``.
28+
29+
*Extends:* :class:`pints.LogPDF`.
30+
"""
31+
def __init__(self, sigma=2, r=4):
32+
33+
# Sigma for every mode
34+
self._sigma = float(sigma)
35+
if self._sigma <= 0:
36+
raise ValueError('Sigma must be greater than zero.')
37+
38+
# Set modes
39+
r = float(r)
40+
if r <= 0:
41+
raise ValueError('Argument r must be greater than zero.')
42+
d = r * self._sigma
43+
self._modes = [
44+
[d, d],
45+
[-d, d],
46+
[-d, -d],
47+
[d, -d],
48+
]
49+
50+
# Set covariances
51+
self._covs = [np.eye(2) * sigma] * 4
52+
53+
# Create scipy 'random variables'
54+
self._vars = [
55+
scipy.stats.multivariate_normal(mode, self._covs[i])
56+
for i, mode in enumerate(self._modes)]
57+
58+
def __call__(self, x):
59+
f = np.sum([var.pdf(x) for var in self._vars])
60+
return -float('inf') if f == 0 else np.log(f)
61+
62+
def n_parameters(self):
63+
""" See :meth:`pints.LogPDF.n_parameters()`. """
64+
return 2
65+
66+
def kl_score(self, samples):
67+
"""
68+
Calculates a heuristic score for how well a given set of samples
69+
matches this LogPDF's underlying distribution, based on
70+
Kullback-Leibler divergence of the individual modes. This only works
71+
well if the modes are nicely separated, i.e. for larger values of
72+
``r``.
73+
"""
74+
dimension = 2
75+
76+
# Check size of input
77+
if not len(samples.shape) == 2:
78+
raise ValueError('Given samples list must be 2x2.')
79+
if samples.shape[1] != dimension:
80+
raise ValueError(
81+
'Given samples must have length ' + str(dimension))
82+
83+
# Separate samples into quadrants
84+
q12 = samples[samples[:, 1] >= 0]
85+
q34 = samples[samples[:, 1] < 0]
86+
q1 = q12[q12[:, 0] >= 0]
87+
q2 = q12[q12[:, 0] < 0]
88+
q3 = q34[q34[:, 0] < 0]
89+
q4 = q34[q34[:, 0] >= 0]
90+
qs = [q1, q2, q3, q4]
91+
92+
# Calculate kullback-leibler for each quadrant-mode pair
93+
dkls = np.array([0, 0, 0, 0], dtype=float)
94+
for i, q in enumerate(qs):
95+
if len(q) == 0:
96+
continue
97+
m0 = np.mean(q, axis=0)
98+
s0 = np.cov(q.T)
99+
m1 = self._modes[i]
100+
s1 = self._covs[i]
101+
cov_inv = np.linalg.inv(s1)
102+
dkl1 = np.trace(cov_inv.dot(s0))
103+
dkl2 = np.dot((m1 - m0).T, cov_inv).dot(m1 - m0)
104+
dkl3 = np.log(np.linalg.det(s1) / np.linalg.det(s0))
105+
dkls[i] = 0.5 * (dkl1 + dkl2 + dkl3 - dimension)
106+
107+
# No samples in a given quadrant? Then use 100 times max divergence
108+
penalty1 = 100 * np.max(dkls)
109+
dkls[dkls == 0] = penalty1
110+
111+
# Sum divergences together
112+
score = np.sum(dkls)
113+
114+
# Penalise unequal distribution of the points, and return
115+
ns = [len(q) for q in qs]
116+
penalty2 = np.max(ns) / max(1, np.min(ns))
117+
return score * penalty2
118+
119+
def sample(self, n):
120+
"""
121+
Returns ``n`` samples from the underlying distribution.
122+
"""
123+
if n < 0:
124+
raise ValueError('Number of samples cannot be negative.')
125+
126+
# Calculate number of samples from each distribution
127+
weights = [0.25] * 4
128+
ns = np.sum(scipy.stats.multinomial.rvs(1, weights, n), axis=0)
129+
130+
# Draw samples from each distribution, then join them together
131+
x = [v.rvs(ns[i]) for i, v in enumerate(self._vars)]
132+
x = np.vstack(x)
133+
134+
# Shuffle the samples and return
135+
np.random.shuffle(x)
136+
return x
137+

0 commit comments

Comments
 (0)