@@ -233,6 +233,78 @@ def search(self, x):
233233 return [x for (x , _ ) in neighs ]
234234
235235
236+ class _GridOverlap :
237+
238+ def __init__ (
239+ self ,
240+ n_intervals = 1 ,
241+ overlap_frac = None ,
242+ kind = 'flat' ,
243+ leaf_capacity = 1 ,
244+ leaf_radius = None ,
245+ pivoting = None ,
246+ ):
247+ self .n_intervals = n_intervals
248+ self .overlap_frac = overlap_frac
249+ self .kind = kind
250+ self .leaf_capacity = leaf_capacity
251+ self .leaf_radius = leaf_radius
252+ self .pivoting = pivoting
253+
254+ def fit (self , X ):
255+ if self .overlap_frac is None :
256+ dim = 1 if X .ndim == 1 else X .shape [1 ]
257+ self .__overlap_frac = self ._get_overlap_frac (dim , 0.5 )
258+ else :
259+ self .__overlap_frac = self .overlap_frac
260+ self .__n_intervals = self .n_intervals
261+ if (self .__overlap_frac <= 0.0 ) or (self .__overlap_frac > 0.5 ):
262+ warn_user ('The parameter overlap_frac is expected within range (0.0, 0.5]' )
263+ self .__min , self .__max , self .__delta = self ._get_bounds (X )
264+ radius = 1.0 / (2.0 - 2.0 * self .__overlap_frac )
265+ self .__cover = BallCover (
266+ radius ,
267+ metric = _Pullback (self ._gamma_n , chebyshev ()),
268+ kind = self .kind ,
269+ leaf_capacity = self .leaf_capacity ,
270+ leaf_radius = self .leaf_radius ,
271+ pivoting = self .pivoting ,
272+ )
273+ self .__cover .fit (X )
274+
275+ def get_center (self , x ):
276+ cell = self .__n_intervals * (x - self .__min ) // self .__delta
277+ center = self ._gamma_n_inv (_rho (self ._gamma_n (x )))
278+ return tuple (cell ), center
279+
280+ def get_cell (self , x ):
281+ _ , center = self .get_center (x )
282+ return self .__cover .search (center )
283+
284+ def _get_overlap_frac (self , dim , overlap_vol_frac ):
285+ beta = math .pow (1.0 - overlap_vol_frac , 1.0 / dim )
286+ return 1.0 - 1.0 / (2.0 - beta )
287+
288+ def _gamma_n (self , x ):
289+ return self .__n_intervals * (x - self .__min ) / self .__delta
290+
291+ def _gamma_n_inv (self , x ):
292+ return self .__min + self .__delta * x / self .__n_intervals
293+
294+ def _get_bounds (self , X ):
295+ if (X is None ) or len (X ) == 0 :
296+ return
297+ minimum , maximum = X [0 ], X [0 ]
298+ eps = np .finfo (np .float64 ).eps
299+ for w in X :
300+ minimum = np .minimum (minimum , np .array (w ))
301+ maximum = np .maximum (maximum , np .array (w ))
302+ _min = np .nan_to_num (minimum , nan = - float (eps ))
303+ _max = np .nan_to_num (maximum , nan = float (eps ))
304+ _delta = np .maximum (eps , _max - _min )
305+ return _min , _max , _delta
306+
307+
236308class CubicalCover (Proximity ):
237309 """
238310 Cover algorithm based on the `cubical proximity function`, covering data
@@ -291,35 +363,9 @@ def __init__(
291363 self .leaf_radius = leaf_radius
292364 self .pivoting = pivoting
293365
294- def _gamma_n (self , x ):
295- return self .__n_intervals * (x - self .__min ) / self .__delta
296-
297- def _gamma_n_inv (self , x ):
298- return self .__min + self .__delta * x / self .__n_intervals
299-
300- def _phi (self , x ):
301- return self ._gamma_n_inv (_rho (self ._gamma_n (x )))
302-
303- def _get_bounds (self , data ):
304- if (data is None ) or len (data ) == 0 :
305- return
306- minimum , maximum = data [0 ], data [0 ]
307- eps = np .finfo (np .float64 ).eps
308- for w in data :
309- minimum = np .minimum (minimum , np .array (w ))
310- maximum = np .maximum (maximum , np .array (w ))
311- _min = np .nan_to_num (minimum , nan = - float (eps ))
312- _max = np .nan_to_num (maximum , nan = float (eps ))
313- _delta = np .maximum (eps , _max - _min )
314- return _min , _max , _delta
315-
316366 def _convert (self , X ):
317367 return np .asarray (X ).reshape (len (X ), - 1 ).astype (float )
318368
319- def _get_overlap_frac (self , dim , overlap_vol_frac ):
320- beta = math .pow (1.0 - overlap_vol_frac , 1.0 / dim )
321- return 1.0 - 1.0 / (2.0 - beta )
322-
323369 def fit (self , X ):
324370 """
325371 Train internal parameters.
@@ -332,27 +378,16 @@ def fit(self, X):
332378 :return: The object itself.
333379 :rtype: self
334380 """
335- X = np .asarray (X )
336- if self .overlap_frac is None :
337- dim = 1 if X .ndim == 1 else X .shape [1 ]
338- self .__overlap_frac = self ._get_overlap_frac (dim , 0.5 )
339- else :
340- self .__overlap_frac = self .overlap_frac
341- if (self .__overlap_frac <= 0.0 ) or (self .__overlap_frac > 0.5 ):
342- warn_user ('The parameter overlap_frac is expected within range (0.0, 0.5]' )
343- self .__n_intervals = self .n_intervals
344- self .__radius = 1.0 / (2.0 - 2.0 * self .__overlap_frac )
345381 XX = self ._convert (X )
346- self .__ball_proximity = BallCover (
347- self .__radius ,
348- metric = _Pullback ( self ._gamma_n , chebyshev ()) ,
382+ self .__grid_overlap = _GridOverlap (
383+ n_intervals = self .n_intervals ,
384+ overlap_frac = self .overlap_frac ,
349385 kind = self .kind ,
350386 leaf_capacity = self .leaf_capacity ,
351387 leaf_radius = self .leaf_radius ,
352388 pivoting = self .pivoting ,
353389 )
354- self .__min , self .__max , self .__delta = self ._get_bounds (XX )
355- self .__ball_proximity .fit (XX )
390+ self .__grid_overlap .fit (XX )
356391 return self
357392
358393 def search (self , x ):
@@ -367,7 +402,7 @@ def search(self, x):
367402 :return: The indices of the neighbors contained in the dataset.
368403 :rtype: list[int]
369404 """
370- return self .__ball_proximity . search ( self . _phi ( x ) )
405+ return self .__grid_overlap . get_cell ( x )
371406
372407
373408class StandardCover (Cover ):
@@ -388,73 +423,25 @@ def __init__(
388423 self .leaf_radius = leaf_radius
389424 self .pivoting = pivoting
390425
391- def _gamma_n (self , x ):
392- return self .n_intervals * (x - self .__min ) / self .__delta
393-
394- def _gamma_n_inv (self , x ):
395- return self .__min + self .__delta * x / self .n_intervals
396-
397- def _phi (self , x ):
398- return self ._gamma_n_inv (_rho (self ._gamma_n (x )))
399-
400- def _get_bounds (self , data ):
401- if (data is None ) or len (data ) == 0 :
402- return
403- minimum , maximum = data [0 ], data [0 ]
404- eps = np .finfo (np .float64 ).eps
405- for w in data :
406- minimum = np .minimum (minimum , np .array (w ))
407- maximum = np .maximum (maximum , np .array (w ))
408- _min = np .nan_to_num (minimum , nan = - float (eps ))
409- _max = np .nan_to_num (maximum , nan = float (eps ))
410- _delta = np .maximum (eps , _max - _min )
411- return _min , _max , _delta
412-
413- def _convert (self , X ):
414- return np .asarray (X ).reshape (len (X ), - 1 ).astype (float )
415-
416- def _get_overlap_frac (self , dim , overlap_vol_frac ):
417- beta = math .pow (1.0 - overlap_vol_frac , 1.0 / dim )
418- return 1.0 - 1.0 / (2.0 - beta )
419-
420- def _cubical_landmarks (self , X ):
421- lmrks = {}
422- for x in X :
423- lmrk = self ._landmark (x )
424- if lmrk not in lmrks :
425- lmrks [lmrk ] = x
426- return lmrks
427-
428- def _landmark (self , x ):
429- cell = self .n_intervals * (x - self .__min ) // self .__delta
430- return tuple (cell )
431-
432426 def apply (self , X ):
433427 X = np .asarray (X )
434- if self .overlap_frac is None :
435- dim = 1 if X .ndim == 1 else X .shape [1 ]
436- _overlap_frac = self ._get_overlap_frac (dim , 0.5 )
437- else :
438- _overlap_frac = self .overlap_frac
439- if (_overlap_frac <= 0.0 ) or (_overlap_frac > 0.5 ):
440- warn_user ('The parameter overlap_frac is expected within range (0.0, 0.5]' )
441- self .__min , self .__max , self .__delta = self ._get_bounds (X )
442-
443- _radius = 1.0 / (2.0 - 2.0 * _overlap_frac )
444-
445- _ball_cover = BallCover (
446- _radius ,
447- metric = _Pullback (self ._gamma_n , chebyshev ()),
428+ _grid_overlap = _GridOverlap (
429+ n_intervals = self .n_intervals ,
430+ overlap_frac = self .overlap_frac ,
448431 kind = self .kind ,
449432 leaf_capacity = self .leaf_capacity ,
450433 leaf_radius = self .leaf_radius ,
451434 pivoting = self .pivoting ,
452435 )
453- _ball_cover .fit (X )
436+ _grid_overlap .fit (X )
454437
455- lmrks_to_cover = self ._cubical_landmarks (X )
438+ lmrks_to_cover = {}
439+ for x in X :
440+ lmrk , center = _grid_overlap .get_center (x )
441+ if lmrk not in lmrks_to_cover :
442+ lmrks_to_cover [lmrk ] = x
456443 while lmrks_to_cover :
457- _ , lmrk = lmrks_to_cover .popitem ()
458- neigh_ids = _ball_cover . search ( self . _phi ( lmrk ) )
444+ _ , x = lmrks_to_cover .popitem ()
445+ neigh_ids = _grid_overlap . get_cell ( x )
459446 if neigh_ids :
460447 yield neigh_ids
0 commit comments