Skip to content

Commit a8342d1

Browse files
committed
Improved class hierarchies, reduced boilerplate
1 parent ff23476 commit a8342d1

File tree

1 file changed

+45
-58
lines changed

1 file changed

+45
-58
lines changed

src/tdamapper/cover.py

Lines changed: 45 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ def __init__(
252252
self.pivoting = pivoting
253253

254254
def fit(self, X):
255+
X = np.asarray(X).reshape(len(X), -1).astype(float)
255256
if self.overlap_frac is None:
256257
dim = 1 if X.ndim == 1 else X.shape[1]
257258
self.__overlap_frac = self._get_overlap_frac(dim, 0.5)
@@ -271,16 +272,25 @@ def fit(self, X):
271272
pivoting=self.pivoting,
272273
)
273274
self.__cover.fit(X)
275+
return self
276+
277+
def search(self, x):
278+
center = self._gamma_n_inv(_rho(self._gamma_n(x)))
279+
return self.__cover.search(center)
274280

275-
def get_center(self, x):
281+
def landmarks(self, X):
282+
lmrks = {}
283+
for x in X:
284+
lmrk, center = self._get_center(x)
285+
if lmrk not in lmrks:
286+
lmrks[lmrk] = x
287+
return lmrks
288+
289+
def _get_center(self, x):
276290
cell = self.__n_intervals * (x - self.__min) // self.__delta
277291
center = self._gamma_n_inv(_rho(self._gamma_n(x)))
278292
return tuple(cell), center
279293

280-
def get_cell(self, x):
281-
_, center = self.get_center(x)
282-
return self.__cover.search(center)
283-
284294
def _get_overlap_frac(self, dim, overlap_vol_frac):
285295
beta = math.pow(1.0 - overlap_vol_frac, 1.0 / dim)
286296
return 1.0 - 1.0 / (2.0 - beta)
@@ -294,18 +304,16 @@ def _gamma_n_inv(self, x):
294304
def _get_bounds(self, X):
295305
if (X is None) or len(X) == 0:
296306
return
297-
minimum, maximum = X[0], X[0]
307+
_min, _max = X[0], X[0]
298308
eps = np.finfo(np.float64).eps
299-
for w in X:
300-
minimum = np.minimum(minimum, np.array(w))
301-
maximum = np.maximum(maximum, np.array(w))
302-
_min = np.nan_to_num(minimum, nan=-float(eps))
303-
_max = np.nan_to_num(maximum, nan=float(eps))
304-
_delta = np.maximum(eps, _max - _min)
309+
_min = np.min(X, axis=0)
310+
_max = np.max(X, axis=0)
311+
_delta = _max - _min
312+
_delta[(_delta >= -eps) & (_delta <= eps)] = self.__n_intervals
305313
return _min, _max, _delta
306314

307315

308-
class CubicalCover(Proximity):
316+
class CubicalCover(_GridOverlap, Proximity):
309317
"""
310318
Cover algorithm based on the `cubical proximity function`, covering data
311319
with open hypercubes of uniform size and overlap.
@@ -356,15 +364,14 @@ def __init__(
356364
leaf_radius=None,
357365
pivoting=None,
358366
):
359-
self.n_intervals = n_intervals
360-
self.overlap_frac = overlap_frac
361-
self.kind = kind
362-
self.leaf_capacity = leaf_capacity
363-
self.leaf_radius = leaf_radius
364-
self.pivoting = pivoting
365-
366-
def _convert(self, X):
367-
return np.asarray(X).reshape(len(X), -1).astype(float)
367+
super().__init__(
368+
n_intervals=n_intervals,
369+
overlap_frac=overlap_frac,
370+
kind=kind,
371+
leaf_capacity=leaf_capacity,
372+
leaf_radius=leaf_radius,
373+
pivoting=pivoting,
374+
)
368375

369376
def fit(self, X):
370377
"""
@@ -378,17 +385,8 @@ def fit(self, X):
378385
:return: The object itself.
379386
:rtype: self
380387
"""
381-
XX = self._convert(X)
382-
self.__grid_overlap = _GridOverlap(
383-
n_intervals=self.n_intervals,
384-
overlap_frac=self.overlap_frac,
385-
kind=self.kind,
386-
leaf_capacity=self.leaf_capacity,
387-
leaf_radius=self.leaf_radius,
388-
pivoting=self.pivoting,
389-
)
390-
self.__grid_overlap.fit(XX)
391-
return self
388+
#X = np.asarray(X).reshape(len(X), -1).astype(float)
389+
return super().fit(X)
392390

393391
def search(self, x):
394392
"""
@@ -402,10 +400,10 @@ def search(self, x):
402400
:return: The indices of the neighbors contained in the dataset.
403401
:rtype: list[int]
404402
"""
405-
return self.__grid_overlap.get_cell(x)
403+
return super().search(x)
406404

407405

408-
class StandardCover(Cover):
406+
class StandardCover(_GridOverlap, Cover):
409407

410408
def __init__(
411409
self,
@@ -416,32 +414,21 @@ def __init__(
416414
leaf_radius=None,
417415
pivoting=None,
418416
):
419-
self.n_intervals = n_intervals
420-
self.overlap_frac = overlap_frac
421-
self.kind = kind
422-
self.leaf_capacity = leaf_capacity
423-
self.leaf_radius = leaf_radius
424-
self.pivoting = pivoting
425-
426-
def apply(self, X):
427-
X = np.asarray(X)
428-
_grid_overlap = _GridOverlap(
429-
n_intervals=self.n_intervals,
430-
overlap_frac=self.overlap_frac,
431-
kind=self.kind,
432-
leaf_capacity=self.leaf_capacity,
433-
leaf_radius=self.leaf_radius,
434-
pivoting=self.pivoting,
417+
super().__init__(
418+
n_intervals=n_intervals,
419+
overlap_frac=overlap_frac,
420+
kind=kind,
421+
leaf_capacity=leaf_capacity,
422+
leaf_radius=leaf_radius,
423+
pivoting=pivoting,
435424
)
436-
_grid_overlap.fit(X)
437425

438-
lmrks_to_cover = {}
439-
for x in X:
440-
lmrk, center = _grid_overlap.get_center(x)
441-
if lmrk not in lmrks_to_cover:
442-
lmrks_to_cover[lmrk] = x
426+
def apply(self, X):
427+
#X = np.asarray(X).reshape(len(X), -1).astype(float)
428+
super().fit(X)
429+
lmrks_to_cover = super().landmarks(X)
443430
while lmrks_to_cover:
444431
_, x = lmrks_to_cover.popitem()
445-
neigh_ids = _grid_overlap.get_cell(x)
432+
neigh_ids = self.search(x)
446433
if neigh_ids:
447434
yield neigh_ids

0 commit comments

Comments
 (0)