11from random import randrange
2+ from typing import Callable , Generic , Iterable , TypeVar
23
34import numpy as np
45
5- from tdamapper .utils .vptree_hier .common import Leaf , Node , VPArray , _mid
6+ from tdamapper .utils .metrics import Metric
7+ from tdamapper .utils .vptree_hier .common import (
8+ Leaf ,
9+ Node ,
10+ Tree ,
11+ VPArray ,
12+ VPTreeType ,
13+ _mid ,
14+ )
615
16+ T = TypeVar ("T" )
717
8- class Builder :
918
10- def __init__ (self , vpt , X ):
19+ class Builder (Generic [T ]):
20+
21+ _arr : VPArray [T ]
22+ _leaf_capacity : int
23+ _leaf_radius : float
24+ _distance : Metric [T ]
25+ _pivoting : Callable [[int , int ], None ]
26+
27+ def __init__ (self , vpt : VPTreeType [T ], items : Iterable [T ]):
1128 self ._distance = vpt ._get_distance ()
1229
13- dataset = [x for x in X ]
30+ dataset = [x for x in items ]
1431 indices = np .array ([i for i in range (len (dataset ))])
15- distances = np .array ([0.0 for _ in X ])
32+ distances = np .array ([0.0 for _ in items ])
1633 self ._arr = VPArray (dataset , distances , indices )
1734
1835 self ._leaf_capacity = vpt .get_leaf_capacity ()
@@ -24,17 +41,17 @@ def __init__(self, vpt, X):
2441 elif pivoting == "furthest" :
2542 self ._pivoting = self ._pivoting_furthest
2643
27- def _pivoting_disabled (self , start , end ) :
44+ def _pivoting_disabled (self , start : int , end : int ) -> None :
2845 pass
2946
30- def _pivoting_random (self , start , end ) :
47+ def _pivoting_random (self , start : int , end : int ) -> None :
3148 if end <= start :
3249 return
3350 pivot = randrange (start , end )
3451 if pivot > start :
3552 self ._arr .swap (start , pivot )
3653
37- def _furthest (self , start , end , i ) :
54+ def _furthest (self , start : int , end : int , i : int ) -> int :
3855 furthest_dist = 0.0
3956 furthest = start
4057 i_point = self ._arr .get_point (i )
@@ -46,7 +63,7 @@ def _furthest(self, start, end, i):
4663 furthest_dist = j_dist
4764 return furthest
4865
49- def _pivoting_furthest (self , start , end ) :
66+ def _pivoting_furthest (self , start : int , end : int ) -> None :
5067 if end <= start :
5168 return
5269 rnd = randrange (start , end )
@@ -55,24 +72,26 @@ def _pivoting_furthest(self, start, end):
5572 if furthest > start :
5673 self ._arr .swap (start , furthest )
5774
58- def _update (self , start , end ) :
75+ def _update (self , start : int , end : int ) -> None :
5976 self ._pivoting (start , end )
6077 v_point = self ._arr .get_point (start )
6178 for i in range (start + 1 , end ):
6279 point = self ._arr .get_point (i )
6380 self ._arr .set_distance (i , self ._distance (v_point , point ))
6481
65- def build (self ):
82+ def build (self ) -> tuple [ Tree [ T ], VPArray [ T ]] :
6683 tree = self ._build_rec (0 , self ._arr .size ())
6784 return tree , self ._arr
6885
69- def _build_rec (self , start , end ) :
86+ def _build_rec (self , start : int , end : int ) -> Tree [ T ] :
7087 mid = _mid (start , end )
7188 self ._update (start , end )
7289 v_point = self ._arr .get_point (start )
7390 self ._arr .partition (start + 1 , end , mid )
7491 v_radius = self ._arr .get_distance (mid )
7592 self ._arr .set_distance (start , v_radius )
93+ left : Tree [T ]
94+ right : Tree [T ]
7695 if (end - start <= 2 * self ._leaf_capacity ) or (v_radius <= self ._leaf_radius ):
7796 left = Leaf (start + 1 , mid )
7897 right = Leaf (mid , end )
0 commit comments