Skip to content

Commit 96fe8fd

Browse files
committed
Added Problem.model() methods
1 parent 0e27f1e commit 96fe8fd

4 files changed

Lines changed: 12 additions & 4 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ All notable changes to this project will be documented in this file.
66
## Unreleased
77

88
### Added
9-
- [#1715](https://github.com/pints-team/pints/pull/1715) Added `ProblemErrorMeasure.problem()`, `ProblemLogLikelihood.problem()`
9+
- [#1715](https://github.com/pints-team/pints/pull/1715) Added methods `ProblemErrorMeasure.problem()`, `ProblemLogLikelihood.problem()`, `SingleOutputProblem.model()` and `MultiOutputProblem.model()`.
1010
### Changed
1111
- [#1713](https://github.com/pints-team/pints/pull/1713) PINTS now requires matplotlib 2.2 or newer.
1212
### Deprecated

pints/_core.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,12 @@ def evaluateS1(self, parameters):
170170
np.asarray(dy).reshape((self._n_times, self._n_parameters))
171171
)
172172

173+
def model(self):
174+
""" Returns the :class:`ForwardModel` underlying this problem. """
175+
return self._model
176+
173177
def n_outputs(self):
174-
"""
175-
Returns the number of outputs for this problem (always 1).
176-
"""
178+
""" Returns the number of outputs for this problem (always 1). """
177179
return 1
178180

179181
def n_parameters(self):
@@ -281,6 +283,10 @@ def evaluateS1(self, parameters):
281283
self._n_times, self._n_outputs, self._n_parameters)
282284
)
283285

286+
def model(self):
287+
""" Returns the :class:`ForwardModel` underlying this problem. """
288+
return self._model
289+
284290
def n_outputs(self):
285291
"""
286292
Returns the number of outputs for this problem.

pints/tests/test_multi_output_problem.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def test_basics(self):
3535
self.assertEqual(problem.n_parameters(), model.n_parameters(), 2)
3636
self.assertEqual(problem.n_outputs(), model.n_outputs(), 3)
3737
self.assertEqual(problem.n_times(), len(times))
38+
self.assertIs(problem.model(), model)
3839

3940
# Test errors
4041
times[0] = -2

pints/tests/test_single_output_problem.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def test_basics(self):
3232
self.assertEqual(problem.n_parameters(), model.n_parameters(), 2)
3333
self.assertEqual(problem.n_outputs(), model.n_outputs(), 1)
3434
self.assertEqual(problem.n_times(), len(times))
35+
self.assertIs(problem.model(), model)
3536

3637
# Test errors
3738
times[0] = -2

0 commit comments

Comments
 (0)