Skip to content

Commit 0206af7

Browse files
committed
Tweaks to docstring and derivative function
- tweak to docstring for constant model - tweak to derivative function to simplify for constant model - tweak to test comments for error measures to reflect changes to the derivative function for the constant model
1 parent 27464e1 commit 0206af7

2 files changed

Lines changed: 7 additions & 7 deletions

File tree

pints/tests/test_error_measures.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_mean_squared_error_multi(self):
154154
x = [1, 2]
155155

156156
# Model outputs are 3 times [1, 4]
157-
# Model derivatives are 3 times [[1, 0], [0, 2]]
157+
# Model derivatives are 3 times [[1, 0], [0, 1]]
158158
y, dy = p.evaluateS1(x)
159159
self.assertTrue(np.all(y == p.evaluate(x)))
160160
self.assertTrue(np.all(y[0, :] == [1, 4]))
@@ -183,7 +183,7 @@ def test_mean_squared_error_multi(self):
183183
self.assertEqual(dex.shape, (2, ))
184184

185185
# Residuals are: [[0, 0], [-1, -3], [-2, -6]]
186-
# Derivatives are: [[1, 0], [0, 2]]
186+
# Derivatives are: [[1, 0], [0, 1]]
187187
# dex1 is: (2 / nt / no) * (0 - 1 - 2) * 1 = (1 / 3) * -3 * 1 = -1
188188
# dex2 is: (2 / nt / no) * (0 - 3 - 6) * 1 = (1 / 3) * -9 * 1 = -3
189189
self.assertEqual(dex[0], -1)
@@ -218,7 +218,7 @@ def test_mean_squared_error_weighted(self):
218218
x = [1, 2]
219219

220220
# Model outputs are 3 times [1, 4]
221-
# Model derivatives are 3 times [[1, 0], [0, 2]]
221+
# Model derivatives are 3 times [[1, 0], [0, 1]]
222222
y, dy = p.evaluateS1(x)
223223
self.assertTrue(np.all(y == p.evaluate(x)))
224224
self.assertTrue(np.all(y[0, :] == [1, 4]))
@@ -344,7 +344,7 @@ def test_sum_of_squares_error_multi(self):
344344
x = [1, 2]
345345

346346
# Model outputs are 3 times [1,4]
347-
# Model derivatives are 3 times [[1, 0], [0, 2]]
347+
# Model derivatives are 3 times [[1, 0], [0, 1]]
348348
y, dy = p.evaluateS1(x)
349349
self.assertTrue(np.all(y == p.evaluate(x)))
350350
self.assertTrue(np.all(y[0, :] == [1, 4]))
@@ -373,7 +373,7 @@ def test_sum_of_squares_error_multi(self):
373373
self.assertEqual(dex.shape, (2, ))
374374

375375
# Residuals are: [[0, 0], [-1, -3], [-2, -6]]
376-
# Derivatives are: [[1, 0], [0, 2]]
376+
# Derivatives are: [[1, 0], [0, 1]]
377377
# dex1 is: 2 * (0 - 1 - 2) * 1 = 2 * -3 * 1 = -6
378378
# dex2 is: 2 * (0 - 3 - 6) * 2 = 2 * -9 * 1 = -18
379379
self.assertEqual(dex[0], -6)

pints/toy/_constant_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class ConstantModel(pints.ForwardModelS1):
3030
.. math::
3131
3232
\\frac{\partial{f_i(t)}}{dp_j} =
33-
\\begin{cases} i, i = j\\\\0, i \\neq j \end{cases}
33+
\\begin{cases} 1, i = j\\\\0, i \\neq j \end{cases}
3434
3535
Arguments:
3636
@@ -107,5 +107,5 @@ def simulateS1(self, parameters, times):
107107
# i.e.
108108
# [[1, 0],
109109
# [0, 1]]
110-
dy = np.tile(np.diag(np.ones(len(self._r))), (len(times), 1, 1))
110+
dy = np.tile(np.diag(np.ones(self._n)), (len(times), 1, 1))
111111
return (y, dy)

0 commit comments

Comments
 (0)