1010from ..notebook_integration import ensure_holoviews
1111from .base_learner import BaseLearner
1212
13+
1314def uniform_loss (interval , scale , function_values ):
1415 """Loss function that samples the domain uniformly.
1516
@@ -93,7 +94,7 @@ def __init__(self, function, bounds, loss_per_interval=None):
9394 self .losses_combined = {}
9495
9596 self .data = sortedcontainers .SortedDict ()
96- self .data_interp = {}
97+ self .pending_points = set ()
9798
9899 # A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
99100 # properties.
@@ -118,10 +119,6 @@ def __init__(self, function, bounds, loss_per_interval=None):
118119 def vdim (self ):
119120 return 1 if self ._vdim is None else self ._vdim
120121
121- @property
122- def data_combined (self ):
123- return {** self .data , ** self .data_interp }
124-
125122 @property
126123 def npoints (self ):
127124 return len (self .data )
@@ -133,25 +130,49 @@ def loss(self, real=True):
133130 else :
134131 return max (losses .values ())
135132
136- def update_losses (self , x , data , neighbors , losses ):
137- x_lower , x_upper = neighbors [x ]
138-
139- def _update (interval ):
140- a , b = interval
141- if abs (a - b ) > self ._dx_eps :
142- losses [interval ] = self .loss_per_interval (interval ,
143- self ._scale , data )
133+ def update_interpolated_losses_in_interval (self , x_lower , x_upper ):
134+ if x_lower is not None and x_upper is not None :
135+ dx = x_upper - x_lower
136+ loss = self .loss_per_interval ((x_lower , x_upper ), self ._scale , self .data )
137+ self .losses [x_lower , x_upper ] = loss if abs (dx ) > self ._dx_eps else 0
138+
139+ start = self .neighbors_combined .bisect_right (x_lower )
140+ end = self .neighbors_combined .bisect_left (x_upper )
141+ for i in range (start , end ):
142+ a , b = self .neighbors_combined .iloc [i ], self .neighbors_combined .iloc [i + 1 ]
143+ self .losses_combined [a , b ] = (b - a ) * self .losses [x_lower , x_upper ] / dx
144+ if start == end :
145+ self .losses_combined [x_lower , x_upper ] = self .losses [x_lower , x_upper ]
146+
147+ def update_losses (self , x , real = True ):
148+ if real :
149+ x_lower , x_upper = self .get_neighbors (x , self .neighbors )
150+ self .update_interpolated_losses_in_interval (x_lower , x )
151+ self .update_interpolated_losses_in_interval (x , x_upper )
152+ self .losses .pop ((x_lower , x_upper ), None )
153+ else :
154+ losses_combined = self .losses_combined
155+ x_lower , x_upper = self .get_neighbors (x , self .neighbors )
156+ a , b = self .get_neighbors (x , self .neighbors_combined )
157+ if x_lower is not None and x_upper is not None :
158+ dx = x_upper - x_lower
159+ loss = self .losses [x_lower , x_upper ]
160+ losses_combined [a , x ] = ((x - a ) * loss / dx
161+ if abs (x - a ) > self ._dx_eps else 0 )
162+ losses_combined [x , b ] = ((b - x ) * loss / dx
163+ if abs (b - x ) > self ._dx_eps else 0 )
144164 else :
145- losses [interval ] = 0
165+ if a is not None :
166+ losses_combined [a , x ] = float ('inf' )
167+ if b is not None :
168+ losses_combined [x , b ] = float ('inf' )
146169
147- if x_lower is not None :
148- _update ((x_lower , x ))
149- if x_upper is not None :
150- _update ((x , x_upper ))
151- try :
152- del losses [x_lower , x_upper ]
153- except KeyError :
154- pass
170+ losses_combined .pop ((a , b ), None )
171+
172+ def get_neighbors (self , x , neighbors ):
173+ if x in neighbors :
174+ return neighbors [x ]
175+ return self .find_neighbors (x , neighbors )
155176
156177 def find_neighbors (self , x , neighbors ):
157178 pos = neighbors .bisect_left (x )
@@ -197,38 +218,19 @@ def add_point(self, x, y):
197218 real = y is not None
198219
199220 if real :
200- # Add point to the real data dict and pop from the unfinished
201- # data_interp dict.
221+ # Add point to the real data dict
202222 self .data [x ] = y
203- self .data_interp .pop (x , None )
223+ # remove from set of pending points
224+ self .pending_points .discard (x )
204225
205226 if self ._vdim is None :
206227 try :
207228 self ._vdim = len (np .squeeze (y ))
208229 except TypeError :
209230 self ._vdim = 1
210-
211- # Invalidate interpolated neighbors of new point
212- i = self .data .bisect_left (x )
213- if i == 0 :
214- x_left = self .data .iloc [0 ]
215- for _x in self .data_interp :
216- if _x < x_left :
217- self .data_interp [_x ] = None
218- elif i == len (self .data ):
219- x_right = self .data .iloc [- 1 ]
220- for _x in self .data_interp :
221- if _x > x_right :
222- self .data_interp [_x ] = None
223- else :
224- x_left , x_right = self .data .iloc [i - 1 ], self .data .iloc [i ]
225- for _x in self .data_interp :
226- if x_left < _x < x_right :
227- self .data_interp [_x ] = None
228-
229231 else :
230- # The keys of data_interp are the unknown points
231- self .data_interp [ x ] = None
232+ # The keys of pending_points are the unknown points
233+ self .pending_points . add ( x )
232234
233235 # Update the neighbors
234236 self .update_neighbors (x , self .neighbors_combined )
@@ -238,37 +240,17 @@ def add_point(self, x, y):
238240 # Update the scale
239241 self .update_scale (x , y )
240242
241- # Interpolate
242- for _x , _y in self .data_interp .items ():
243- if _y is None :
244- if len (self .data ) >= 2 :
245- i = self .data .bisect_left (_x )
246- if i == 0 :
247- i_left , i_right = (0 , 1 )
248- elif i == len (self .data ):
249- i_left , i_right = (- 2 , - 1 )
250- else :
251- i_left , i_right = (i - 1 , i )
252- x_left , x_right = self .data .iloc [i_left ], self .data .iloc [i_right ]
253- y_left , y_right = self .data [x_left ], self .data [x_right ]
254- dx = x_right - x_left
255- dy = y_right - y_left
256- self .data_interp [_x ] = (dy / dx ) * (_x - x_left ) + y_left
257-
258243 # Update the losses
259- self .update_losses (x , self .data_combined , self .neighbors_combined ,
260- self .losses_combined )
261- if real :
262- self .update_losses (x , self .data , self .neighbors , self .losses )
244+ self .update_losses (x , real )
245+
246+ # If the scale has increased enough, recompute all losses.
247+ if self ._scale [1 ] > self ._oldscale [1 ] * 2 :
248+
249+ for interval in self .losses :
250+ self .update_interpolated_losses_in_interval (* interval )
251+
252+ self ._oldscale = deepcopy (self ._scale )
263253
264- # If the scale has doubled, recompute all losses.
265- if self ._scale > self ._oldscale * 2 :
266- self .losses = {xs : self .loss_per_interval (xs , self ._scale , self .data )
267- for xs in self .losses }
268- self .losses_combined = {x : self .loss_per_interval (x , self ._scale ,
269- self .data_combined )
270- for x in self .losses_combined }
271- self ._oldscale = self ._scale
272254
273255 def choose_points (self , n , add_data = True ):
274256 """Return n points that are expected to maximally reduce the loss."""
@@ -285,7 +267,7 @@ def choose_points(self, n, add_data=True):
285267 # If the bounds have not been chosen yet, we choose them first.
286268 points = []
287269 for bound in self .bounds :
288- if bound not in self .data and bound not in self .data_interp :
270+ if bound not in self .data and bound not in self .pending_points :
289271 points .append (bound )
290272
291273 if len (points ) == 2 :
@@ -305,8 +287,9 @@ def xs(x, n):
305287 return [x [0 ] + step * i for i in range (1 , n )]
306288
307289 # Calculate how many points belong to each interval.
308- quals = [(- loss , x_range , 1 ) for (x_range , loss ) in
309- self .losses_combined .items ()]
290+ x_scale = self ._scale [0 ]
291+ quals = [(- loss if not math .isinf (loss ) else (x0 - x1 ) / x_scale , (x0 , x1 ), 1 )
292+ for ((x0 , x1 ), loss ) in self .losses_combined .items ()]
310293
311294 heapq .heapify (quals )
312295
@@ -345,6 +328,6 @@ def plot(self):
345328
346329
347330 def remove_unfinished (self ):
348- self .data_interp = {}
331+ self .pending_points = set ()
349332 self .losses_combined = deepcopy (self .losses )
350333 self .neighbors_combined = deepcopy (self .neighbors )
0 commit comments