Skip to content

Commit 6f14ae4

Browse files
committed
Improved types
1 parent a0dd7b7 commit 6f14ae4

1 file changed

Lines changed: 45 additions & 31 deletions

File tree

src/tdamapper/search.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
from __future__ import annotations
1010

1111
import math
12-
from typing import Any, Callable, Dict, List, Optional, Union
12+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1313

1414
import numpy as np
15+
from numpy.typing import NDArray
1516

1617
from tdamapper._common import ParamsMixin, warn_user
1718
from tdamapper.core import ArrayLike, PointLike
@@ -62,9 +63,8 @@ class BallSearch(ParamsMixin):
6263
Acceptable values are None, 'random', or 'furthest'. Defaults to None.
6364
"""
6465

65-
_radius: float
66-
_data: List[tuple[int, Any]]
6766
_vptree: VPTree
67+
_radius: float
6868

6969
def __init__(
7070
self,
@@ -96,15 +96,17 @@ def fit(self, X: ArrayLike) -> BallSearch:
9696
:return: The object itself.
9797
"""
9898
metric = get_metric(self.metric, **(self.metric_params or {}))
99+
metric_pullback = _Pullback(_snd, metric)
100+
data = list(enumerate(X))
101+
leaf_radius = self.leaf_radius or self.radius
99102
self._radius = self.radius
100-
self._data = list(enumerate(X))
101103
self._vptree = VPTree(
102-
self._data,
103-
metric=_Pullback(_snd, metric),
104+
data,
105+
metric=metric_pullback,
104106
metric_params=None,
105107
kind=self.kind,
106108
leaf_capacity=self.leaf_capacity,
107-
leaf_radius=self.leaf_radius or self.radius,
109+
leaf_radius=leaf_radius,
108110
pivoting=self.pivoting,
109111
)
110112
return self
@@ -160,7 +162,6 @@ class KNNSearch(ParamsMixin):
160162
"""
161163

162164
_neighbors: int
163-
_data: List[tuple[int, Any]]
164165
_vptree: VPTree
165166

166167
def __init__(
@@ -193,14 +194,16 @@ def fit(self, X: ArrayLike) -> KNNSearch:
193194
:return: The object itself.
194195
"""
195196
metric = get_metric(self.metric, **(self.metric_params or {}))
197+
metric_pullback = _Pullback(_snd, metric)
198+
data = list(enumerate(X))
199+
leaf_capacity = self.leaf_capacity or self.neighbors
196200
self._neighbors = self.neighbors
197-
self._data = list(enumerate(X))
198201
self._vptree = VPTree(
199-
self._data,
200-
metric=_Pullback(_snd, metric),
202+
data,
203+
metric=metric_pullback,
201204
metric_params=None,
202205
kind=self.kind,
203-
leaf_capacity=self.leaf_capacity or self.neighbors,
206+
leaf_capacity=leaf_capacity,
204207
leaf_radius=self.leaf_radius,
205208
pivoting=self.pivoting,
206209
)
@@ -275,19 +278,22 @@ def __init__(
275278
self.leaf_radius = leaf_radius
276279
self.pivoting = pivoting
277280

278-
def _get_center(self, x):
281+
def _get_center(
282+
self,
283+
x: NDArray[np.float64],
284+
) -> Tuple[Tuple[float], NDArray[np.float64]]:
279285
offset = self._offset(x)
280286
center = self._phi(x)
281287
return tuple(offset), center
282288

283-
def _get_overlap_frac(self, dim, overlap_vol_frac):
289+
def _get_overlap_frac(self, dim: int, overlap_vol_frac: float) -> float:
284290
beta = math.pow(1.0 - overlap_vol_frac, 1.0 / dim)
285291
return 1.0 - 1.0 / (2.0 - beta)
286292

287-
def _offset(self, x):
293+
def _offset(self, x: NDArray[np.float64]) -> NDArray[np.float64]:
288294
return np.minimum(self._n_intervals - 1, np.floor(self._gamma_n(x)))
289295

290-
def _phi(self, x):
296+
def _phi(self, x: NDArray[np.float64]) -> NDArray[np.float64]:
291297
offset = self._offset(x)
292298
return self._gamma_n_inv(0.5 + offset)
293299

@@ -297,9 +303,11 @@ def _gamma_n(self, x):
297303
def _gamma_n_inv(self, x):
298304
return self._min + self._delta * x / self._n_intervals
299305

300-
def _get_bounds(self, X):
306+
def _get_bounds(
307+
self, X: NDArray[np.float64]
308+
) -> Optional[Tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]]]:
301309
if (X is None) or len(X) == 0:
302-
return
310+
return None
303311
_min, _max = X[0], X[0]
304312
eps = np.finfo(np.float64).eps
305313
_min = np.min(X, axis=0)
@@ -308,7 +316,7 @@ def _get_bounds(self, X):
308316
_delta[(_delta >= -eps) & (_delta <= eps)] = self._n_intervals
309317
return _min, _max, _delta
310318

