|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import functools as ft |
| 4 | +import math |
| 5 | +import random |
| 6 | + |
| 7 | +import numpy as np |
3 | 8 | import pytest |
4 | 9 |
|
5 | | -from adaptive.learner import BalancingLearner, Learner1D |
| 10 | +from adaptive.learner import BalancingLearner, Learner1D, Learner2D |
6 | 11 | from adaptive.runner import simple |
7 | 12 |
|
8 | 13 | strategies = ["loss", "loss_improvements", "npoints", "cycle"] |
9 | 14 |
|
10 | 15 |
|
| 16 | +def ring_of_fire(xy, d): |
| 17 | + a = 0.2 |
| 18 | + x, y = xy |
| 19 | + return x + math.exp(-((x**2 + y**2 - d**2) ** 2) / a**4) |
| 20 | + |
| 21 | + |
11 | 22 | def test_balancing_learner_loss_cache(): |
12 | 23 | learner = Learner1D(lambda x: x, bounds=(-1, 1)) |
13 | 24 | learner.tell(-1, -1) |
@@ -64,3 +75,34 @@ def test_strategies(strategy, goal_type, goal): |
64 | 75 | learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)] |
65 | 76 | learner = BalancingLearner(learners, strategy=strategy) |
66 | 77 | simple(learner, **{goal_type: goal}) |
| 78 | + |
| 79 | + |
| 80 | +def test_loss_improvements_strategy_with_tell_pending_false_reserves_child_points(): |
| 81 | + random.seed(3104322362) |
| 82 | + np.random.seed(3104322362 % 2**32) |
| 83 | + |
| 84 | + learners = [ |
| 85 | + Learner2D( |
| 86 | + ft.partial(ring_of_fire, d=random.uniform(0.2, 1)), |
| 87 | + bounds=((-1, 1), (-1, 1)), |
| 88 | + ) |
| 89 | + for _ in range(4) |
| 90 | + ] |
| 91 | + learner = BalancingLearner(learners, strategy="loss_improvements") |
| 92 | + |
| 93 | + stash = [] |
| 94 | + for n, m in [(1, 1), (4, 4), (2, 0), (4, 4), (8, 6)]: |
| 95 | + xs, _ = learner.ask(n, tell_pending=False) |
| 96 | + random.shuffle(xs) |
| 97 | + for _ in range(m): |
| 98 | + stash.append(xs.pop()) |
| 99 | + |
| 100 | + for x in xs: |
| 101 | + learner.tell(x, learner.function(x)) |
| 102 | + |
| 103 | + random.shuffle(stash) |
| 104 | + for _ in range(m): |
| 105 | + x = stash.pop() |
| 106 | + learner.tell(x, learner.function(x)) |
| 107 | + |
| 108 | + assert all(not child.pending_points for child in learners) |
0 commit comments