2121from tdamapper .utils .vptree import PivotingStrategy , VPTree , VPTreeKind
2222
2323T = TypeVar ("T" )
24- T_contra = TypeVar ("T_contra" , contravariant = True )
2524S = TypeVar ("S" )
25+
26+ T_contra = TypeVar ("T_contra" , contravariant = True )
2627S_contra = TypeVar ("S_contra" , contravariant = True )
2728
29+ CubicalAlgorithm = Literal ["standard" , "proximity" ]
2830
29- class _Pullback (Generic [S_contra , T_contra ]):
31+
32+ class Pullback (Generic [S_contra , T_contra ]):
33+ """
34+ Pullback pseudo-metric function.
35+
36+ This class is used to adapt a metric function that operates on a
37+ transformed space to work with the original space. It applies a function
38+ to the input data before computing the distance, effectively pulling back
39+ the metric to the original space.
40+
41+ Given a function :math:`f: X \t o Y` and a metric
42+ :math:`d: Y \t imes Y \t o \\ mathbb{R}`,
43+ this class defines a new pseudo-metric
44+ :math:`d': X \t imes X \t o \\ mathbb{R}` such that:
45+ :math:`d'(x_1, x_2) = d(f(x_1), f(x_2))`.
46+
47+ When :math:`f` is injective, this pseudo-metric :math:`d'` is a true
48+ metric. If :math:`f` is not injective, it is a pseudo-metric, meaning it
49+ may not satisfy the identity of two objects :math:`x_1`, :math:`x_2` with
50+ :math:`d'(x_1, x_2) = 0`.
51+
52+ :param fun: A function that transforms the input data.
53+ :param dist: A metric function that operates on the transformed data.
54+ """
3055
3156 def __init__ (
3257 self , fun : Callable [[S_contra ], T_contra ], dist : Metric [T_contra ]
@@ -35,10 +60,24 @@ def __init__(
3560 self .dist = dist
3661
3762 def __call__ (self , x : S_contra , y : S_contra ) -> float :
63+ """
64+ Compute the distance between two points in the original space
65+ using the pullback metric.
66+
67+ This method applies the transformation function to both points and
68+ then computes the distance using the provided metric function.
69+
70+ :param x: A point in the original space.
71+ :param y: Another point in the original space.
72+ :return: The distance between the transformed points in the metric space.
73+ """
3874 return self .dist (self .fun (x ), self .fun (y ))
3975
4076
4177def _snd (x : tuple [T , ...]) -> T :
78+ """
79+ Extract the second element from a tuple.
80+ """
4281 return x [1 ]
4382
4483
@@ -108,7 +147,7 @@ def fit(self, X: ArrayRead[T_contra]) -> BallCover[T_contra]:
108147 self ._data = list (enumerate (X ))
109148 self ._vptree = VPTree (
110149 self ._data ,
111- metric = _Pullback (_snd , metric ),
150+ metric = Pullback (_snd , metric ),
112151 kind = self .kind ,
113152 leaf_capacity = self .leaf_capacity ,
114153 leaf_radius = self .leaf_radius or self .radius ,
@@ -214,7 +253,7 @@ def fit(self, X: ArrayRead[T_contra]) -> KNNCover[T_contra]:
214253 self ._data = list (enumerate (X ))
215254 self ._vptree = VPTree (
216255 self ._data ,
217- metric = _Pullback (_snd , metric ),
256+ metric = Pullback (_snd , metric ),
218257 kind = self .kind ,
219258 leaf_capacity = self .leaf_capacity or self .neighbors ,
220259 leaf_radius = self .leaf_radius ,
@@ -252,6 +291,35 @@ def apply(self, X: ArrayRead[T_contra]) -> Iterator[list[int]]:
252291
253292
254293class BaseCubicalCover :
294+ """
295+ Base class for cubical cover algorithms, which cover data with open
296+ hypercubes of uniform size and overlap. This class provides the basic
297+ functionality for cubical covers, including the initialization of parameters
298+ and the methods for computing the center of a hypercube and its overlap.
299+
300+ A hypercube is a multidimensional generalization of a square or a cube.
301+ The size and overlap of the hypercubes are determined by the number of
302+ intervals and the overlap fraction parameters. This class maps each point
303+ to the hypercube with the nearest center.
304+
305+ :param n_intervals: The number of intervals to use for each dimension.
306+ Must be positive and less than or equal to the length of the dataset.
307+ Defaults to 1.
308+ :param overlap_frac: The fraction of overlap between adjacent intervals on
309+ each dimension, must be in the range (0.0, 0.5]. If not specified, the
310+ overlap_frac is computed such that the volume of the overlap within
311+ each hypercube is half the total volume. Defaults to None.
312+ :param kind: Specifies whether to use a flat or a hierarchical vantage
313+ point tree. Acceptable values are 'flat' or 'hierarchical'. Defaults to
314+ 'flat'.
315+ :param leaf_capacity: The maximum number of points in a leaf node of the
316+ vantage point tree. Must be a positive value. Defaults to 1.
317+ :param leaf_radius: The radius of the leaf nodes. If not specified, it
318+ defaults to the value of `radius`. Must be a positive value. Defaults
319+ to None.
320+ :param pivoting: The method used for pivoting in the vantage point tree.
321+ Acceptable values are None, 'random', or 'furthest'. Defaults to None.
322+ """
255323
256324 _overlap_frac : float
257325 _n_intervals : int
@@ -338,7 +406,7 @@ def fit(self, X: ArrayRead[NDArray[np.float_]]) -> BaseCubicalCover:
338406 radius = 1.0 / (2.0 - 2.0 * self ._overlap_frac )
339407 self ._cover = BallCover (
340408 radius ,
341- metric = _Pullback (self ._gamma_n , chebyshev ()),
409+ metric = Pullback (self ._gamma_n , chebyshev ()),
342410 kind = self .kind ,
343411 leaf_capacity = self .leaf_capacity ,
344412 leaf_radius = self .leaf_radius ,
@@ -506,9 +574,6 @@ def apply(self, X: ArrayRead[NDArray[np.float_]]) -> Iterator[list[int]]:
506574 yield neigh_ids
507575
508576
509- CubicalAlgorithm = Literal ["standard" , "proximity" ]
510-
511-
512577class CubicalCover (ParamsMixin ):
513578 """
514579 Wrapper class for cubical cover algorithms, which cover data with open
0 commit comments