1414import numpy as np
1515from numpy .typing import NDArray
1616
17- from tdamapper ._common import ArrayLike , ParamsMixin , warn_user
17+ from tdamapper ._common import ArrayLike , ParamsMixin , PointLike , warn_user
1818from tdamapper .core import Proximity
1919from tdamapper .utils .metrics import Metric , chebyshev , get_metric
2020from tdamapper .utils .vptree import VPTree
@@ -282,7 +282,7 @@ def __init__(
282282 self .leaf_radius = leaf_radius
283283 self .pivoting = pivoting
284284
285- def _get_center (self , x : NDArray [ np . float64 ] ) -> tuple [tuple , NDArray ]:
285+ def _get_center (self , x : PointLike ) -> tuple [tuple , NDArray [ np . float64 ] ]:
286286 offset = self ._offset (x )
287287 center = self ._phi (x )
288288 return tuple (offset ), center
@@ -291,17 +291,17 @@ def _get_overlap_frac(self, dim: int, overlap_vol_frac: float) -> float:
291291 beta = math .pow (1.0 - overlap_vol_frac , 1.0 / dim )
292292 return 1.0 - 1.0 / (2.0 - beta )
293293
294- def _offset (self , x : NDArray [ np . float64 ] ) -> NDArray [np .float64 ]:
294+ def _offset (self , x : PointLike ) -> NDArray [np .float64 ]:
295295 return np .minimum (self ._n_intervals - 1 , np .floor (self ._gamma_n (x )))
296296
297- def _phi (self , x : NDArray [ np . float64 ] ) -> NDArray [np .float64 ]:
297+ def _phi (self , x : PointLike ) -> NDArray [np .float64 ]:
298298 offset = self ._offset (x )
299299 return self ._gamma_n_inv (0.5 + offset )
300300
301- def _gamma_n (self , x : NDArray [ np . float64 ] ) -> NDArray [np .float64 ]:
301+ def _gamma_n (self , x : PointLike ) -> NDArray [np .float64 ]:
302302 return self ._n_intervals * (x - self ._min ) / self ._delta
303303
304- def _gamma_n_inv (self , x : NDArray [ np . float64 ] ) -> NDArray [np .float64 ]:
304+ def _gamma_n_inv (self , x : PointLike ) -> NDArray [np .float64 ]:
305305 return self ._min + self ._delta * x / self ._n_intervals
306306
307307 def _get_bounds (
@@ -481,15 +481,15 @@ def __init__(
481481 pivoting = pivoting ,
482482 )
483483
484- def _landmarks (self , X : NDArray [ np . float64 ] ) -> dict [tuple , NDArray [ np . float64 ] ]:
484+ def _landmarks (self , X : ArrayLike ) -> dict [tuple , PointLike ]:
485485 lmrks = {}
486486 for x in X :
487487 lmrk , _ = self ._get_center (x )
488488 if lmrk not in lmrks :
489489 lmrks [lmrk ] = x
490490 return lmrks
491491
492- def apply (self , X : NDArray [ np . float64 ] ) -> Generator [list [int ]]:
492+ def apply (self , X : ArrayLike ) -> Generator [list [int ]]:
493493 """
494494 Covers the dataset using landmarks.
495495
@@ -595,7 +595,7 @@ def _get_cubical_cover(self) -> Union[ProximityCubicalCover, StandardCubicalCove
595595 "'proximity'."
596596 )
597597
598- def fit (self , X : NDArray [ np . float64 ] ) -> CubicalCover :
598+ def fit (self , X : ArrayLike ) -> CubicalCover :
599599 """
600600 Train internal parameters.
601601
@@ -625,7 +625,7 @@ def search(self, x: Any) -> list[int]:
625625 """
626626 return self ._cubical_cover .search (x )
627627
628- def apply (self , X : NDArray [ np . float64 ] ) -> Generator [list [int ]]:
628+ def apply (self , X : ArrayLike ) -> Generator [list [int ]]:
629629 """
630630 Covers the dataset using hypercubes.
631631
0 commit comments