Skip to content

Commit 6037c5f

Browse files
committed
Improved types
1 parent adadc80 commit 6037c5f

5 files changed

Lines changed: 115 additions & 51 deletions

File tree

src/tdamapper/utils/vptree_flat/ball_search.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,37 @@
1-
from tdamapper.utils.vptree_flat.common import _mid
1+
from typing import Generic, Iterable, TypeVar
22

3+
from tdamapper.utils.metrics import Metric
4+
from tdamapper.utils.vptree_flat.common import VPArray, VPTreeType, _mid
35

4-
class BallSearch:
6+
T = TypeVar("T")
57

6-
def __init__(self, vpt, point, eps, inclusive=True):
8+
9+
class BallSearch(Generic[T]):
10+
11+
_arr: VPArray[T]
12+
_distance: Metric[T]
13+
_point: T
14+
_eps: float
15+
_inclusive: bool
16+
17+
def __init__(
18+
self, vpt: VPTreeType[T], point: T, eps: float, inclusive: bool = True
19+
):
720
self._arr = vpt._get_arr()
821
self._distance = vpt._get_distance()
922
self._point = point
1023
self._eps = eps
1124
self._inclusive = inclusive
1225

13-
def search(self):
26+
def search(self) -> Iterable[T]:
1427
return self._search_iter()
1528

16-
def _inside(self, dist):
29+
def _inside(self, dist: float) -> bool:
1730
if self._inclusive:
1831
return dist <= self._eps
1932
return dist < self._eps
2033

21-
def _search_iter(self):
34+
def _search_iter(self) -> Iterable[T]:
2235
stack = [(0, self._arr.size())]
2336
result = []
2437
while stack:

src/tdamapper/utils/vptree_flat/builder.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
from random import randrange
2+
from typing import Generic, Iterable, TypeVar
23

34
import numpy as np
45

5-
from tdamapper.utils.vptree_flat.common import VPArray
6+
from tdamapper.utils.vptree_flat.common import VPArray, VPTreeType
7+
8+
T = TypeVar("T")
69

710

811
def _mid(start, end):
912
return (start + end) // 2
1013

1114

12-
class Builder:
15+
class Builder(Generic[T]):
1316

14-
def __init__(self, vpt, X):
17+
def __init__(self, vpt: VPTreeType[T], X: Iterable[T]) -> None:
1518
self._distance = vpt._get_distance()
1619

1720
dataset = [x for x in X]
@@ -29,17 +32,17 @@ def __init__(self, vpt, X):
2932
elif pivoting == "furthest":
3033
self._pivoting = self._pivoting_furthest
3134

32-
def _pivoting_disabled(self, start, end):
35+
def _pivoting_disabled(self, start: int, end: int) -> None:
3336
pass
3437

35-
def _pivoting_random(self, start, end):
38+
def _pivoting_random(self, start: int, end: int) -> None:
3639
if end <= start:
3740
return
3841
pivot = randrange(start, end)
3942
if pivot > start:
4043
self._arr.swap(start, pivot)
4144

42-
def _furthest(self, start, end, i):
45+
def _furthest(self, start: int, end: int, i: int) -> int:
4346
furthest_dist = 0.0
4447
furthest = start
4548
i_point = self._arr.get_point(i)
@@ -51,7 +54,7 @@ def _furthest(self, start, end, i):
5154
furthest_dist = j_dist
5255
return furthest
5356

54-
def _pivoting_furthest(self, start, end):
57+
def _pivoting_furthest(self, start: int, end: int) -> None:
5558
if end <= start:
5659
return
5760
rnd = randrange(start, end)
@@ -60,7 +63,7 @@ def _pivoting_furthest(self, start, end):
6063
if furthest > start:
6164
self._arr.swap(start, furthest)
6265

63-
def _update(self, start, end):
66+
def _update(self, start: int, end: int) -> None:
6467
self._pivoting(start, end)
6568
v_point = self._arr.get_point(start)
6669
is_terminal = self._arr.is_terminal(start)
@@ -69,11 +72,11 @@ def _update(self, start, end):
6972
self._arr.set_distance(i, self._distance(v_point, point))
7073
self._arr.set_terminal(i, is_terminal)
7174

