Skip to content

Commit 205c37b

Browse files
committed
Merge branch '52--interpolate-loss-instead-of-function' into 'master'
Interpolate loss between points instead of interpolating the value See merge request qt/adaptive!61
2 parents 65ca5f7 + d0ee785 commit 205c37b

1 file changed

Lines changed: 62 additions & 79 deletions

File tree

adaptive/learner/learner1D.py

Lines changed: 62 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ..notebook_integration import ensure_holoviews
1111
from .base_learner import BaseLearner
1212

13+
1314
def uniform_loss(interval, scale, function_values):
1415
"""Loss function that samples the domain uniformly.
1516
@@ -93,7 +94,7 @@ def __init__(self, function, bounds, loss_per_interval=None):
9394
self.losses_combined = {}
9495

9596
self.data = sortedcontainers.SortedDict()
96-
self.data_interp = {}
97+
self.pending_points = set()
9798

9899
# A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
99100
# properties.
@@ -118,10 +119,6 @@ def __init__(self, function, bounds, loss_per_interval=None):
118119
def vdim(self):
119120
return 1 if self._vdim is None else self._vdim
120121

121-
@property
122-
def data_combined(self):
123-
return {**self.data, **self.data_interp}
124-
125122
@property
126123
def npoints(self):
127124
return len(self.data)
@@ -133,25 +130,49 @@ def loss(self, real=True):
133130
else:
134131
return max(losses.values())
135132

136-
def update_losses(self, x, data, neighbors, losses):
137-
x_lower, x_upper = neighbors[x]
138-
139-
def _update(interval):
140-
a, b = interval
141-
if abs(a - b) > self._dx_eps:
142-
losses[interval] = self.loss_per_interval(interval,
143-
self._scale, data)
133+
def update_interpolated_losses_in_interval(self, x_lower, x_upper):
134+
if x_lower is not None and x_upper is not None:
135+
dx = x_upper - x_lower
136+
loss = self.loss_per_interval((x_lower, x_upper), self._scale, self.data)
137+
self.losses[x_lower, x_upper] = loss if abs(dx) > self._dx_eps else 0
138+
139+
start = self.neighbors_combined.bisect_right(x_lower)
140+
end = self.neighbors_combined.bisect_left(x_upper)
141+
for i in range(start, end):
142+
a, b = self.neighbors_combined.iloc[i], self.neighbors_combined.iloc[i + 1]
143+
self.losses_combined[a, b] = (b - a) * self.losses[x_lower, x_upper] / dx
144+
if start == end:
145+
self.losses_combined[x_lower, x_upper] = self.losses[x_lower, x_upper]
146+
147+
def update_losses(self, x, real=True):
148+
if real:
149+
x_lower, x_upper = self.get_neighbors(x, self.neighbors)
150+
self.update_interpolated_losses_in_interval(x_lower, x)
151+
self.update_interpolated_losses_in_interval(x, x_upper)
152+
self.losses.pop((x_lower, x_upper), None)
153+
else:
154+
losses_combined = self.losses_combined
155+
x_lower, x_upper = self.get_neighbors(x, self.neighbors)
156+
a, b = self.get_neighbors(x, self.neighbors_combined)
157+
if x_lower is not None and x_upper is not None:
158+
dx = x_upper - x_lower
159+
loss = self.losses[x_lower, x_upper]
160+
losses_combined[a, x] = ((x - a) * loss / dx
161+
if abs(x - a) > self._dx_eps else 0)
162+
losses_combined[x, b] = ((b - x) * loss / dx
163+
if abs(b - x) > self._dx_eps else 0)
144164
else:
145-
losses[interval] = 0
165+
if a is not None:
166+
losses_combined[a, x] = float('inf')
167+
if b is not None:
168+
losses_combined[x, b] = float('inf')
146169

147-
if x_lower is not None:
148-
_update((x_lower, x))
149-
if x_upper is not None:
150-
_update((x, x_upper))
151-
try:
152-
del losses[x_lower, x_upper]
153-
except KeyError:
154-
pass
170+
losses_combined.pop((a, b), None)
171+
172+
def get_neighbors(self, x, neighbors):
173+
if x in neighbors:
174+
return neighbors[x]
175+
return self.find_neighbors(x, neighbors)
155176

156177
def find_neighbors(self, x, neighbors):
157178
pos = neighbors.bisect_left(x)
@@ -197,38 +218,19 @@ def add_point(self, x, y):
197218
real = y is not None
198219

199220
if real:
200-
# Add point to the real data dict and pop from the unfinished
201-
# data_interp dict.
221+
# Add point to the real data dict
202222
self.data[x] = y
203-
self.data_interp.pop(x, None)
223+
# remove from set of pending points
224+
self.pending_points.discard(x)
204225

205226
if self._vdim is None:
206227
try:
207228
self._vdim = len(np.squeeze(y))
208229
except TypeError:
209230
self._vdim = 1
210-
211-
# Invalidate interpolated neighbors of new point
212-
i = self.data.bisect_left(x)
213-
if i == 0:
214-
x_left = self.data.iloc[0]
215-
for _x in self.data_interp:
216-
if _x < x_left:
217-
self.data_interp[_x] = None
218-
elif i == len(self.data):
219-
x_right = self.data.iloc[-1]
220-
for _x in self.data_interp:
221-
if _x > x_right:
222-
self.data_interp[_x] = None
223-
else:
224-
x_left, x_right = self.data.iloc[i-1], self.data.iloc[i]
225-
for _x in self.data_interp:
226-
if x_left < _x < x_right:
227-
self.data_interp[_x] = None
228-
229231
else:
230-
# The keys of data_interp are the unknown points
231-
self.data_interp[x] = None
232+
# The keys of pending_points are the unknown points
233+
self.pending_points.add(x)
232234

233235
# Update the neighbors
234236
self.update_neighbors(x, self.neighbors_combined)
@@ -238,37 +240,17 @@ def add_point(self, x, y):
238240
# Update the scale
239241
self.update_scale(x, y)
240242

241-
# Interpolate
242-
for _x, _y in self.data_interp.items():
243-
if _y is None:
244-
if len(self.data) >=2:
245-
i = self.data.bisect_left(_x)
246-
if i == 0:
247-
i_left, i_right = (0, 1)
248-
elif i == len(self.data):
249-
i_left, i_right = (-2, -1)
250-
else:
251-
i_left, i_right = (i - 1, i)
252-
x_left, x_right = self.data.iloc[i_left], self.data.iloc[i_right]
253-
y_left, y_right = self.data[x_left], self.data[x_right]
254-
dx = x_right - x_left
255-
dy = y_right - y_left
256-
self.data_interp[_x] = (dy / dx) * (_x - x_left) + y_left
257-
258243
# Update the losses
259-
self.update_losses(x, self.data_combined, self.neighbors_combined,
260-
self.losses_combined)
261-
if real:
262-
self.update_losses(x, self.data, self.neighbors, self.losses)
244+
self.update_losses(x, real)
245+
246+
# If the scale has increased enough, recompute all losses.
247+
if self._scale[1] > self._oldscale[1] * 2:
248+
249+
for interval in self.losses:
250+
self.update_interpolated_losses_in_interval(*interval)
251+
252+
self._oldscale = deepcopy(self._scale)
263253

264-
# If the scale has doubled, recompute all losses.
265-
if self._scale > self._oldscale * 2:
266-
self.losses = {xs: self.loss_per_interval(xs, self._scale, self.data)
267-
for xs in self.losses}
268-
self.losses_combined = {x: self.loss_per_interval(x, self._scale,
269-
self.data_combined)
270-
for x in self.losses_combined}
271-
self._oldscale = self._scale
272254

273255
def choose_points(self, n, add_data=True):
274256
"""Return n points that are expected to maximally reduce the loss."""
@@ -285,7 +267,7 @@ def choose_points(self, n, add_data=True):
285267
# If the bounds have not been chosen yet, we choose them first.
286268
points = []
287269
for bound in self.bounds:
288-
if bound not in self.data and bound not in self.data_interp:
270+
if bound not in self.data and bound not in self.pending_points:
289271
points.append(bound)
290272

291273
if len(points) == 2:
@@ -305,8 +287,9 @@ def xs(x, n):
305287
return [x[0] + step * i for i in range(1, n)]
306288

307289
# Calculate how many points belong to each interval.
308-
quals = [(-loss, x_range, 1) for (x_range, loss) in
309-
self.losses_combined.items()]
290+
x_scale = self._scale[0]
291+
quals = [(-loss if not math.isinf(loss) else (x0 - x1) / x_scale, (x0, x1), 1)
292+
for ((x0, x1), loss) in self.losses_combined.items()]
310293

311294
heapq.heapify(quals)
312295

@@ -345,6 +328,6 @@ def plot(self):
345328

346329

347330
def remove_unfinished(self):
348-
self.data_interp = {}
331+
self.pending_points = set()
349332
self.losses_combined = deepcopy(self.losses)
350333
self.neighbors_combined = deepcopy(self.neighbors)

0 commit comments

Comments
 (0)