Skip to content

Commit 4cded72

Browse files
authored
Merge pull request #1715 from pints-team/1706-access-models-etc
Adding methods to allow access to problems, models, etc underlying errors, priors, etc
2 parents e8dd222 + 96fe8fd commit 4cded72

8 files changed

Lines changed: 49 additions & 4 deletions

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ All notable changes to this project will be documented in this file.
66
## Unreleased
77

88
### Added
9+
- [#1715](https://github.com/pints-team/pints/pull/1715) Added methods `ProblemErrorMeasure.problem()`, `ProblemLogLikelihood.problem()`, `SingleOutputProblem.model()` and `MultiOutputProblem.model()`.
910
### Changed
1011
- [#1713](https://github.com/pints-team/pints/pull/1713) PINTS now requires matplotlib 2.2 or newer.
1112
### Deprecated

pints/_core.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,12 @@ def evaluateS1(self, parameters):
170170
np.asarray(dy).reshape((self._n_times, self._n_parameters))
171171
)
172172

173+
def model(self):
174+
""" Returns the :class:`ForwardModel` underlying this problem. """
175+
return self._model
176+
173177
def n_outputs(self):
174-
"""
175-
Returns the number of outputs for this problem (always 1).
176-
"""
178+
""" Returns the number of outputs for this problem (always 1). """
177179
return 1
178180

179181
def n_parameters(self):
@@ -281,6 +283,10 @@ def evaluateS1(self, parameters):
281283
self._n_times, self._n_outputs, self._n_parameters)
282284
)
283285

286+
def model(self):
287+
""" Returns the :class:`ForwardModel` underlying this problem. """
288+
return self._model
289+
284290
def n_outputs(self):
285291
"""
286292
Returns the number of outputs for this problem.

pints/_error_measures.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class ProblemErrorMeasure(ErrorMeasure):
4949
:class:`single<pints.SingleOutputProblem>` or
5050
:class:`multi-output<pints.MultiOutputProblem>` problems.
5151
"""
52-
def __init__(self, problem=None):
52+
def __init__(self, problem):
5353
super(ProblemErrorMeasure, self).__init__()
5454
self._problem = problem
5555
self._times = problem.times()
@@ -62,6 +62,10 @@ def n_parameters(self):
6262
""" See :meth:`ErrorMeasure.n_parameters()`. """
6363
return self._n_parameters
6464

65+
def problem(self):
66+
""" Returns the problem this error measure was defined on. """
67+
return self._problem
68+
6569

6670
class MeanSquaredError(ProblemErrorMeasure):
6771
r"""

pints/_log_pdfs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ class ProblemLogLikelihood(LogPDF):
321321
----------
322322
problem
323323
The time-series problem this log-likelihood is defined for.
324+
324325
"""
325326
def __init__(self, problem):
326327
super(ProblemLogLikelihood, self).__init__()
@@ -334,6 +335,9 @@ def n_parameters(self):
334335
""" See :meth:`LogPDF.n_parameters()`. """
335336
return self._n_parameters
336337

338+
def problem(self):
339+
return self._problem
340+
337341

338342
class LogPosterior(LogPDF):
339343
"""

pints/tests/test_error_measures.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,18 @@ def test_bad_constructor(self):
496496
pints.ProbabilityBasedError, MiniProblem())
497497

498498

499+
class TestProblemErrorMeasure(unittest.TestCase):
500+
""" Tests shared methods of the ProblemErrorMeasure abstract class. """
501+
502+
def test_shared(self):
503+
# Test underlying ProblemErrorMeasure method problem()
504+
505+
problem = MiniProblem()
506+
error = pints.MeanSquaredError(problem)
507+
self.assertIs(error.problem(), problem)
508+
self.assertEqual(error.n_parameters(), problem.n_parameters())
509+
510+
499511
class TestRootMeanSquaredError(unittest.TestCase):
500512

501513
@classmethod

pints/tests/test_log_likelihoods.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2572,6 +2572,22 @@ def test_negative_sd(self):
25722572
self.assertEqual(log_likelihood([1, 1, 0]), -np.inf)
25732573

25742574

2575+
class TestProblemLogLikelihood(unittest.TestCase):
2576+
""" Test shared ProblemLogLikelihood methods. """
2577+
2578+
@classmethod
2579+
def setUpClass(cls):
2580+
cls.model = pints.toy.ConstantModel(1)
2581+
cls.times = np.array([1, 2, 3, 4])
2582+
cls.data = np.asarray([1.9, 2.1, 1.8, 2.2])
2583+
2584+
def test_shared_methods(self):
2585+
problem = pints.SingleOutputProblem(self.model, self.times, self.data)
2586+
log_likelihood = pints.GaussianKnownSigmaLogLikelihood(problem, 1.5)
2587+
self.assertEqual(log_likelihood.n_parameters(), problem.n_parameters())
2588+
self.assertIs(log_likelihood.problem(), problem)
2589+
2590+
25752591
class TestScaledLogLikelihood(unittest.TestCase):
25762592

25772593
@classmethod

pints/tests/test_multi_output_problem.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def test_basics(self):
3535
self.assertEqual(problem.n_parameters(), model.n_parameters(), 2)
3636
self.assertEqual(problem.n_outputs(), model.n_outputs(), 3)
3737
self.assertEqual(problem.n_times(), len(times))
38+
self.assertIs(problem.model(), model)
3839

3940
# Test errors
4041
times[0] = -2

pints/tests/test_single_output_problem.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def test_basics(self):
3232
self.assertEqual(problem.n_parameters(), model.n_parameters(), 2)
3333
self.assertEqual(problem.n_outputs(), model.n_outputs(), 1)
3434
self.assertEqual(problem.n_times(), len(times))
35+
self.assertIs(problem.model(), model)
3536

3637
# Test errors
3738
times[0] = -2

0 commit comments

Comments
 (0)