72-
def build(self):
75+
def build(self) -> VPArray[T]:
7376
self._build_iter()
7477
return self._arr
7578

76-
def _build_iter(self):
79+
def _build_iter(self) -> None:
7780
stack = [(0, self._arr.size())]
7881
while stack:
7982
start, end = stack.pop()
Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,72 @@
1+
from __future__ import annotations
2+
3+
from typing import Generic, Iterable, Optional, Protocol, TypeVar
4+
5+
import numpy as np
6+
from numpy.typing import NDArray
7+
8+
from tdamapper._common import ArrayLike
9+
from tdamapper.utils.metrics import Metric
110
from tdamapper.utils.quickselect import quickselect, swap_all
211

12+
T = TypeVar("T")
13+
314

415
def _mid(start, end):
516
return (start + end) // 2
617

718

8-
class VPArray:
19+
class VPTreeType(Protocol[T]):
20+
21+
def _get_arr(self) -> VPArray[T]: ...
22+
23+
def _get_distance(self) -> Metric[T]: ...
24+
25+
def get_leaf_capacity(self) -> int: ...
26+
27+
def get_leaf_radius(self) -> float: ...
28+
29+
def get_pivoting(self) -> Optional[str]: ...
30+
31+
32+
class VPArray(Generic[T]):
933

10-
def __init__(self, dataset, distances, indices, is_terminal):
34+
def __init__(
35+
self,
36+
dataset: ArrayLike[T],
37+
distances: NDArray[np.float64],
38+
indices: NDArray[np.int64],
39+
is_terminal: NDArray[np.bool_],
40+
):
1141
self._dataset = dataset
1242
self._distances = distances
1343
self._indices = indices
1444
self._is_terminal = is_terminal
1545

16-
def size(self):
46+
def size(self) -> int:
1747
return len(self._dataset)
1848

19-
def get_point(self, i):
49+
def get_point(self, i: int) -> T:
2050
return self._dataset[self._indices[i]]
2151

22-
def get_points(self, s, e):
52+
def get_points(self, s: int, e: int) -> Iterable[T]:
2353
for x_index in self._indices[s:e]:
2454
yield self._dataset[x_index]
2555

26-
def get_distance(self, i):
56+
def get_distance(self, i: int) -> float:
2757
return self._distances[i]
2858

29-
def set_distance(self, i, dist):
59+
def set_distance(self, i: int, dist: float) -> None:
3060
self._distances[i] = dist
3161

32-
def set_terminal(self, i, terminal):
62+
def set_terminal(self, i: int, terminal: bool) -> None:
3363
self._is_terminal[i] = terminal
3464

35-
def is_terminal(self, i):
65+
def is_terminal(self, i: int) -> bool:
3666
return self._is_terminal[i]
3767

38-
def swap(self, i, j):
68+
def swap(self, i: int, j: int) -> None:
3969
swap_all(self._distances, i, j, self._indices, self._is_terminal)
4070

41-
def partition(self, s, e, k):
71+
def partition(self, s: int, e: int, k: int) -> None:
4272
quickselect(self._distances, s, e, k, self._indices, self._is_terminal)

src/tdamapper/utils/vptree_flat/knn_search.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,41 @@
1+
from typing import Generic, TypeVar
2+
13
from tdamapper.utils.heap import MaxHeap
2-
from tdamapper.utils.vptree_flat.common import _mid
4+
from tdamapper.utils.metrics import Metric
5+
from tdamapper.utils.vptree_flat.common import VPArray, VPTreeType, _mid
36

47
_PRE = 0
58
_POST = 1
69

10+
T = TypeVar("T")
11+
12+
13+
class KnnSearch(Generic[T]):
714

8-
class KnnSearch:
15+
_arr: VPArray[T]
16+
_distance: Metric[T]
17+
_point: T
18+
_neighbors: int
19+
_result: MaxHeap[float, T]
920

