11import unittest
2+ import logging
23import time
34
45import numpy as np
910from tdamapper .utils .vptree_flat import VPTree as FVPT
1011
1112from tests .ball_tree import SkBallTree
13+ import tests .setup_logging
1214
1315
1416dist = euclidean ()
@@ -21,57 +23,59 @@ def dataset(dim=10, num=1000):
2123 return [np .random .rand (dim ) for _ in range (num )]
2224
2325
24- def cover (vpt , X , r ):
25- covered_ids = set ()
26- for i , xi in enumerate (X ):
27- if i not in covered_ids :
28- neigh = vpt .ball_search (xi , r )
29- neigh_ids = [int (x [0 ]) for x in neigh ]
30- covered_ids .update (neigh_ids )
31- if neigh_ids :
32- yield neigh_ids
33-
34-
35- def run (X , r , dist , vp , ** kwargs ):
36- XX = np .array ([[i ] + [xi for xi in x ] for i , x in enumerate (X )])
37- d = lambda x , y : dist (x [1 :], y [1 :])
38- t0 = time .time ()
39- vpt = vp (XX , metric = d , ** kwargs )
40- list (cover (vpt , XX , r ))
41- t1 = time .time ()
42- print (f'time: { t1 - t0 } ' )
43-
44-
4526class TestVpSettings (unittest .TestCase ):
4627
28+ logger = logging .getLogger (__name__ )
29+
30+ def cover (self , vpt , X , r ):
31+ covered_ids = set ()
32+ for i , xi in enumerate (X ):
33+ if i not in covered_ids :
34+ neigh = vpt .ball_search (xi , r )
35+ neigh_ids = [int (x [0 ]) for x in neigh ]
36+ covered_ids .update (neigh_ids )
37+ if neigh_ids :
38+ yield neigh_ids
39+
40+ def run_bench (self , X , r , dist , vp , ** kwargs ):
41+ XX = np .array ([[i ] + [xi for xi in x ] for i , x in enumerate (X )])
42+ d = lambda x , y : dist (x [1 :], y [1 :])
43+ t0 = time .time ()
44+ vpt = vp (XX , metric = d , ** kwargs )
45+ list (self .cover (vpt , XX , r ))
46+ t1 = time .time ()
47+ self .logger .info (f'time: { t1 - t0 } ' )
48+
4749 def test_cover_random (self ):
4850 for r in [1.0 , 10.0 , 100.0 ]:
4951 for n in [100 , 1000 , 10000 ]:
50- print (f'============= n: { n } , r: { r } =============' )
52+ self .logger .info (f'============ Cover Bench Random ==========' )
53+ self .logger .info (f'[n: { n } , r: { r } ]' )
5154 X = dataset (num = n )
52- print ('>>>>>>> HVPT >>>>>>' )
53- run (X , r , dist , HVPT , leaf_radius = r , pivoting = 'random' )
54- run (X , r , dist , HVPT , leaf_radius = r , pivoting = 'furthest' )
55- print ('>>>>>>> FVPT >>>>>>' )
56- run (X , r , dist , FVPT , leaf_radius = r , pivoting = 'random' )
57- run (X , r , dist , FVPT , leaf_radius = r , pivoting = 'furthest' )
58- print ('>>>>>> SKBT >>>>>>' )
59- run (X , r , dist , SkBallTree )
60- run (X , r , dist , SkBallTree , leaf_radius = r )
61- print ('' )
55+ self . logger . info ('>>>>>>> HVPT >>>>>>' )
56+ self . run_bench (X , r , dist , HVPT , leaf_radius = r , pivoting = 'random' )
57+ self . run_bench (X , r , dist , HVPT , leaf_radius = r , pivoting = 'furthest' )
58+ self . logger . info ('>>>>>>> FVPT >>>>>>' )
59+ self . run_bench (X , r , dist , FVPT , leaf_radius = r , pivoting = 'random' )
60+ self . run_bench (X , r , dist , FVPT , leaf_radius = r , pivoting = 'furthest' )
61+ self . logger . info ('>>>>>> SKBT >>>>>>' )
62+ self . run_bench (X , r , dist , SkBallTree )
63+ self . run_bench (X , r , dist , SkBallTree , leaf_radius = r )
64+ self . logger . info ('' )
6265
6366 def test_cover_digits (self ):
6467 X , _ = load_digits (return_X_y = True )
6568 #X = PCA(n_components=3).fit_transform(X)
6669 for r in [1.0 , 10.0 , 100.0 ]:
67- print (f'============= r: { r } =============' )
68- print ('>>>>>>> HVPT >>>>>>' )
69- run (X , r , dist , HVPT , leaf_radius = r , pivoting = 'random' )
70- run (X , r , dist , HVPT , leaf_radius = r , pivoting = 'furthest' )
71- print ('>>>>>>> FVPT >>>>>>' )
72- run (X , r , dist , FVPT , leaf_radius = r , pivoting = 'random' )
73- run (X , r , dist , FVPT , leaf_radius = r , pivoting = 'furthest' )
74- print ('>>>>>> SKBT >>>>>>' )
75- run (X , r , dist , SkBallTree )
76- run (X , r , dist , SkBallTree , leaf_radius = r )
77- print ('' )
70+ self .logger .info (f'======= Cover Bench Digits =======' )
71+ self .logger .info (f'[r: { r } ]' )
72+ self .logger .info ('>>>>>>> HVPT >>>>>>' )
73+ self .run_bench (X , r , dist , HVPT , leaf_radius = r , pivoting = 'random' )
74+ self .run_bench (X , r , dist , HVPT , leaf_radius = r , pivoting = 'furthest' )
75+ self .logger .info ('>>>>>>> FVPT >>>>>>' )
76+ self .run_bench (X , r , dist , FVPT , leaf_radius = r , pivoting = 'random' )
77+ self .run_bench (X , r , dist , FVPT , leaf_radius = r , pivoting = 'furthest' )
78+ self .logger .info ('>>>>>> SKBT >>>>>>' )
79+ self .run_bench (X , r , dist , SkBallTree )
80+ self .run_bench (X , r , dist , SkBallTree , leaf_radius = r )
81+ self .logger .info ('' )
0 commit comments