-
Notifications
You must be signed in to change notification settings - Fork 25
Expand file tree
/
Copy pathvptree.py
More file actions
99 lines (88 loc) · 3.55 KB
/
vptree.py
File metadata and controls
99 lines (88 loc) · 3.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
A module for fast knn and range searches, depending only on a given metric
"""
from tdamapper.utils.vptree_flat.vptree import VPTree as FVPT
from tdamapper.utils.vptree_hier.vptree import VPTree as HVPT
class VPTree:
"""
A Vantage Point Tree, or vp-tree, for fast range-queries and knn-queries.
:param X: A dataset of n points.
:type X: array-like of shape (n, m) or list-like of length n
:param metric: The metric used to define the distance between points.
Accepts any value compatible with `tdamapper.utils.metrics.get_metric`.
Defaults to 'euclidean'.
:type metric: str or callable
:param metric_params: Additional parameters for the metric function, to be
passed to `tdamapper.utils.metrics.get_metric`. Defaults to None.
:type metric_params: dict, optional
:param kind: Specifies whether to use a flat or a hierarchical vantage
point tree. Acceptable values are 'flat' or 'hierarchical'. Defaults
to 'flat'.
:type kind: str
:param leaf_capacity: The maximum number of points in a leaf node of the
vantage point tree. Must be a positive value. Defaults to 1.
:type leaf_capacity: int
:param leaf_radius: The radius of the leaf nodes. Must be a positive
value. Defaults to 0.0.
:type leaf_radius: float
:param pivoting: The method used for pivoting in the vantage point tree.
Acceptable values are None, 'random', or 'furthest'. Defaults to None.
:type pivoting: str or callable, optional
"""
def __init__(
self,
X,
metric="euclidean",
metric_params=None,
kind="flat",
leaf_capacity=1,
leaf_radius=0.0,
pivoting=None,
):
builder = FVPT
if kind == "flat":
builder = FVPT
elif kind == "hierarchical":
builder = HVPT
else:
raise ValueError(f"Unknown kind of vptree: {kind}")
self.__vpt = builder(
X,
metric=metric,
metric_params=metric_params,
leaf_capacity=leaf_capacity,
leaf_radius=leaf_radius,
pivoting=pivoting,
)
def ball_search(self, point, eps, inclusive=True):
"""
Perform a ball search in the Vantage Point Tree.
This method searches for all points within a specified radius from a
given point.
:param point: The query point from which to search for neighbors.
:type point: objet, list, or array-like
:param eps: The radius within which to search for neighbors. Must be
positive.
:type eps: float
:param inclusive: Whether to include points exactly at the distance
`eps` from `point`. Defaults to True.
:type inclusive: bool
:return: A list of points within the specified radius from the given
query point.
:rtype: list
"""
return self.__vpt.ball_search(point, eps, inclusive=inclusive)
def knn_search(self, point, k):
"""
Perform a k-nearest neighbors search in the Vantage Point Tree.
This method searches for the k-nearest neighbors to a given query
point.
:param point: The point from which to search for nearest neighbors.
:type point: objet, list, or array-like
:param k: The number of nearest neighbors to search for. Must be
positive.
:type k: int
:return: A list of the k-nearest neighbors to the given query point.
:rtype: list
"""
return self.__vpt.knn_search(point, k)