diff --git a/CHANGELOG.md b/CHANGELOG.md index ae3e5321d..0614b78f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ All notable changes to this project will be documented in this file. ### Removed ### Fixed - [#1713](https://github.com/pints-team/pints/pull/1713) Fixed Numpy 2.4.1 compatibility issues. +- [#1690][https://github.com/pints-team/pints/pull/1690) Fixed bug in optimisation controller if population size left at `None`. ## [0.5.1] - 2025-09-26 diff --git a/pints/_optimisers/__init__.py b/pints/_optimisers/__init__.py index 17c99d86a..55bd9a45a 100644 --- a/pints/_optimisers/__init__.py +++ b/pints/_optimisers/__init__.py @@ -622,6 +622,10 @@ def run(self): pop_size = 1 if isinstance(self._optimiser, PopulationBasedOptimiser): pop_size = self._optimiser.population_size() + if pop_size is None: + pop_size = self._optimiser.suggested_population_size( + n_workers if self._parallel else None) + self._optimiser.set_population_size(pop_size) if self._log_to_screen: print('Population size: ' + str(pop_size)) diff --git a/pints/tests/test_opt_controller.py b/pints/tests/test_opt_controller.py index 2313039bb..ca5e10c2e 100755 --- a/pints/tests/test_opt_controller.py +++ b/pints/tests/test_opt_controller.py @@ -364,31 +364,17 @@ def test_stopping_no_criterion(self): opt.set_max_unchanged_iterations(None) self.assertRaises(ValueError, opt.run) - def test_set_population_size(self): - # Tests the set_population_size method for this optimiser. + def test_population_size_not_set(self): + # Population size can be None: then suggested should be used - r = pints.toy.RosenbrockError() - x = np.array([1.01, 1.01]) - opt = pints.OptimisationController(r, x, method=method) - m = opt.optimiser() - n = m.population_size() - m.set_population_size(n + 1) - self.assertEqual(m.population_size(), n + 1) - - # Test invalid size - self.assertRaisesRegex( - ValueError, 'at least 1', m.set_population_size, 0) - - # test hyper parameter interface - self.assertEqual(m.n_hyper_parameters(), 1) - m.set_hyper_parameters([n + 2]) - self.assertEqual(m.population_size(), n + 2) - self.assertRaisesRegex( - ValueError, 'at least 1', m.set_hyper_parameters, [0]) - - # Test changing during run - m.ask() - self.assertRaises(Exception, m.set_population_size, 2) + model = pints.toy.ParabolicError() + opt = pints.OptimisationController(model, [1, 1], method=pints.CMAES) + opt.optimiser().set_population_size(None) + opt.set_log_to_screen(True) + opt.set_max_iterations(3) + with StreamCapture() as c: + opt.run() + self.assertIn('Population size: 6', c.text()) def test_parallel(self): # Test parallelised running. diff --git a/pints/tests/test_opt_population_based.py b/pints/tests/test_opt_population_based.py new file mode 100755 index 000000000..a5c6d0caf --- /dev/null +++ b/pints/tests/test_opt_population_based.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# +# Tests the shared methods of the PopulationBasedOptimiser +# +# This file is part of PINTS (https://github.com/pints-team/pints/) which is +# released under the BSD 3-clause license. See accompanying LICENSE.md for +# copyright notice and full license details. +# +import unittest + +import numpy as np + +import pints +import pints.toy + + +class TestPopulationBasedOptimiser(unittest.TestCase): + """ + Tests the shared methods of the PopulationBasedOptimiser. + """ + def setUp(self): + """ Called before every test """ + np.random.seed(1) + + def test_population_size(self): + + r = pints.toy.RosenbrockError() + x = np.array([1.01, 1.01]) + opt = pints.OptimisationController(r, x, method=pints.XNES) + m = opt.optimiser() + n = m.population_size() + m.set_population_size(n + 1) + self.assertEqual(m.population_size(), n + 1) + + # Test invalid size + self.assertRaisesRegex( + ValueError, 'at least 1', m.set_population_size, 0) + + # test hyper parameter interface + self.assertEqual(m.n_hyper_parameters(), 1) + m.set_hyper_parameters([n + 2]) + self.assertEqual(m.population_size(), n + 2) + self.assertRaisesRegex( + ValueError, 'at least 1', m.set_hyper_parameters, [0]) + + # Test changing during run + m.ask() + self.assertRaises(Exception, m.set_population_size, 2) + + +if __name__ == '__main__': + unittest.main() +