Skip to content

Commit 5bc216f

Browse files
committed
Fix simulated learner state restoration
1 parent 3c7a3b7 commit 5bc216f

5 files changed

Lines changed: 88 additions & 6 deletions

File tree

adaptive/learner/balancing_learner.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,13 @@ def ask(
261261
return [], []
262262

263263
if not tell_pending:
264-
with restore(*self.learners):
265-
return self._ask_and_tell(n)
264+
try:
265+
with restore(*self.learners):
266+
return self._ask_and_tell(n)
267+
finally:
268+
self._ask_cache.clear()
269+
self._loss.clear()
270+
self._pending_loss.clear()
266271
else:
267272
return self._ask_and_tell(n)
268273

adaptive/learner/learnerND.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,7 @@ def _update_range(self, new_output):
862862
# this is the first point, nothing to do, just set the range
863863
self._min_value = np.min(new_output)
864864
self._max_value = np.max(new_output)
865-
self._old_scale = self._scale or 1
865+
self._old_scale = self._scale
866866
return False
867867

868868
# if range in one or more directions is doubled, then update all losses
@@ -885,7 +885,10 @@ def _update_range(self, new_output):
885885

886886
self._output_multiplier = scale_multiplier
887887

888-
scale_factor = self._scale / self._old_scale
888+
if self._old_scale == 0:
889+
scale_factor = math.inf if self._scale > 0 else 1
890+
else:
891+
scale_factor = self._scale / self._old_scale
889892
if scale_factor > self._recompute_losses_factor:
890893
self._old_scale = self._scale
891894
self._recompute_all_losses()

adaptive/tests/test_balancing_learner.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
11
from __future__ import annotations
22

3+
import functools as ft
4+
import math
5+
import random
6+
7+
import numpy as np
38
import pytest
49

5-
from adaptive.learner import BalancingLearner, Learner1D
10+
from adaptive.learner import BalancingLearner, Learner1D, Learner2D
611
from adaptive.runner import simple
712

813
strategies = ["loss", "loss_improvements", "npoints", "cycle"]
914

1015

16+
def ring_of_fire(xy, d):
17+
a = 0.2
18+
x, y = xy
19+
return x + math.exp(-((x**2 + y**2 - d**2) ** 2) / a**4)
20+
21+
1122
def test_balancing_learner_loss_cache():
1223
learner = Learner1D(lambda x: x, bounds=(-1, 1))
1324
learner.tell(-1, -1)
@@ -64,3 +75,34 @@ def test_strategies(strategy, goal_type, goal):
6475
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
6576
learner = BalancingLearner(learners, strategy=strategy)
6677
simple(learner, **{goal_type: goal})
78+
79+
80+
def test_loss_improvements_strategy_with_tell_pending_false_reserves_child_points():
81+
random.seed(3104322362)
82+
np.random.seed(3104322362 % 2**32)
83+
84+
learners = [
85+
Learner2D(
86+
ft.partial(ring_of_fire, d=random.uniform(0.2, 1)),
87+
bounds=((-1, 1), (-1, 1)),
88+
)
89+
for _ in range(4)
90+
]
91+
learner = BalancingLearner(learners, strategy="loss_improvements")
92+
93+
stash = []
94+
for n, m in [(1, 1), (4, 4), (2, 0), (4, 4), (8, 6)]:
95+
xs, _ = learner.ask(n, tell_pending=False)
96+
random.shuffle(xs)
97+
for _ in range(m):
98+
stash.append(xs.pop())
99+
100+
for x in xs:
101+
learner.tell(x, learner.function(x))
102+
103+
random.shuffle(stash)
104+
for _ in range(m):
105+
x = stash.pop()
106+
learner.tell(x, learner.function(x))
107+
108+
assert all(not child.pending_points for child in learners)

adaptive/tests/unit/test_learnernd.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,37 @@ def f_vec(x):
153153
assert np.isclose(result[1], 0.0)
154154

155155

156+
def test_learnerND_recomputes_losses_for_small_scale_updates():
157+
learner = make_1d_learner()
158+
learner._recompute_losses_factor = 1
159+
160+
for point, value in [((-1,), 0.0), ((0.0,), 0.45), ((1.0,), 0.45)]:
161+
learner.tell(point, value)
162+
163+
simplex = next(
164+
simplex
165+
for simplex in learner.tri.simplices
166+
if {tuple(vertex) for vertex in learner.tri.get_vertices(simplex)}
167+
== {(-1.0,), (0.0,)}
168+
)
169+
cached_before = learner._losses[simplex]
170+
assert np.isclose(cached_before, learner._compute_loss(simplex))
171+
172+
learner.tell((0.5,), 0.67)
173+
174+
simplex = next(
175+
simplex
176+
for simplex in learner.tri.simplices
177+
if {tuple(vertex) for vertex in learner.tri.get_vertices(simplex)}
178+
== {(-1.0,), (0.0,)}
179+
)
180+
cached_after = learner._losses[simplex]
181+
182+
assert learner._old_scale == pytest.approx(0.67)
183+
assert np.isclose(cached_after, learner._compute_loss(simplex))
184+
assert not np.isclose(cached_after, cached_before)
185+
186+
156187
def test_learnerND_1d_plot_requires_holoviews(monkeypatch):
157188
"""Test that plotting fails with a clear error without holoviews."""
158189

adaptive/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import concurrent.futures as concurrent
4+
import copy
45
import functools
56
import gzip
67
import inspect
@@ -27,7 +28,7 @@ def named_product(**items: Sequence[Any]):
2728

2829
@contextmanager
2930
def restore(*learners) -> Iterator[None]:
30-
states = [learner.__getstate__() for learner in learners]
31+
states = [copy.deepcopy(learner.__getstate__()) for learner in learners]
3132
try:
3233
yield
3334
finally:

0 commit comments

Comments
 (0)