Skip to content

Commit 27464e1

Browse files
committed
Corrected tests
- constant model test was wrong previously - corrected incorrect derivative-based tests for the constant model
1 parent 9dc3ebc commit 27464e1

2 files changed

Lines changed: 29 additions & 29 deletions

File tree

pints/tests/test_error_measures.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ def test_mean_squared_error_multi(self):
160160
self.assertTrue(np.all(y[0, :] == [1, 4]))
161161
self.assertTrue(np.all(y[1, :] == [1, 4]))
162162
self.assertTrue(np.all(y[2, :] == [1, 4]))
163-
self.assertTrue(np.all(dy[0, :] == [[1, 0], [0, 2]]))
164-
self.assertTrue(np.all(dy[1, :] == [[1, 0], [0, 2]]))
165-
self.assertTrue(np.all(dy[2, :] == [[1, 0], [0, 2]]))
163+
self.assertTrue(np.all(dy[0, :] == [[1, 0], [0, 1]]))
164+
self.assertTrue(np.all(dy[1, :] == [[1, 0], [0, 1]]))
165+
self.assertTrue(np.all(dy[2, :] == [[1, 0], [0, 1]]))
166166

167167
# Check residuals
168168
rx = y - np.array(values)
@@ -185,9 +185,9 @@ def test_mean_squared_error_multi(self):
185185
# Residuals are: [[0, 0], [-1, -3], [-2, -6]]
186186
# Derivatives are: [[1, 0], [0, 2]]
187187
# dex1 is: (2 / nt / no) * (0 - 1 - 2) * 1 = (1 / 3) * -3 * 1 = -1
188-
# dex2 is: (2 / nt / no) * (0 - 3 - 6) * 2 = (1 / 3) * -9 * 2 = -6
188+
# dex2 is: (2 / nt / no) * (0 - 3 - 6) * 1 = (1 / 3) * -9 * 1 = -3
189189
self.assertEqual(dex[0], -1)
190-
self.assertEqual(dex[1], -6)
190+
self.assertEqual(dex[1], -3)
191191

192192
def test_mean_squared_error_weighted(self):
193193
""" Tests :class:`pints.MeanSquaredError` with weighted outputs. """
@@ -224,9 +224,9 @@ def test_mean_squared_error_weighted(self):
224224
self.assertTrue(np.all(y[0, :] == [1, 4]))
225225
self.assertTrue(np.all(y[1, :] == [1, 4]))
226226
self.assertTrue(np.all(y[2, :] == [1, 4]))
227-
self.assertTrue(np.all(dy[0, :] == [[1, 0], [0, 2]]))
228-
self.assertTrue(np.all(dy[1, :] == [[1, 0], [0, 2]]))
229-
self.assertTrue(np.all(dy[2, :] == [[1, 0], [0, 2]]))
227+
self.assertTrue(np.all(dy[0, :] == [[1, 0], [0, 1]]))
228+
self.assertTrue(np.all(dy[1, :] == [[1, 0], [0, 1]]))
229+
self.assertTrue(np.all(dy[2, :] == [[1, 0], [0, 1]]))
230230

231231
# Check residuals
232232
rx = y - np.array(values)
@@ -251,11 +251,11 @@ def test_mean_squared_error_weighted(self):
251251
# dex1 is: (2 / nt / no) * (0 - 1 - 2) * 1 * 1
252252
# = (1 / 3) * -3 * 1 * 1
253253
# = -1
254-
# dex2 is: (2 / nt / no) * (0 - 3 - 6) * 2 * 2
255-
# = (1 / 3) * -9 * 2 * 2
256-
# = -12
254+
# dex2 is: (2 / nt / no) * (0 - 3 - 6) * 1 * 2
255+
# = (1 / 3) * -9 * 1 * 2
256+
# = -6
257257
self.assertEqual(dex[0], -1)
258-
self.assertEqual(dex[1], -12)
258+
self.assertEqual(dex[1], -6)
259259

260260
def test_probability_based_error(self):
261261
""" Tests :class:`pints.ProbabilityBasedError`. """
@@ -350,9 +350,9 @@ def test_sum_of_squares_error_multi(self):
350350
self.assertTrue(np.all(y[0, :] == [1, 4]))
351351
self.assertTrue(np.all(y[1, :] == [1, 4]))
352352
self.assertTrue(np.all(y[2, :] == [1, 4]))
353-
self.assertTrue(np.all(dy[0, :] == [[1, 0], [0, 2]]))
354-
self.assertTrue(np.all(dy[1, :] == [[1, 0], [0, 2]]))
355-
self.assertTrue(np.all(dy[2, :] == [[1, 0], [0, 2]]))
353+
self.assertTrue(np.all(dy[0, :] == [[1, 0], [0, 1]]))
354+
self.assertTrue(np.all(dy[1, :] == [[1, 0], [0, 1]]))
355+
self.assertTrue(np.all(dy[2, :] == [[1, 0], [0, 1]]))
356356

