Skip to content

Commit 5814aa0

Browse files
Copilotbasnijholt
andauthored
Fix Learner2D state rollback in BalancingLearner ask(tell_pending=False) (#492)
* Initial plan * Fix Learner2D restore snapshot mutation in balancing asks * Use shallow copies instead of deepcopy in Learner2D.__getstate__ Key-level mutations (stack pops/inserts) are all that restore() needs to roll back; values are never mutated in place. A deepcopy of the full data dict would add O(npoints) allocation to every ask(tell_pending=False) call and to pickling. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Bas Nijholt <bas@nijho.lt>
1 parent 15a81c7 commit 5814aa0

2 files changed

Lines changed: 15 additions & 3 deletions

File tree

adaptive/learner/learner2D.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -897,8 +897,8 @@ def __getstate__(self):
897897
cloudpickle.dumps(self.function),
898898
self.bounds,
899899
self.loss_per_triangle,
900-
self._stack,
901-
self._get_data(),
900+
self._stack.copy(),
901+
self._get_data().copy(),
902902
)
903903

904904
def __setstate__(self, state):

adaptive/tests/test_balancing_learner.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from adaptive.learner import BalancingLearner, Learner1D
5+
from adaptive.learner import BalancingLearner, Learner1D, Learner2D
66
from adaptive.runner import simple
77

88
strategies = ["loss", "loss_improvements", "npoints", "cycle"]
@@ -51,6 +51,18 @@ def test_ask_0(strategy):
5151
assert len(points) == 0
5252

5353

54+
def test_ask_without_pending_restores_learner2d_state():
55+
learner = Learner2D(lambda xy: xy[0] + xy[1], bounds=((-1, 1), (-1, 1)))
56+
initial_stack = list(learner._stack.items())
57+
initial_data = learner.data.copy()
58+
59+
balancing_learner = BalancingLearner([learner])
60+
balancing_learner.ask(1, tell_pending=False)
61+
62+
assert list(learner._stack.items()) == initial_stack
63+
assert learner.data == initial_data
64+
65+
5466
@pytest.mark.parametrize(
5567
"strategy, goal_type, goal",
5668
[

0 commit comments

Comments
 (0)