|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import Generic, Iterable, Optional, Protocol, TypeVar |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from numpy.typing import NDArray |
| 7 | + |
| 8 | +from tdamapper._common import ArrayLike |
| 9 | +from tdamapper.utils.metrics import Metric |
1 | 10 | from tdamapper.utils.quickselect import quickselect, swap_all |
2 | 11 |
|
| 12 | +T = TypeVar("T") |
| 13 | + |
3 | 14 |
|
4 | 15 | def _mid(start, end): |
5 | 16 | return (start + end) // 2 |
6 | 17 |
|
7 | 18 |
|
8 | | -class VPArray: |
| 19 | +class VPTreeType(Protocol[T]): |
| 20 | + |
| 21 | + def _get_arr(self) -> VPArray[T]: ... |
| 22 | + |
| 23 | + def _get_distance(self) -> Metric[T]: ... |
| 24 | + |
| 25 | + def get_leaf_capacity(self) -> int: ... |
| 26 | + |
| 27 | + def get_leaf_radius(self) -> float: ... |
| 28 | + |
| 29 | + def get_pivoting(self) -> Optional[str]: ... |
| 30 | + |
| 31 | + |
| 32 | +class VPArray(Generic[T]): |
9 | 33 |
|
10 | | - def __init__(self, dataset, distances, indices, is_terminal): |
| 34 | + def __init__( |
| 35 | + self, |
| 36 | + dataset: ArrayLike[T], |
| 37 | + distances: NDArray[np.float64], |
| 38 | + indices: NDArray[np.int64], |
| 39 | + is_terminal: NDArray[np.bool_], |
| 40 | + ): |
11 | 41 | self._dataset = dataset |
12 | 42 | self._distances = distances |
13 | 43 | self._indices = indices |
14 | 44 | self._is_terminal = is_terminal |
15 | 45 |
|
16 | | - def size(self): |
| 46 | + def size(self) -> int: |
17 | 47 | return len(self._dataset) |
18 | 48 |
|
19 | | - def get_point(self, i): |
| 49 | + def get_point(self, i: int) -> T: |
20 | 50 | return self._dataset[self._indices[i]] |
21 | 51 |
|
22 | | - def get_points(self, s, e): |
| 52 | + def get_points(self, s: int, e: int) -> Iterable[T]: |
23 | 53 | for x_index in self._indices[s:e]: |
24 | 54 | yield self._dataset[x_index] |
25 | 55 |
|
26 | | - def get_distance(self, i): |
| 56 | + def get_distance(self, i: int) -> float: |
27 | 57 | return self._distances[i] |
28 | 58 |
|
29 | | - def set_distance(self, i, dist): |
| 59 | + def set_distance(self, i: int, dist: float) -> None: |
30 | 60 | self._distances[i] = dist |
31 | 61 |
|
32 | | - def set_terminal(self, i, terminal): |
| 62 | + def set_terminal(self, i: int, terminal: bool) -> None: |
33 | 63 | self._is_terminal[i] = terminal |
34 | 64 |
|
35 | | - def is_terminal(self, i): |
| 65 | + def is_terminal(self, i: int) -> bool: |
36 | 66 | return self._is_terminal[i] |
37 | 67 |
|
38 | | - def swap(self, i, j): |
| 68 | + def swap(self, i: int, j: int) -> None: |
39 | 69 | swap_all(self._distances, i, j, self._indices, self._is_terminal) |
40 | 70 |
|
41 | | - def partition(self, s, e, k): |
| 71 | + def partition(self, s: int, e: int, k: int) -> None: |
42 | 72 | quickselect(self._distances, s, e, k, self._indices, self._is_terminal) |
0 commit comments