Skip to content

Commit ff23476

Browse files
committed
Refactored common code
1 parent 4c9adda commit ff23476

File tree

1 file changed

+88
-101
lines changed

1 file changed

+88
-101
lines changed

src/tdamapper/cover.py

Lines changed: 88 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,78 @@ def search(self, x):
233233
return [x for (x, _) in neighs]
234234

235235

236+
class _GridOverlap:
237+
238+
def __init__(
239+
self,
240+
n_intervals=1,
241+
overlap_frac=None,
242+
kind='flat',
243+
leaf_capacity=1,
244+
leaf_radius=None,
245+
pivoting=None,
246+
):
247+
self.n_intervals = n_intervals
248+
self.overlap_frac = overlap_frac
249+
self.kind = kind
250+
self.leaf_capacity = leaf_capacity
251+
self.leaf_radius = leaf_radius
252+
self.pivoting = pivoting
253+
254+
def fit(self, X):
255+
if self.overlap_frac is None:
256+
dim = 1 if X.ndim == 1 else X.shape[1]
257+
self.__overlap_frac = self._get_overlap_frac(dim, 0.5)
258+
else:
259+
self.__overlap_frac = self.overlap_frac
260+
self.__n_intervals = self.n_intervals
261+
if (self.__overlap_frac <= 0.0) or (self.__overlap_frac > 0.5):
262+
warn_user('The parameter overlap_frac is expected within range (0.0, 0.5]')
263+
self.__min, self.__max, self.__delta = self._get_bounds(X)
264+
radius = 1.0 / (2.0 - 2.0 * self.__overlap_frac)
265+
self.__cover = BallCover(
266+
radius,
267+
metric=_Pullback(self._gamma_n, chebyshev()),
268+
kind=self.kind,
269+
leaf_capacity=self.leaf_capacity,
270+
leaf_radius=self.leaf_radius,
271+
pivoting=self.pivoting,
272+
)
273+
self.__cover.fit(X)
274+
275+
def get_center(self, x):
276+
cell = self.__n_intervals * (x - self.__min) // self.__delta
277+
center = self._gamma_n_inv(_rho(self._gamma_n(x)))
278+
return tuple(cell), center
279+
280+
def get_cell(self, x):
281+
_, center = self.get_center(x)
282+
return self.__cover.search(center)
283+
284+
def _get_overlap_frac(self, dim, overlap_vol_frac):
285+
beta = math.pow(1.0 - overlap_vol_frac, 1.0 / dim)
286+
return 1.0 - 1.0 / (2.0 - beta)
287+
288+
def _gamma_n(self, x):
289+
return self.__n_intervals * (x - self.__min) / self.__delta
290+
291+
def _gamma_n_inv(self, x):
292+
return self.__min + self.__delta * x / self.__n_intervals
293+
294+
def _get_bounds(self, X):
295+
if (X is None) or len(X) == 0:
296+
return
297+
minimum, maximum = X[0], X[0]
298+
eps = np.finfo(np.float64).eps
299+
for w in X:
300+
minimum = np.minimum(minimum, np.array(w))
301+
maximum = np.maximum(maximum, np.array(w))
302+
_min = np.nan_to_num(minimum, nan=-float(eps))
303+
_max = np.nan_to_num(maximum, nan=float(eps))
304+
_delta = np.maximum(eps, _max - _min)
305+
return _min, _max, _delta
306+
307+
236308
class CubicalCover(Proximity):
237309
"""
238310
Cover algorithm based on the `cubical proximity function`, covering data
@@ -291,35 +363,9 @@ def __init__(
291363
self.leaf_radius = leaf_radius
292364
self.pivoting = pivoting
293365

294-
def _gamma_n(self, x):
295-
return self.__n_intervals * (x - self.__min) / self.__delta
296-
297-
def _gamma_n_inv(self, x):
298-
return self.__min + self.__delta * x / self.__n_intervals
299-
300-
def _phi(self, x):
301-
return self._gamma_n_inv(_rho(self._gamma_n(x)))
302-
303-
def _get_bounds(self, data):
304-
if (data is None) or len(data) == 0:
305-
return
306-
minimum, maximum = data[0], data[0]
307-
eps = np.finfo(np.float64).eps
308-
for w in data:
309-
minimum = np.minimum(minimum, np.array(w))
310-
maximum = np.maximum(maximum, np.array(w))
311-
_min = np.nan_to_num(minimum, nan=-float(eps))
312-
_max = np.nan_to_num(maximum, nan=float(eps))
313-
_delta = np.maximum(eps, _max - _min)
314-
return _min, _max, _delta
315-
316366
def _convert(self, X):
317367
return np.asarray(X).reshape(len(X), -1).astype(float)
318368

319-
def _get_overlap_frac(self, dim, overlap_vol_frac):
320-
beta = math.pow(1.0 - overlap_vol_frac, 1.0 / dim)
321-
return 1.0 - 1.0 / (2.0 - beta)
322-
323369
def fit(self, X):
324370
"""
325371
Train internal parameters.
@@ -332,27 +378,16 @@ def fit(self, X):
332378
:return: The object itself.
333379
:rtype: self
334380
"""
335-
X = np.asarray(X)
336-
if self.overlap_frac is None:
337-
dim = 1 if X.ndim == 1 else X.shape[1]
338-
self.__overlap_frac = self._get_overlap_frac(dim, 0.5)
339-
else:
340-
self.__overlap_frac = self.overlap_frac
341-
if (self.__overlap_frac <= 0.0) or (self.__overlap_frac > 0.5):
342-
warn_user('The parameter overlap_frac is expected within range (0.0, 0.5]')
343-
self.__n_intervals = self.n_intervals
344-
self.__radius = 1.0 / (2.0 - 2.0 * self.__overlap_frac)
345381
XX = self._convert(X)
346-
self.__ball_proximity = BallCover(
347-
self.__radius,
348-
metric=_Pullback(self._gamma_n, chebyshev()),
382+
self.__grid_overlap = _GridOverlap(
383+
n_intervals=self.n_intervals,
384+
overlap_frac=self.overlap_frac,
349385
kind=self.kind,
350386
leaf_capacity=self.leaf_capacity,
351387
leaf_radius=self.leaf_radius,
352388
pivoting=self.pivoting,
353389
)
354-
self.__min, self.__max, self.__delta = self._get_bounds(XX)
355-
self.__ball_proximity.fit(XX)
390+
self.__grid_overlap.fit(XX)
356391
return self
357392

358393
def search(self, x):
@@ -367,7 +402,7 @@ def search(self, x):
367402
:return: The indices of the neighbors contained in the dataset.
368403
:rtype: list[int]
369404
"""
370-
return self.__ball_proximity.search(self._phi(x))
405+
return self.__grid_overlap.get_cell(x)
371406

