99from __future__ import annotations
1010
1111import math
12- from typing import Any , Callable , Dict , List , Optional , Union
12+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1313
1414import numpy as np
15+ from numpy .typing import NDArray
1516
1617from tdamapper ._common import ParamsMixin , warn_user
1718from tdamapper .core import ArrayLike , PointLike
@@ -62,9 +63,8 @@ class BallSearch(ParamsMixin):
6263 Acceptable values are None, 'random', or 'furthest'. Defaults to None.
6364 """
6465
65- _radius : float
66- _data : List [tuple [int , Any ]]
6766 _vptree : VPTree
67+ _radius : float
6868
6969 def __init__ (
7070 self ,
@@ -96,15 +96,17 @@ def fit(self, X: ArrayLike) -> BallSearch:
9696 :return: The object itself.
9797 """
9898 metric = get_metric (self .metric , ** (self .metric_params or {}))
99+ metric_pullback = _Pullback (_snd , metric )
100+ data = list (enumerate (X ))
101+ leaf_radius = self .leaf_radius or self .radius
99102 self ._radius = self .radius
100- self ._data = list (enumerate (X ))
101103 self ._vptree = VPTree (
102- self . _data ,
103- metric = _Pullback ( _snd , metric ) ,
104+ data ,
105+ metric = metric_pullback ,
104106 metric_params = None ,
105107 kind = self .kind ,
106108 leaf_capacity = self .leaf_capacity ,
107- leaf_radius = self . leaf_radius or self . radius ,
109+ leaf_radius = leaf_radius ,
108110 pivoting = self .pivoting ,
109111 )
110112 return self
@@ -160,7 +162,6 @@ class KNNSearch(ParamsMixin):
160162 """
161163
162164 _neighbors : int
163- _data : List [tuple [int , Any ]]
164165 _vptree : VPTree
165166
166167 def __init__ (
@@ -193,14 +194,16 @@ def fit(self, X: ArrayLike) -> KNNSearch:
193194 :return: The object itself.
194195 """
195196 metric = get_metric (self .metric , ** (self .metric_params or {}))
197+ metric_pullback = _Pullback (_snd , metric )
198+ data = list (enumerate (X ))
199+ leaf_capacity = self .leaf_capacity or self .neighbors
196200 self ._neighbors = self .neighbors
197- self ._data = list (enumerate (X ))
198201 self ._vptree = VPTree (
199- self . _data ,
200- metric = _Pullback ( _snd , metric ) ,
202+ data ,
203+ metric = metric_pullback ,
201204 metric_params = None ,
202205 kind = self .kind ,
203- leaf_capacity = self . leaf_capacity or self . neighbors ,
206+ leaf_capacity = leaf_capacity ,
204207 leaf_radius = self .leaf_radius ,
205208 pivoting = self .pivoting ,
206209 )
@@ -275,19 +278,22 @@ def __init__(
275278 self .leaf_radius = leaf_radius
276279 self .pivoting = pivoting
277280
278- def _get_center (self , x ):
281+ def _get_center (
282+ self ,
283+ x : NDArray [np .float64 ],
284+ ) -> Tuple [Tuple [float ], NDArray [np .float64 ]]:
279285 offset = self ._offset (x )
280286 center = self ._phi (x )
281287 return tuple (offset ), center
282288
283- def _get_overlap_frac (self , dim , overlap_vol_frac ) :
289+ def _get_overlap_frac (self , dim : int , overlap_vol_frac : float ) -> float :
284290 beta = math .pow (1.0 - overlap_vol_frac , 1.0 / dim )
285291 return 1.0 - 1.0 / (2.0 - beta )
286292
287- def _offset (self , x ) :
293+ def _offset (self , x : NDArray [ np . float64 ]) -> NDArray [ np . float64 ] :
288294 return np .minimum (self ._n_intervals - 1 , np .floor (self ._gamma_n (x )))
289295
290- def _phi (self , x ) :
296+ def _phi (self , x : NDArray [ np . float64 ]) -> NDArray [ np . float64 ] :
291297 offset = self ._offset (x )
292298 return self ._gamma_n_inv (0.5 + offset )
293299
@@ -297,9 +303,11 @@ def _gamma_n(self, x):
297303 def _gamma_n_inv (self , x ):
298304 return self ._min + self ._delta * x / self ._n_intervals
299305
300- def _get_bounds (self , X ):
306+ def _get_bounds (
307+ self , X : NDArray [np .float64 ]
308+ ) -> Optional [Tuple [NDArray [np .float64 ], NDArray [np .float64 ], NDArray [np .float64 ]]]:
301309 if (X is None ) or len (X ) == 0 :
302- return
310+ return None
303311 _min , _max = X [0 ], X [0 ]
304312 eps = np .finfo (np .float64 ).eps
305313 _min = np .min (X , axis = 0 )
@@ -308,7 +316,7 @@ def _get_bounds(self, X):
308316 _delta [(_delta >= - eps ) & (_delta <= eps )] = self ._n_intervals
309317 return _min , _max , _delta
310318
311- def fit (self , X : ArrayLike ) -> CubicalSearch :
319+ def fit (self , X : NDArray [ np . float64 ] ) -> CubicalSearch :
312320 """
313321 Train internal parameters.
314322
@@ -318,18 +326,22 @@ def fit(self, X: ArrayLike) -> CubicalSearch:
318326 :param X: A dataset of n points.
319327 :return: The object itself.
320328 """
329+ if self .overlap_frac is not None and self .overlap_frac <= 0.0 :
330+ raise ValueError ("The parameter overlap_frac is expected to be > 0.0" )
331+ if self .overlap_frac is not None and self .overlap_frac > 0.5 :
332+ warn_user ("The parameter overlap_frac is expected to be <= 0.5" )
321333 X = np .asarray (X ).reshape (len (X ), - 1 ).astype (float )
322- if self .overlap_frac is None :
323- dim = 1 if X .ndim == 1 else X .shape [1 ]
324- self ._overlap_frac = self ._get_overlap_frac (dim , 0.5 )
325- else :
326- self ._overlap_frac = self .overlap_frac
334+ dim = 1 if X .ndim == 1 else X .shape [1 ]
327335 self ._n_intervals = self .n_intervals
328- if self ._overlap_frac <= 0.0 :
329- raise ValueError ("The parameter overlap_frac is expected to be " "> 0.0" )
330- if self ._overlap_frac > 0.5 :
331- warn_user ("The parameter overlap_frac is expected to be <= 0.5" )
332- self ._min , self ._max , self ._delta = self ._get_bounds (X )
336+ self ._overlap_frac = (
337+ self .overlap_frac
338+ if self .overlap_frac is not None
339+ else self ._get_overlap_frac (dim , 0.5 )
340+ )
341+ bounds = self ._get_bounds (X )
342+ if bounds is None :
343+ raise ValueError ("The dataset is empty or not properly defined." )
344+ self ._min , self ._max , self ._delta = bounds
333345 radius = 1.0 / (2.0 - 2.0 * self ._overlap_frac )
334346 self ._ball_search = BallSearch (
335347 radius ,
@@ -342,7 +354,7 @@ def fit(self, X: ArrayLike) -> CubicalSearch:
342354 self ._ball_search .fit (X )
343355 return self
344356
345- def search (self , x : PointLike ) -> List [int ]:
357+ def search (self , x : NDArray [ np . float64 ] ) -> List [int ]:
346358 """
347359 Return a list of neighbors for the query point.
348360
@@ -402,7 +414,9 @@ def __init__(
402414 pivoting = pivoting ,
403415 )
404416
405- def landmarks (self , X : ArrayLike ) -> Dict :
417+ def landmarks (
418+ self , X : NDArray [np .float64 ]
419+ ) -> Dict [Tuple [float ], NDArray [np .float64 ]]:
406420 """
407421 Identify unique hypercubes based on the centers of the hypercubes that
408422 intersect the dataset.
0 commit comments