Skip to content

Commit b9f8eee

Browse files
committed
Improved types
1 parent c0c9226 commit b9f8eee

2 files changed

Lines changed: 12 additions & 12 deletions

File tree

src/tdamapper/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def run_mapper(
190190
logger.error(error)
191191
return None
192192

193-
mapper_config = MapperConfig(**kwargs)
193+
mapper_config: MapperConfig = MapperConfig(**kwargs)
194194

195195
lens_type = mapper_config.lens_type
196196
cover_scale_data = mapper_config.cover_scale_data
@@ -218,7 +218,7 @@ def run_mapper(
218218
elif lens_type == LENS_UMAP:
219219
lens = lens_umap(n_components=lens_umap_n_components)
220220

221-
cover: Cover
221+
cover: Optional[Cover]
222222
if cover_type == COVER_CUBICAL:
223223
cover = CubicalCover(
224224
n_intervals=cover_cubical_n_intervals,

src/tdamapper/cover.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import numpy as np
1515
from numpy.typing import NDArray
1616

17-
from tdamapper._common import ArrayLike, ParamsMixin, warn_user
17+
from tdamapper._common import ArrayLike, ParamsMixin, PointLike, warn_user
1818
from tdamapper.core import Proximity
1919
from tdamapper.utils.metrics import Metric, chebyshev, get_metric
2020
from tdamapper.utils.vptree import VPTree
@@ -282,7 +282,7 @@ def __init__(
282282
self.leaf_radius = leaf_radius
283283
self.pivoting = pivoting
284284

285-
def _get_center(self, x: NDArray[np.float64]) -> tuple[tuple, NDArray]:
285+
def _get_center(self, x: PointLike) -> tuple[tuple, NDArray[np.float64]]:
286286
offset = self._offset(x)
287287
center = self._phi(x)
288288
return tuple(offset), center
@@ -291,17 +291,17 @@ def _get_overlap_frac(self, dim: int, overlap_vol_frac: float) -> float:
291291
beta = math.pow(1.0 - overlap_vol_frac, 1.0 / dim)
292292
return 1.0 - 1.0 / (2.0 - beta)
293293

294-
def _offset(self, x: NDArray[np.float64]) -> NDArray[np.float64]:
294+
def _offset(self, x: PointLike) -> NDArray[np.float64]:
295295
return np.minimum(self._n_intervals - 1, np.floor(self._gamma_n(x)))
296296

297-
def _phi(self, x: NDArray[np.float64]) -> NDArray[np.float64]:
297+
def _phi(self, x: PointLike) -> NDArray[np.float64]:
298298
offset = self._offset(x)
299299
return self._gamma_n_inv(0.5 + offset)
300300

301-
def _gamma_n(self, x: NDArray[np.float64]) -> NDArray[np.float64]:
301+
def _gamma_n(self, x: PointLike) -> NDArray[np.float64]:
302302
return self._n_intervals * (x - self._min) / self._delta
303303

304-
def _gamma_n_inv(self, x: NDArray[np.float64]) -> NDArray[np.float64]:
304+
def _gamma_n_inv(self, x: PointLike) -> NDArray[np.float64]:
305305
return self._min + self._delta * x / self._n_intervals
306306

307307
def _get_bounds(
@@ -481,15 +481,15 @@ def __init__(
481481
pivoting=pivoting,
482482
)
483483

484-
def _landmarks(self, X: NDArray[np.float64]) -> dict[tuple, NDArray[np.float64]]:
484+
def _landmarks(self, X: ArrayLike) -> dict[tuple, PointLike]:
485485
lmrks = {}
486486
for x in X:
487487
lmrk, _ = self._get_center(x)
488488
if lmrk not in lmrks:
489489
lmrks[lmrk] = x
490490
return lmrks
491491

492-
def apply(self, X: NDArray[np.float64]) -> Generator[list[int]]:
492+
def apply(self, X: ArrayLike) -> Generator[list[int]]:
493493
"""
494494
Covers the dataset using landmarks.
495495
@@ -595,7 +595,7 @@ def _get_cubical_cover(self) -> Union[ProximityCubicalCover, StandardCubicalCove
595595
"'proximity'."
596596
)
597597

598-
def fit(self, X: NDArray[np.float64]) -> CubicalCover:
598+
def fit(self, X: ArrayLike) -> CubicalCover:
599599
"""
600600
Train internal parameters.
601601
@@ -625,7 +625,7 @@ def search(self, x: Any) -> list[int]:
625625
"""
626626
return self._cubical_cover.search(x)
627627

628-
def apply(self, X: NDArray[np.float64]) -> Generator[list[int]]:
628+
def apply(self, X: ArrayLike) -> Generator[list[int]]:
629629
"""
630630
Covers the dataset using hypercubes.
631631

0 commit comments

Comments
 (0)