Skip to content

Commit 80f0944

Browse files
committed
Improved logging
1 parent 3588b0c commit 80f0944

File tree

4 files changed

+64
-49
lines changed

4 files changed

+64
-49
lines changed

tests/setup_logging.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import logging
2+
3+
4+
logging.basicConfig(
5+
level=logging.INFO,
6+
format='%(asctime)s %(module)s %(levelname)s: %(message)s',
7+
handlers=[
8+
logging.StreamHandler() # Logs to console
9+
]
10+
)

tests/test_bench_cover.py

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
import logging
23
import time
34

45
import numpy as np
@@ -9,6 +10,7 @@
910
from tdamapper.utils.vptree_flat import VPTree as FVPT
1011

1112
from tests.ball_tree import SkBallTree
13+
import tests.setup_logging
1214

1315

1416
dist = euclidean()
@@ -21,57 +23,59 @@ def dataset(dim=10, num=1000):
2123
return [np.random.rand(dim) for _ in range(num)]
2224

2325

24-
def cover(vpt, X, r):
25-
covered_ids = set()
26-
for i, xi in enumerate(X):
27-
if i not in covered_ids:
28-
neigh = vpt.ball_search(xi, r)
29-
neigh_ids = [int(x[0]) for x in neigh]
30-
covered_ids.update(neigh_ids)
31-
if neigh_ids:
32-
yield neigh_ids
33-
34-
35-
def run(X, r, dist, vp, **kwargs):
36-
XX = np.array([[i] + [xi for xi in x] for i, x in enumerate(X)])
37-
d = lambda x, y: dist(x[1:], y[1:])
38-
t0 = time.time()
39-
vpt = vp(XX, metric=d, **kwargs)
40-
list(cover(vpt, XX, r))
41-
t1 = time.time()
42-
print(f'time: {t1 - t0}')
43-
44-
4526
class TestVpSettings(unittest.TestCase):
4627

28+
logger = logging.getLogger(__name__)
29+
30+
def cover(self, vpt, X, r):
31+
covered_ids = set()
32+
for i, xi in enumerate(X):
33+
if i not in covered_ids:
34+
neigh = vpt.ball_search(xi, r)
35+
neigh_ids = [int(x[0]) for x in neigh]
36+
covered_ids.update(neigh_ids)
37+
if neigh_ids:
38+
yield neigh_ids
39+
40+
def run_bench(self, X, r, dist, vp, **kwargs):
41+
XX = np.array([[i] + [xi for xi in x] for i, x in enumerate(X)])
42+
d = lambda x, y: dist(x[1:], y[1:])
43+
t0 = time.time()
44+
vpt = vp(XX, metric=d, **kwargs)
45+
list(self.cover(vpt, XX, r))
46+
t1 = time.time()
47+
self.logger.info(f'time: {t1 - t0}')
48+
4749
def test_cover_random(self):
4850
for r in [1.0, 10.0, 100.0]:
4951
for n in [100, 1000, 10000]:
50-
print(f'============= n: {n}, r: {r} =============')
52+
self.logger.info(f'============ Cover Bench Random ==========')
53+
self.logger.info(f'[n: {n}, r: {r}]')
5154
X = dataset(num=n)
52-
print('>>>>>>> HVPT >>>>>>')
53-
run(X, r, dist, HVPT, leaf_radius=r, pivoting='random')
54-
run(X, r, dist, HVPT, leaf_radius=r, pivoting='furthest')
55-
print('>>>>>>> FVPT >>>>>>')
56-
run(X, r, dist, FVPT, leaf_radius=r, pivoting='random')
57-
run(X, r, dist, FVPT, leaf_radius=r, pivoting='furthest')
58-
print('>>>>>> SKBT >>>>>>')
59-
run(X, r, dist, SkBallTree)
60-
run(X, r, dist, SkBallTree, leaf_radius=r)
61-
print('')
55+
self.logger.info('>>>>>>> HVPT >>>>>>')
56+
self.run_bench(X, r, dist, HVPT, leaf_radius=r, pivoting='random')
57+
self.run_bench(X, r, dist, HVPT, leaf_radius=r, pivoting='furthest')
58+
self.logger.info('>>>>>>> FVPT >>>>>>')
59+
self.run_bench(X, r, dist, FVPT, leaf_radius=r, pivoting='random')
60+
self.run_bench(X, r, dist, FVPT, leaf_radius=r, pivoting='furthest')
61+
self.logger.info('>>>>>> SKBT >>>>>>')
62+
self.run_bench(X, r, dist, SkBallTree)
63+
self.run_bench(X, r, dist, SkBallTree, leaf_radius=r)
64+
self.logger.info('')
6265