311-
def fit(self, X: ArrayLike) -> CubicalSearch:
319+
def fit(self, X: NDArray[np.float64]) -> CubicalSearch:
312320
"""
313321
Train internal parameters.
314322
@@ -318,18 +326,22 @@ def fit(self, X: ArrayLike) -> CubicalSearch:
318326
:param X: A dataset of n points.
319327
:return: The object itself.
320328
"""
329+
if self.overlap_frac is not None and self.overlap_frac <= 0.0:
330+
raise ValueError("The parameter overlap_frac is expected to be > 0.0")
331+
if self.overlap_frac is not None and self.overlap_frac > 0.5:
332+
warn_user("The parameter overlap_frac is expected to be <= 0.5")
321333
X = np.asarray(X).reshape(len(X), -1).astype(float)
322-
if self.overlap_frac is None:
323-
dim = 1 if X.ndim == 1 else X.shape[1]
324-
self._overlap_frac = self._get_overlap_frac(dim, 0.5)
325-
else:
326-
self._overlap_frac = self.overlap_frac
334+
dim = 1 if X.ndim == 1 else X.shape[1]
327335
self._n_intervals = self.n_intervals
328-
if self._overlap_frac <= 0.0:
329-
raise ValueError("The parameter overlap_frac is expected to be " "> 0.0")
330-
if self._overlap_frac > 0.5:
331-
warn_user("The parameter overlap_frac is expected to be <= 0.5")
332-
self._min, self._max, self._delta = self._get_bounds(X)
336+
self._overlap_frac = (
337+
self.overlap_frac
338+
if self.overlap_frac is not None
339+
else self._get_overlap_frac(dim, 0.5)
340+
)
341+
bounds = self._get_bounds(X)
342+
if bounds is None:
343+
raise ValueError("The dataset is empty or not properly defined.")
344+
self._min, self._max, self._delta = bounds
333345
radius = 1.0 / (2.0 - 2.0 * self._overlap_frac)
334346
self._ball_search = BallSearch(
335347
radius,
@@ -342,7 +354,7 @@ def fit(self, X: ArrayLike) -> CubicalSearch:
342354
self._ball_search.fit(X)
343355
return self
344356

345-
def search(self, x: PointLike) -> List[int]:
357+
def search(self, x: NDArray[np.float64]) -> List[int]:
346358
"""
347359
Return a list of neighbors for the query point.
348360
@@ -402,7 +414,9 @@ def __init__(
402414
pivoting=pivoting,
403415
)
404416

405-
def landmarks(self, X: ArrayLike) -> Dict:
417+
def landmarks(
418+
self, X: NDArray[np.float64]
419+
) -> Dict[Tuple[float], NDArray[np.float64]]:
406420
"""
407421
Identify unique hypercubes based on the centers of the hypercubes that
408422
intersect the dataset.

0 commit comments

Comments
 (0)