357357
# Check residuals
358358
rx = y - np.array(values)
@@ -375,9 +375,9 @@ def test_sum_of_squares_error_multi(self):
375375
# Residuals are: [[0, 0], [-1, -3], [-2, -6]]
376376
# Derivatives are: [[1, 0], [0, 2]]
377377
# dex1 is: 2 * (0 - 1 - 2) * 1 = 2 * -3 * 1 = -6
378-
# dex2 is: 2 * (0 - 3 - 6) * 2 = 2 * -9 * 2 = -36
378+
# dex2 is: 2 * (0 - 3 - 6) * 2 = 2 * -9 * 1 = -18
379379
self.assertEqual(dex[0], -6)
380-
self.assertEqual(dex[1], -36)
380+
self.assertEqual(dex[1], -18)
381381

382382
def test_sum_of_squares_error_weighted(self):
383383
""" Tests :class:`pints.MeanSquaredError` with weighted outputs. """
@@ -408,15 +408,15 @@ def test_sum_of_squares_error_weighted(self):
408408
x = [1, 2]
409409

410410
# Model outputs are 3 times [1, 4]
411-
# Model derivatives are 3 times [[1, 0], [0, 2]]
411+
# Model derivatives are 3 times [[1, 0], [0, 1]]
412412
y, dy = p.evaluateS1(x)
413413
self.assertTrue(np.all(y == p.evaluate(x)))
414414
self.assertTrue(np.all(y[0, :] == [1, 4]))
415415
self.assertTrue(np.all(y[1, :] == [1, 4]))
416416
self.assertTrue(np.all(y[2, :] == [1, 4]))
417-
self.assertTrue(np.all(dy[0, :] == [[1, 0], [0, 2]]))
418-
self.assertTrue(np.all(dy[1, :] == [[1, 0], [0, 2]]))
419-
self.assertTrue(np.all(dy[2, :] == [[1, 0], [0, 2]]))
417+
self.assertTrue(np.all(dy[0, :] == [[1, 0], [0, 1]]))
418+
self.assertTrue(np.all(dy[1, :] == [[1, 0], [0, 1]]))
419+
self.assertTrue(np.all(dy[2, :] == [[1, 0], [0, 1]]))
420420

421421
# Check residuals
422422
rx = y - np.array(values)
@@ -437,15 +437,15 @@ def test_sum_of_squares_error_weighted(self):
437437
self.assertEqual(dex.shape, (2, ))
438438

439439
# Residuals are: [[0, 0], [-1, -3], [-2, -6]]
440-
# Derivatives are: [[1, 0], [0, 2]]
440+
# Derivatives are: [[1, 0], [0, 1]]
441441
# dex1 is: 2 * (0 - 1 - 2) * 1 * 1
442442
# = 2 * -3 * 1 * 1
443443
# = -6
444-
# dex2 is: 2 * (0 - 3 - 6) * 2 * 2
445-
# = 2 * -9 * 2 * 2
446-
# = -72
444+
# dex2 is: 2 * (0 - 3 - 6) * 2 * 1
445+
# = 2 * -9 * 2 * 1
446+
# = -36
447447
self.assertEqual(dex[0], -6)
448-
self.assertEqual(dex[1], -72)
448+
self.assertEqual(dex[1], -36)
449449

450450
def test_sum_of_errors(self):
451451
""" Tests :class:`pints.SumOfErrors`. """

pints/tests/test_toy_constant_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,11 @@ def test_derivatives(self):
174174
self.assertEqual(dy.shape, (3, 2, 2))
175175
dmx = np.array(
176176
[[[1, 0], # dx1/dp1 = 1, dx1/dp2 = 0
177-
[0, 2]], # dx2/dp2 = 0, dx2/dp2 = 2
177+
[0, 1]], # dx2/dp2 = 0, dx2/dp2 = 1
178178
[[1, 0],
179-
[0, 2]],
179+
[0, 1]],
180180
[[1, 0],
181-
[0, 2]]]
181+
[0, 1]]]
182182
)
183183
self.assertTrue(np.all(dy == dmx))
184184

0 commit comments

Comments
 (0)