6366
def test_cover_digits(self):
6467
X, _ = load_digits(return_X_y=True)
6568
#X = PCA(n_components=3).fit_transform(X)
6669
for r in [1.0, 10.0, 100.0]:
67-
print(f'============= r: {r} =============')
68-
print('>>>>>>> HVPT >>>>>>')
69-
run(X, r, dist, HVPT, leaf_radius=r, pivoting='random')
70-
run(X, r, dist, HVPT, leaf_radius=r, pivoting='furthest')
71-
print('>>>>>>> FVPT >>>>>>')
72-
run(X, r, dist, FVPT, leaf_radius=r, pivoting='random')
73-
run(X, r, dist, FVPT, leaf_radius=r, pivoting='furthest')
74-
print('>>>>>> SKBT >>>>>>')
75-
run(X, r, dist, SkBallTree)
76-
run(X, r, dist, SkBallTree, leaf_radius=r)
77-
print('')
70+
self.logger.info(f'======= Cover Bench Digits =======')
71+
self.logger.info(f'[r: {r}]')
72+
self.logger.info('>>>>>>> HVPT >>>>>>')
73+
self.run_bench(X, r, dist, HVPT, leaf_radius=r, pivoting='random')
74+
self.run_bench(X, r, dist, HVPT, leaf_radius=r, pivoting='furthest')
75+
self.logger.info('>>>>>>> FVPT >>>>>>')
76+
self.run_bench(X, r, dist, FVPT, leaf_radius=r, pivoting='random')
77+
self.run_bench(X, r, dist, FVPT, leaf_radius=r, pivoting='furthest')
78+
self.logger.info('>>>>>> SKBT >>>>>>')
79+
self.run_bench(X, r, dist, SkBallTree)
80+
self.run_bench(X, r, dist, SkBallTree, leaf_radius=r)
81+
self.logger.info('')

tests/test_bench_metrics.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
import logging
23
import timeit
34

45
import pandas as pd
@@ -8,6 +9,8 @@
89

910
import tdamapper.utils.metrics as metrics
1011

12+
import tests.setup_logging
13+
1114

1215
@numba.njit(fastmath=True)
1316
def euclidean_numpy(a, b):
@@ -106,7 +109,9 @@ def run_bench(X):
106109

107110
class TestBenchMetrics(unittest.TestCase):
108111

112+
logger = logging.getLogger(__name__)
113+
109114
def test_bench(self):
110115
X = np.random.rand(1000, 1000)
111116
df_bench = run_bench(X)
112-
print(df_bench)
117+
self.logger.info(df_bench)

tests/test_bench_vptree.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from tdamapper.utils.vptree_flat import VPTree as FVPT
1111

1212
from tests.ball_tree import SkBallTree
13+
import tests.setup_logging
1314

1415

1516
dist = euclidean()
@@ -30,11 +31,6 @@ class TestBenchmark(unittest.TestCase):
3031

3132
logger = logging.getLogger(__name__)
3233

33-
logging.basicConfig(
34-
format = '%(asctime)s %(module)s %(levelname)s: %(message)s',
35-
datefmt = '%m/%d/%Y %I:%M:%S %p',
36-
level = logging.INFO)
37-
3834
def test_bench(self):
3935
self.logger.info('==== Dataset random =============')
4036
self._test_compare(dataset())

0 commit comments

Comments
 (0)