372407

373408
class StandardCover(Cover):
@@ -388,73 +423,25 @@ def __init__(
388423
self.leaf_radius = leaf_radius
389424
self.pivoting = pivoting
390425

391-
def _gamma_n(self, x):
392-
return self.n_intervals * (x - self.__min) / self.__delta
393-
394-
def _gamma_n_inv(self, x):
395-
return self.__min + self.__delta * x / self.n_intervals
396-
397-
def _phi(self, x):
398-
return self._gamma_n_inv(_rho(self._gamma_n(x)))
399-
400-
def _get_bounds(self, data):
401-
if (data is None) or len(data) == 0:
402-
return
403-
minimum, maximum = data[0], data[0]
404-
eps = np.finfo(np.float64).eps
405-
for w in data:
406-
minimum = np.minimum(minimum, np.array(w))
407-
maximum = np.maximum(maximum, np.array(w))
408-
_min = np.nan_to_num(minimum, nan=-float(eps))
409-
_max = np.nan_to_num(maximum, nan=float(eps))
410-
_delta = np.maximum(eps, _max - _min)
411-
return _min, _max, _delta
412-
413-
def _convert(self, X):
414-
return np.asarray(X).reshape(len(X), -1).astype(float)
415-
416-
def _get_overlap_frac(self, dim, overlap_vol_frac):
417-
beta = math.pow(1.0 - overlap_vol_frac, 1.0 / dim)
418-
return 1.0 - 1.0 / (2.0 - beta)
419-
420-
def _cubical_landmarks(self, X):
421-
lmrks = {}
422-
for x in X:
423-
lmrk = self._landmark(x)
424-
if lmrk not in lmrks:
425-
lmrks[lmrk] = x
426-
return lmrks
427-
428-
def _landmark(self, x):
429-
cell = self.n_intervals * (x - self.__min) // self.__delta
430-
return tuple(cell)
431-
432426
def apply(self, X):
433427
X = np.asarray(X)
434-
if self.overlap_frac is None:
435-
dim = 1 if X.ndim == 1 else X.shape[1]
436-
_overlap_frac = self._get_overlap_frac(dim, 0.5)
437-
else:
438-
_overlap_frac = self.overlap_frac
439-
if (_overlap_frac <= 0.0) or (_overlap_frac > 0.5):
440-
warn_user('The parameter overlap_frac is expected within range (0.0, 0.5]')
441-
self.__min, self.__max, self.__delta = self._get_bounds(X)
442-
443-
_radius = 1.0 / (2.0 - 2.0 * _overlap_frac)
444-
445-
_ball_cover = BallCover(
446-
_radius,
447-
metric=_Pullback(self._gamma_n, chebyshev()),
428+
_grid_overlap = _GridOverlap(
429+
n_intervals=self.n_intervals,
430+
overlap_frac=self.overlap_frac,
448431
kind=self.kind,
449432
leaf_capacity=self.leaf_capacity,
450433
leaf_radius=self.leaf_radius,
451434
pivoting=self.pivoting,
452435
)
453-
_ball_cover.fit(X)
436+
_grid_overlap.fit(X)
454437

455-
lmrks_to_cover = self._cubical_landmarks(X)
438+
lmrks_to_cover = {}
439+
for x in X:
440+
lmrk, center = _grid_overlap.get_center(x)
441+
if lmrk not in lmrks_to_cover:
442+
lmrks_to_cover[lmrk] = x
456443
while lmrks_to_cover:
457-
_, lmrk = lmrks_to_cover.popitem()
458-
neigh_ids = _ball_cover.search(self._phi(lmrk))
444+
_, x = lmrks_to_cover.popitem()
445+
neigh_ids = _grid_overlap.get_cell(x)
459446
if neigh_ids:
460447
yield neigh_ids

0 commit comments

Comments
 (0)