10-
def __init__(self, vpt, point, neighbors):
21+
def __init__(self, vpt: VPTreeType[T], point: T, neighbors: int) -> None:
1122
self._arr = vpt._get_arr()
1223
self._distance = vpt._get_distance()
1324
self._point = point
1425
self._neighbors = neighbors
1526
self._radius = float("inf")
1627
self._result = MaxHeap()
1728

18-
def _get_items(self):
29+
def _get_items(self) -> list[T]:
1930
while len(self._result) > self._neighbors:
2031
self._result.pop()
2132
return [x for (_, x) in self._result]
2233

23-
def search(self):
34+
def search(self) -> list[T]:
2435
self._search_iter()
2536
return self._get_items()
2637

27-
def _process(self, x):
38+
def _process(self, x: T) -> float:
2839
dist = self._distance(self._point, x)
2940
if dist >= self._radius:
3041
return dist
@@ -35,7 +46,7 @@ def _process(self, x):
3546
self._radius, _ = self._result.top()
3647
return dist
3748

38-
def _search_iter(self):
49+
def _search_iter(self) -> list[T]:
3950
self._result = MaxHeap()
4051
stack = [(0, self._arr.size(), 0.0, _PRE)]
4152
while stack:
Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,58 @@
1-
from tdamapper.utils.metrics import get_metric
1+
from typing import Any, Generic, Iterable, Optional, TypeVar, Union
2+
3+
import numpy as np
4+
from numpy.typing import NDArray
5+
6+
from tdamapper.utils.metrics import Metric, get_metric
27
from tdamapper.utils.vptree_flat.ball_search import BallSearch
38
from tdamapper.utils.vptree_flat.builder import Builder
49
from tdamapper.utils.vptree_flat.knn_search import KnnSearch
510

11+
T = TypeVar("T")
12+
613

7-
class VPTree:
14+
class VPTree(Generic[T]):
815

916
def __init__(
1017
self,
11-
X,
12-
metric="euclidean",
13-
metric_params=None,
14-
leaf_capacity=1,
15-
leaf_radius=0.0,
16-
pivoting=None,
17-
):
18+
X: Iterable[T],
19+
metric: Union[str, Metric[T]] = "euclidean",
20+
metric_params: Optional[dict[str, Any]] = None,
21+
leaf_capacity: int = 1,
22+
leaf_radius: float = 0.0,
23+
pivoting: Optional[str] = None,
24+
) -> None:
1825
self._metric = metric
1926
self._metric_params = metric_params
2027
self._leaf_capacity = leaf_capacity
2128
self._leaf_radius = leaf_radius
2229
self._pivoting = pivoting
2330
self._arr = Builder(self, X).build()
2431

25-
def get_metric(self):
32+
def get_metric(self) -> Union[str, Metric[T]]:
2633
return self._metric
2734

28-
def get_metric_params(self):
35+
def get_metric_params(self) -> Optional[dict[str, Any]]:
2936
return self._metric_params
3037

31-
def get_leaf_capacity(self):
38+
def get_leaf_capacity(self) -> int:
3239
return self._leaf_capacity
3340

34-
def get_leaf_radius(self):
41+
def get_leaf_radius(self) -> float:
3542
return self._leaf_radius
3643

37-
def get_pivoting(self):
44+
def get_pivoting(self) -> Optional[str]:
3845
return self._pivoting
3946

4047
def _get_arr(self):
4148
return self._arr
4249

43-
def _get_distance(self):
50+
def _get_distance(self) -> Union[Metric[NDArray[np.float64]], Metric[T]]:
4451
metric_params = self._metric_params or {}
4552
return get_metric(self._metric, **metric_params)
4653

47-
def ball_search(self, point, eps, inclusive=True):
54+
def ball_search(self, point: T, eps: float, inclusive: bool = True) -> Iterable[T]:
4855
return BallSearch(self, point, eps, inclusive).search()
4956

50-
def knn_search(self, point, k):
57+
def knn_search(self, point: T, k: int) -> Iterable[T]:
5158
return KnnSearch(self, point, k).search()

0 commit comments

Comments
 (0)