Skip to content

Commit b9021aa

Browse files
committed
Set default overlap_frac as None
1 parent 46c95b1 commit b9021aa

1 file changed

Lines changed: 16 additions & 4 deletions

File tree

src/tdamapper/cover.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Indeed, the overlaps of the open subsets define the edges of the Mapper graph.
77
"""
88

9+
import math
910
import numpy as np
1011

1112
from tdamapper.core import Proximity
@@ -247,7 +248,9 @@ class CubicalCover(Proximity):
247248
Defaults to 1.
248249
:type n_intervals: int
249250
:param overlap_frac: The fraction of overlap between adjacent intervals on
250-
each dimension, must be in the range (0.0, 1.0). Defaults to 0.5.
251+
each dimension, must be in the range (0.0, 0.5]. If not specified, the
252+
overlap_frac is computed such that the volume of the overlap within
253+
each hypercube is half the total volume. Defaults to None.
251254
:type overlap_frac: float
252255
:param metric: The metric used to define the distance between points.
253256
Accepts any value compatible with `tdamapper.utils.metrics.get_metric`.
@@ -287,8 +290,6 @@ def __init__(
287290
self.leaf_capacity = leaf_capacity
288291
self.leaf_radius = leaf_radius
289292
self.pivoting = pivoting
290-
if (self.overlap_frac <= 0.0) or (self.overlap_frac > 0.5):
291-
warn_user('The parameter overlap_frac is expected within range (0.0, 0.5]')
292293

293294
def _gamma_n(self, x):
294295
return self.__n_intervals * (x - self.__min) / self.__delta
@@ -315,6 +316,10 @@ def _get_bounds(self, data):
315316
def _convert(self, X):
316317
return np.asarray(X).reshape(len(X), -1).astype(float)
317318

319+
def _get_overlap_frac(self, dim, overlap_vol_frac):
320+
beta = math.pow(1.0 - overlap_vol_frac, 1.0 / dim)
321+
return 1.0 - 1.0 / (2.0 - beta)
322+
318323
def fit(self, X):
319324
"""
320325
Train internal parameters.
@@ -327,7 +332,14 @@ def fit(self, X):
327332
:return: The object itself.
328333
:rtype: self
329334
"""
330-
self.__overlap_frac = self.overlap_frac
335+
X = np.asarray(X)
336+
if self.overlap_frac is None:
337+
dim = 1 if X.ndim == 1 else X.shape[1]
338+
self.__overlap_frac = self._get_overlap_frac(dim, 0.5)
339+
else:
340+
self.__overlap_frac = self.overlap_frac
341+
if (self.__overlap_frac <= 0.0) or (self.__overlap_frac > 0.5):
342+
warn_user('The parameter overlap_frac is expected within range (0.0, 0.5]')
331343
self.__n_intervals = self.n_intervals
332344
self.__radius = 1.0 / (2.0 - 2.0 * self.__overlap_frac)
333345
XX = self._convert(X)

0 commit comments

Comments
 (0)