Skip to content

Commit 9c73c55

Browse files
authored
Merge pull request #728 from pints-team/i719-log-likelihood-tests
Value-based tests for log-likelihoods
2 parents c1218dd + 0206af7 commit 9c73c55

5 files changed

Lines changed: 217 additions & 60 deletions

File tree

pints/_log_likelihoods.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,19 @@ class AR1LogLikelihood(pints.ProblemLogLikelihood):
1818
Calculates a log-likelihood assuming AR1 noise model
1919
2020
.. math::
21-
\log{L(\\theta, \sigma|\\boldsymbol{x})} =
21+
\log{L(\\theta, \sigma'|\\boldsymbol{x})} =
2222
-\\frac{N}{2}\log{2\pi}
23-
-N\log{\sigma}
24-
-\\frac{1}{2\sigma^2}
25-
\sum_{i=1}^N{(\\epsilon_i x_i - \\rho \\epsilon_{i-1} )^2}
23+
-N\log{\sigma'}
24+
-\\frac{1}{2\sigma'^2}
25+
\sum_{i=2}^N{(\\epsilon_i x_i - \\rho \\epsilon_{i-1} )^2}
2626
2727
where
2828
2929
.. math::
3030
\\epsilon_i = x_i - f_i(\\theta)
3131
32+
and :math:`sigma' = \\frac{sigma} \\sqrt{1-\\rho^2}`.
33+
3234
3335
Arguments:
3436
@@ -54,8 +56,11 @@ def __init__(self, problem):
5456
self._logn = 0.5 * (self._nt) * np.log(2 * np.pi)
5557

5658
def __call__(self, x):
57-
rho = np.asarray(x[-2 * self._no:-self._no])
58-
sigma = np.asarray(x[-self._no:]) * np.sqrt(1 - rho**2)
59+
m = 2 * self._no
60+
parameters = x[-m:]
61+
rho = np.asarray(parameters[0::2])
62+
sigma = np.asarray(parameters[1::2])
63+
sigma = np.asarray(sigma) * np.sqrt(1 - rho**2)
5964
error = self._values - self._problem.evaluate(x[:-2 * self._no])
6065
autocorr_error = error[1:] - rho * error[:-1]
6166
return np.sum(- self._logn - self._nt * np.log(sigma)
@@ -64,21 +69,29 @@ def __call__(self, x):
6469

6570
class ARMA11LogLikelihood(pints.ProblemLogLikelihood):
6671
"""
67-
Calculates a log-likelihood assuming AR1 noise model
72+
Calculates a log-likelihood assuming ARMA(1,1) noise model.
6873
6974
.. math::
7075
\log{L(\\theta, \sigma|\\boldsymbol{x})} =
7176
-\\frac{N}{2}\log{2\pi}
7277
-N\log{\sigma}
7378
-\\frac{1}{2\sigma^2}
74-
\sum_{i=1}^N{(\\epsilon_i x_i - \\rho \\epsilon_{i-1} -
75-
\\phi \\nu(t-1))^2}
79+
\sum_{i=3}^N{(\\nu_i - \\phi \\nu_{i-1})^2}
7680
7781
where
7882
7983
.. math::
84+
\\nu_i = \\epsilon_i - \\rho \\epsilon_{i-1}
85+
86+
and
87+
88+
..math::
8089
\\epsilon_i = x_i - f_i(\\theta)
8190
91+
and
92+
93+
.. math::
94+
\\sigma = \\sigma\\sqrt{\\frac{1-\\rho^2}{1 + 2\\phi\\rho + \\phi^2}}`
8295
8396
Arguments:
8497
@@ -104,13 +117,16 @@ def __init__(self, problem):
104117
self._logn = 0.5 * (self._nt) * np.log(2 * np.pi)
105118

106119
def __call__(self, x):
107-
rho = np.asarray(x[-3 * self._no:-2 * self._no])
108-
phi = np.asarray(x[-2 * self._no:-self._no])
120+
m = 3 * self._no
121+
parameters = x[-m:]
122+
rho = np.asarray(parameters[0::3])
123+
phi = np.asarray(parameters[1::3])
124+
sigma = np.asarray(parameters[2::3])
109125
sigma = (
110-
np.asarray(x[-self._no:]) *
126+
sigma *
111127
np.sqrt((1.0 - rho**2) / (1.0 + 2.0 * phi * rho + phi**2))
112128
)
113-
error = self._values - self._problem.evaluate(x[:-3 * self._no])
129+
error = self._values - self._problem.evaluate(x[:-m])
114130
v = error[1:] - rho * error[:-1]
115131
autocorr_error = v[1:] - phi * v[:-1]
116132
return np.sum(- self._logn - self._nt * np.log(sigma)
@@ -320,29 +336,27 @@ def __call__(self, x):
320336

321337
def evaluateS1(self, x):
322338
""" See :meth:`LogPDF.evaluateS1()`. """
323-
sigma = float(np.asarray(x[-self._no:]))
339+
sigma = np.asarray(x[-self._no:])
324340

325341
# Evaluate, and get residuals
326342
y, dy = self._problem.evaluateS1(x[:-self._no])
327343

328344
# Reshape dy, in case we're working with a single-output problem
329-
dy = dy.reshape(self._nt, self._no, self._n_parameters - 1)
345+
dy = dy.reshape(self._nt, self._no, self._n_parameters - self._no)
330346

331347
# Note: Must be (data - simulation), sign now matters!
332348
r = self._values - y
333349

334350
# Calculate log-likelihood
335-
L = np.sum(-self._logn - self._nt * np.log(sigma)
336-
- (1.0 / (2 * sigma**2)) * np.sum(r**2, axis=0))
351+
L = self.__call__(x)
337352

338353
# Calculate derivatives in the model parameters
339354
dL = np.sum(
340355
(sigma**(-2.0) * np.sum((r.T * dy.T).T, axis=0).T).T, axis=0)
341356

342357
# Calculate derivative wrt sigma
343-
dsigma = np.sum(-self._nt / sigma +
344-
sigma**(-3.0) * np.sum(r**2, axis=0))
345-
dL = np.concatenate((dL, np.array([dsigma])))
358+
dsigma = -self._nt / sigma + sigma**(-3.0) * np.sum(r**2, axis=0)
359+
dL = np.concatenate((dL, np.array(list(dsigma))))
346360

347361
# Return
348362
return L, dL
@@ -490,4 +504,3 @@ def __init__(self, problem):
490504
'The class `pints.KnownNoiseLogLikelihood` is deprecated.'
491505
' Please use `pints.GaussianLogLikelihood` instead.')
492506
super(UnknownNoiseLogLikelihood, self).__init__(problem)
493-

pints/tests/test_error_measures.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,15 @@ def test_mean_squared_error_multi(self):
154154
x = [1, 2]
155155

156156
# Model outputs are 3 times [1, 4]
157-
# Model derivatives are 3 times [[1, 0], [0, 2]]
157+
# Model derivatives are 3 times [[1, 0], [0, 1]]
158158
y, dy = p.evaluateS1(x)
159159
self.assertTrue(np.all(y == p.evaluate(x)))
160160
self.assertTrue(np.all(y[0, :] == [1, 4]))
161161
self.assertTrue(np.all(y[1, :] == [1, 4]))
162162
self.assertTrue(np.all(y[2, :] == [1, 4]))
163-
self.assertTrue(np.all(dy[0, :] == [[1, 0], [0, 2]]))
164-
self.assertTrue(np.all(dy[1, :] == [[1, 0], [0, 2]]))
165-
self.assertTrue(np.all(dy[2, :] == [[1, 0], [0, 2]]))
163+
self.assertTrue(np.all(dy[0, :] == [[1, 0], [0, 1]]))
164+
self.assertTrue(np.all(dy[1, :] == [[1, 0], [0, 1]]))
165+
self.assertTrue(np.all(dy[2, :] == [[1, 0], [0, 1]]))
166166

167167
# Check residuals
168168
rx = y - np.array(values)
@@ -183,11 +183,11 @@ def test_mean_squared_error_multi(self):
183183
self.assertEqual(dex.shape, (2, ))
184184

185185
# Residuals are: [[0, 0], [-1, -3], [-2, -6]]
186-
# Derivatives are: [[1, 0], [0, 2]]
186+
# Derivatives are: [[1, 0], [0, 1]]
187187
# dex1 is: (2 / nt / no) * (0 - 1 - 2) * 1 = (1 / 3) * -3 * 1 = -1
188-
# dex2 is: (2 / nt / no) * (0 - 3 - 6) * 2 = (1 / 3) * -9 * 2 = -6
188+
# dex2 is: (2 / nt / no) * (0 - 3 - 6) * 1 = (1 / 3) * -9 * 1 = -3
189189
self.assertEqual(dex[0], -1)
190-
self.assertEqual(dex[1], -6)
190+
self.assertEqual(dex[1], -3)
191191

192192
def test_mean_squared_error_weighted(self):
193193
""" Tests :class:`pints.MeanSquaredError` with weighted outputs. """
@@ -218,15 +218,15 @@ def test_mean_squared_error_weighted(self):
218218
x = [1, 2]
219219

220220
# Model outputs are 3 times [1, 4]
221-
# Model derivatives are 3 times [[1, 0], [0, 2]]
221+
# Model derivatives are 3 times [[1, 0], [0, 1]]
222222
y, dy = p.evaluateS1(x)
223223
self.assertTrue(np.all(y == p.evaluate(x)))
224224
self.assertTrue(np.all(y[0, :] == [1, 4]))
225225
self.assertTrue(np.all(y[1, :] == [1, 4]))
226226
self.assertTrue(np.all(y[2, :] == [1, 4]))
227-
self.assertTrue(np.all(dy[0, :] == [[1, 0], [0, 2]]))
228-
self.assertTrue(np.all(dy[1, :] == [[1, 0], [0, 2]]))
229-
self.assertTrue(np.all(dy[2, :] == [[1, 0], [0, 2]]))
227+
self.assertTrue(np.all(dy[0, :] == [[1, 0], [0, 1]]))
228+
self.assertTrue(np.all(dy[1, :] == [[1, 0], [0, 1]]))
229+
self.assertTrue(np.all(dy[2, :] == [[1, 0], [0, 1]]))
230230

231231
# Check residuals
232232
rx = y - np.array(values)
@@ -251,11 +251,11 @@ def test_mean_squared_error_weighted(self):
251251
# dex1 is: (2 / nt / no) * (0 - 1 - 2) * 1 * 1
252252
# = (1 / 3) * -3 * 1 * 1
253253
# = -1
254-
# dex2 is: (2 / nt / no) * (0 - 3 - 6) * 2 * 2
255-
# = (1 / 3) * -9 * 2 * 2
256-
# = -12
254+
# dex2 is: (2 / nt / no) * (0 - 3 - 6) * 1 * 2
255+
# = (1 / 3) * -9 * 1 * 2
256+
# = -6
257257
self.assertEqual(dex[0], -1)
258-
self.assertEqual(dex[1], -12)
258+
self.assertEqual(dex[1], -6)
259259

260260
def test_probability_based_error(self):
261261
""" Tests :class:`pints.ProbabilityBasedError`. """
@@ -344,15 +344,15 @@ def test_sum_of_squares_error_multi(self):
344344
x = [1, 2]
345345

346346
# Model outputs are 3 times [1,4]
347-
# Model derivatives are 3 times [[1, 0], [0, 2]]
347+
# Model derivatives are 3 times [[1, 0], [0, 1]]
348348
y, dy = p.evaluateS1(x)
349349
self.assertTrue(np.all(y == p.evaluate(x)))
350350
self.assertTrue(np.all(y[0, :] == [1, 4]))
351351
self.assertTrue(np.all(y[1, :] == [1, 4]))
352352
self.assertTrue(np.all(y[2, :] == [1, 4]))
353-
self.assertTrue(np.all(dy[0, :] == [[1, 0], [0, 2]]))
354-
self.assertTrue(np.all(dy[1, :] == [[1, 0], [0, 2]]))
355-
self.assertTrue(np.all(dy[2, :] == [[1, 0], [0, 2]]))
353+
self.assertTrue(np.all(dy[0, :] == [[1, 0], [0, 1]]))
354+
self.assertTrue(np.all(dy[1, :] == [[1, 0], [0, 1]]))
355+
self.assertTrue(np.all(dy[2, :] == [[1, 0], [0, 1]]))
356356

357357
# Check residuals
358358
rx = y - np.array(values)
@@ -373,11 +373,11 @@ def test_sum_of_squares_error_multi(self):
373373
self.assertEqual(dex.shape, (2, ))
374374

375375
# Residuals are: [[0, 0], [-1, -3], [-2, -6]]
376-
# Derivatives are: [[1, 0], [0, 2]]
376+
# Derivatives are: [[1, 0], [0, 1]]
377377
# dex1 is: 2 * (0 - 1 - 2) * 1 = 2 * -3 * 1 = -6
378-
# dex2 is: 2 * (0 - 3 - 6) * 2 = 2 * -9 * 2 = -36
378+
# dex2 is: 2 * (0 - 3 - 6) * 2 = 2 * -9 * 1 = -18
379379
self.assertEqual(dex[0], -6)
380-
self.assertEqual(dex[1], -36)
380+
self.assertEqual(dex[1], -18)
381381

382382
def test_sum_of_squares_error_weighted(self):
383383
""" Tests :class:`pints.MeanSquaredError` with weighted outputs. """
@@ -408,15 +408,15 @@ def test_sum_of_squares_error_weighted(self):
408408
x = [1, 2]
409409

410410
# Model outputs are 3 times [1, 4]
411-
# Model derivatives are 3 times [[1, 0], [0, 2]]
411+
# Model derivatives are 3 times [[1, 0], [0, 1]]
412412
y, dy = p.evaluateS1(x)
413413
self.assertTrue(np.all(y == p.evaluate(x)))
414414
self.assertTrue(np.all(y[0, :] == [1, 4]))
415415
self.assertTrue(np.all(y[1, :] == [1, 4]))
416416
self.assertTrue(np.all(y[2, :] == [1, 4]))
417-
self.assertTrue(np.all(dy[0, :] == [[1, 0], [0, 2]]))
418-
self.assertTrue(np.all(dy[1, :] == [[1, 0], [0, 2]]))
419-
self.assertTrue(np.all(dy[2, :] == [[1, 0], [0, 2]]))
417+
self.assertTrue(np.all(dy[0, :] == [[1, 0], [0, 1]]))
418+
self.assertTrue(np.all(dy[1, :] == [[1, 0], [0, 1]]))
419+
self.assertTrue(np.all(dy[2, :] == [[1, 0], [0, 1]]))
420420

421421
# Check residuals
422422
rx = y - np.array(values)
@@ -437,15 +437,15 @@ def test_sum_of_squares_error_weighted(self):
437437
self.assertEqual(dex.shape, (2, ))
438438

439439
# Residuals are: [[0, 0], [-1, -3], [-2, -6]]
440-
# Derivatives are: [[1, 0], [0, 2]]
440+
# Derivatives are: [[1, 0], [0, 1]]
441441
# dex1 is: 2 * (0 - 1 - 2) * 1 * 1
442442
# = 2 * -3 * 1 * 1
443443
# = -6
444-
# dex2 is: 2 * (0 - 3 - 6) * 2 * 2
445-
# = 2 * -9 * 2 * 2
446-
# = -72
444+
# dex2 is: 2 * (0 - 3 - 6) * 2 * 1
445+
# = 2 * -9 * 2 * 1
446+
# = -36
447447
self.assertEqual(dex[0], -6)
448-
self.assertEqual(dex[1], -72)
448+
self.assertEqual(dex[1], -36)
449449

450450
def test_sum_of_errors(self):
451451
""" Tests :class:`pints.SumOfErrors`. """

0 commit comments

Comments
 (0)