Skip to content

Commit 2dab60c

Browse files
committed
Added types
1 parent 6037c5f commit 2dab60c

10 files changed

Lines changed: 189 additions & 73 deletions

File tree

src/tdamapper/utils/vptree_flat/ball_search.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Generic, Iterable, TypeVar
1+
from typing import Generic, TypeVar
22

33
from tdamapper.utils.metrics import Metric
44
from tdamapper.utils.vptree_flat.common import VPArray, VPTreeType, _mid
@@ -23,15 +23,15 @@ def __init__(
2323
self._eps = eps
2424
self._inclusive = inclusive
2525

26-
def search(self) -> Iterable[T]:
26+
def search(self) -> list[T]:
2727
return self._search_iter()
2828

2929
def _inside(self, dist: float) -> bool:
3030
if self._inclusive:
3131
return dist <= self._eps
3232
return dist < self._eps
3333

34-
def _search_iter(self) -> Iterable[T]:
34+
def _search_iter(self) -> list[T]:
3535
stack = [(0, self._arr.size())]
3636
result = []
3737
while stack:

src/tdamapper/utils/vptree_flat/builder.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from random import randrange
2-
from typing import Generic, Iterable, TypeVar
2+
from typing import Callable, Generic, Iterable, TypeVar
33

44
import numpy as np
55

6+
from tdamapper.utils.metrics import Metric
67
from tdamapper.utils.vptree_flat.common import VPArray, VPTreeType
78

89
T = TypeVar("T")
@@ -14,13 +15,19 @@ def _mid(start, end):
1415

1516
class Builder(Generic[T]):
1617

17-
def __init__(self, vpt: VPTreeType[T], X: Iterable[T]) -> None:
18+
_arr: VPArray[T]
19+
_leaf_capacity: int
20+
_leaf_radius: float
21+
_distance: Metric[T]
22+
_pivoting: Callable[[int, int], None]
23+
24+
def __init__(self, vpt: VPTreeType[T], items: Iterable[T]) -> None:
1825
self._distance = vpt._get_distance()
1926

20-
dataset = [x for x in X]
27+
dataset = [x for x in items]
2128
indices = np.array([i for i in range(len(dataset))])
22-
distances = np.array([0.0 for _ in X])
23-
is_terminal = np.array([False for _ in X])
29+
distances = np.array([0.0 for _ in items])
30+
is_terminal = np.array([False for _ in items])
2431
self._arr = VPArray(dataset, distances, indices, is_terminal)
2532

2633
self._leaf_capacity = vpt.get_leaf_capacity()

src/tdamapper/utils/vptree_flat/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
T = TypeVar("T")
1313

1414

15-
def _mid(start, end):
15+
def _mid(start: int, end: int) -> int:
1616
return (start + end) // 2
1717

1818

src/tdamapper/utils/vptree_flat/knn_search.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def _process(self, x: T) -> float:
4343
while len(self._result) > self._neighbors:
4444
self._result.pop()
4545
if len(self._result) == self._neighbors:
46-
self._radius, _ = self._result.top()
46+
radius, _ = self._result.top()
47+
if radius is not None:
48+
self._radius = radius
4749
return dist
4850

4951
def _search_iter(self) -> list[T]:

src/tdamapper/utils/vptree_flat/vptree.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,15 @@
1313

1414
class VPTree(Generic[T]):
1515

16+
_metric: Union[str, Metric[T]]
17+
_metric_params: Optional[dict[str, Any]]
18+
_leaf_capacity: int
19+
_leaf_radius: float
20+
_pivoting: Optional[str]
21+
1622
def __init__(
1723
self,
18-
X: Iterable[T],
24+
items: Iterable[T],
1925
metric: Union[str, Metric[T]] = "euclidean",
2026
metric_params: Optional[dict[str, Any]] = None,
2127
leaf_capacity: int = 1,
@@ -27,7 +33,7 @@ def __init__(
2733
self._leaf_capacity = leaf_capacity
2834
self._leaf_radius = leaf_radius
2935
self._pivoting = pivoting
30-
self._arr = Builder(self, X).build()
36+
self._arr = Builder(self, items).build()
3137

3238
def get_metric(self) -> Union[str, Metric[T]]:
3339
return self._metric

src/tdamapper/utils/vptree_hier/ball_search.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,24 @@
1-
class BallSearch:
1+
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union
22

3-
def __init__(self, vpt, point, eps, inclusive=True):
3+
from tdamapper.utils.metrics import Metric
4+
from tdamapper.utils.vptree_hier.common import Tree, VPArray, VPTreeType
5+
6+
T = TypeVar("T")
7+
8+
9+
class BallSearch(Generic[T]):
10+
11+
_tree: Tree[T]
12+
_arr: VPArray[T]
13+
_distance: Metric[T]
14+
_point: T
15+
_eps: float
16+
_inclusive: bool
17+
_result: list[T]
18+
19+
def __init__(
20+
self, vpt: VPTreeType[T], point: T, eps: float, inclusive: bool = True
21+
) -> None:
422
self._tree = vpt._get_tree()
523
self._arr = vpt._get_arr()
624
self._distance = vpt._get_distance()
@@ -9,17 +27,17 @@ def __init__(self, vpt, point, eps, inclusive=True):
927
self._inclusive = inclusive
1028
self._result = []
1129

12-
def search(self):
30+
def search(self) -> list[T]:
1331
self._result.clear()
1432
self._search_rec(self._tree)
1533
return self._result
1634

17-
def _inside(self, dist):
35+
def _inside(self, dist: float) -> bool:
1836
if self._inclusive:
1937
return dist <= self._eps
2038
return dist < self._eps
2139

22-
def _search_rec(self, tree):
40+
def _search_rec(self, tree: Tree[T]) -> None:
2341
if tree.is_terminal():
2442
start, end = tree.get_bounds()
2543
for x in self._arr.get_points(start, end):

src/tdamapper/utils/vptree_hier/builder.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,35 @@
11
from random import randrange
2+
from typing import Callable, Generic, Iterable, TypeVar
23

34
import numpy as np
45

5-
from tdamapper.utils.vptree_hier.common import Leaf, Node, VPArray, _mid
6+
from tdamapper.utils.metrics import Metric
7+
from tdamapper.utils.vptree_hier.common import (
8+
Leaf,
9+
Node,
10+
Tree,
11+
VPArray,
12+
VPTreeType,
13+
_mid,
14+
)
615

16+
T = TypeVar("T")
717

8-
class Builder:
918

10-
def __init__(self, vpt, X):
19+
class Builder(Generic[T]):
20+
21+
_arr: VPArray[T]
22+
_leaf_capacity: int
23+
_leaf_radius: float
24+
_distance: Metric[T]
25+
_pivoting: Callable[[int, int], None]
26+
27+
def __init__(self, vpt: VPTreeType[T], items: Iterable[T]):
1128
self._distance = vpt._get_distance()
1229

13-
dataset = [x for x in X]
30+
dataset = [x for x in items]
1431
indices = np.array([i for i in range(len(dataset))])
15-
distances = np.array([0.0 for _ in X])
32+
distances = np.array([0.0 for _ in items])
1633
self._arr = VPArray(dataset, distances, indices)
1734

1835
self._leaf_capacity = vpt.get_leaf_capacity()
@@ -24,17 +41,17 @@ def __init__(self, vpt, X):
2441
elif pivoting == "furthest":
2542
self._pivoting = self._pivoting_furthest
2643

27-
def _pivoting_disabled(self, start, end):
44+
def _pivoting_disabled(self, start: int, end: int) -> None:
2845
pass
2946

30-
def _pivoting_random(self, start, end):
47+
def _pivoting_random(self, start: int, end: int) -> None:
3148
if end <= start:
3249
return
3350
pivot = randrange(start, end)
3451
if pivot > start:
3552
self._arr.swap(start, pivot)
3653

37-
def _furthest(self, start, end, i):
54+
def _furthest(self, start: int, end: int, i: int) -> int:
3855
furthest_dist = 0.0
3956
furthest = start
4057
i_point = self._arr.get_point(i)
@@ -46,7 +63,7 @@ def _furthest(self, start, end, i):
4663
furthest_dist = j_dist
4764
return furthest
4865

49-
def _pivoting_furthest(self, start, end):
66+
def _pivoting_furthest(self, start: int, end: int) -> None:
5067
if end <= start:
5168
return
5269
rnd = randrange(start, end)
@@ -55,24 +72,26 @@ def _pivoting_furthest(self, start, end):
5572
if furthest > start:
5673
self._arr.swap(start, furthest)
5774

58-
def _update(self, start, end):
75+
def _update(self, start: int, end: int) -> None:
5976
self._pivoting(start, end)
6077
v_point = self._arr.get_point(start)
6178
for i in range(start + 1, end):
6279
point = self._arr.get_point(i)
6380
self._arr.set_distance(i, self._distance(v_point, point))
6481

65-
def build(self):
82+
def build(self) -> tuple[Tree[T], VPArray[T]]:
6683
tree = self._build_rec(0, self._arr.size())
6784
return tree, self._arr
6885

69-
def _build_rec(self, start, end):
86+
def _build_rec(self, start: int, end: int) -> Tree[T]:
7087
mid = _mid(start, end)
7188
self._update(start, end)
7289
v_point = self._arr.get_point(start)
7390
self._arr.partition(start + 1, end, mid)
7491
v_radius = self._arr.get_distance(mid)
7592
self._arr.set_distance(start, v_radius)
93+
left: Tree[T]
94+
right: Tree[T]
7695
if (end - start <= 2 * self._leaf_capacity) or (v_radius <= self._leaf_radius):
7796
left = Leaf(start + 1, mid)
7897
right = Leaf(mid, end)
Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,103 @@
1+
from __future__ import annotations
2+
3+
from typing import Generic, Iterable, Optional, Protocol, TypeVar, Union
4+
5+
import numpy as np
6+
from numpy.typing import NDArray
7+
8+
from tdamapper._common import ArrayLike
9+
from tdamapper.utils.metrics import Metric
110
from tdamapper.utils.quickselect import quickselect, swap_all
211

12+
T = TypeVar("T")
13+
314

4-
def _mid(start, end):
15+
def _mid(start: int, end: int) -> int:
516
return (start + end) // 2
617

718

8-
class VPArray:
19+
class VPTreeType(Protocol[T]):
20+
21+
def _get_arr(self) -> VPArray[T]: ...
22+
23+
def _get_distance(self) -> Metric[T]: ...
24+
25+
def _get_tree(self) -> Tree[T]: ...
926

10-
def __init__(self, dataset, distances, indices):
27+
def get_leaf_capacity(self) -> int: ...
28+
29+
def get_leaf_radius(self) -> float: ...
30+
31+
def get_pivoting(self) -> Optional[str]: ...
32+
33+
34+
class VPArray(Generic[T]):
35+
36+
def __init__(
37+
self,
38+
dataset: ArrayLike[T],
39+
distances: NDArray[np.float64],
40+
indices: NDArray[np.bool_],
41+
):
1142
self._dataset = dataset
1243
self._distances = distances
1344
self._indices = indices
1445

15-
def size(self):
46+
def size(self) -> int:
1647
return len(self._dataset)
1748

18-
def get_point(self, i):
49+
def get_point(self, i: int) -> T:
1950
return self._dataset[self._indices[i]]
2051

21-
def get_points(self, s, e):
52+
def get_points(self, s: int, e: int) -> Iterable[T]:
2253
for x_index in self._indices[s:e]:
2354
yield self._dataset[x_index]
2455

25-
def get_distance(self, i):
56+
def get_distance(self, i: int) -> float:
2657
return self._distances[i]
2758

28-
def set_distance(self, i, dist):
59+
def set_distance(self, i: int, dist: float) -> None:
2960
self._distances[i] = dist
3061

31-
def swap(self, i, j):
62+
def swap(self, i: int, j: int) -> None:
3263
swap_all(self._distances, i, j, self._indices)
3364

34-
def partition(self, s, e, k):
65+
def partition(self, s: int, e: int, k: int) -> None:
3566
quickselect(self._distances, s, e, k, self._indices)
3667

3768

38-
class Node:
69+
class Node(Generic[T]):
3970

40-
def __init__(self, radius, center, left, right):
71+
def __init__(self, radius: float, center: T, left: Tree[T], right: Tree[T]):
4172
self._radius = radius
4273
self._center = center
4374
self._left = left
4475
self._right = right
4576

46-
def get_ball(self):
77+
def get_ball(self) -> tuple[float, T]:
4778
return self._radius, self._center
4879

49-
def is_terminal(self):
80+
def is_terminal(self) -> bool:
5081
return False
5182

52-
def get_left(self):
83+
def get_left(self) -> Tree[T]:
5384
return self._left
5485

55-
def get_right(self):
86+
def get_right(self) -> Tree[T]:
5687
return self._right
5788

5889

5990
class Leaf:
6091

61-
def __init__(self, start, end):
92+
def __init__(self, start: int, end: int) -> None:
6293
self._start = start
6394
self._end = end
6495

65-
def get_bounds(self):
96+
def get_bounds(self) -> tuple[int, int]:
6697
return self._start, self._end
6798

68-
def is_terminal(self):
99+
def is_terminal(self) -> bool:
69100
return True
101+
102+
103+
Tree = Union[Node[T], Leaf]

0 commit comments

Comments
 (0)