Skip to content

Commit 3e07f19

Browse files
committed
Added profiling decorator. Improved tests
1 parent 93d3a98 commit 3e07f19

6 files changed

Lines changed: 51 additions & 6 deletions

File tree

benchmarks/benchmark.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
from tdamapper.core import TrivialClustering
1313

1414

15+
def _identity(x):
16+
return x
17+
18+
1519
def _segment(cardinality, dimension, noise=0.1, start=None, end=None):
1620
if start is None:
1721
start = np.zeros(dimension)
@@ -70,7 +74,7 @@ def fit(self, X, y=None):
7074
def run_gm(X, n, p):
7175
t0 = time.time()
7276
pipe = gm.make_mapper_pipeline(
73-
filter_func=lambda x: x,
77+
filter_func=_identity,
7478
cover=gm.CubicalCover(n_intervals=n, overlap_frac=p),
7579
clusterer=TrivialEstimator(),
7680
)

src/tdamapper/_common.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
This module provides common functionalities for internal use.
33
"""
44

5+
import cProfile
6+
import io
7+
import pstats
58
import warnings
69

710
import numpy as np
@@ -147,3 +150,22 @@ def clone(obj):
147150
obj_noargs = type(obj)()
148151
obj_noargs.set_params(**params)
149152
return obj_noargs
153+
154+
155+
def profile(n_lines=10):
156+
def decorator(func):
157+
def wrapper(*args, **kwargs):
158+
profiler = cProfile.Profile()
159+
profiler.enable()
160+
result = func(*args, **kwargs)
161+
profiler.disable()
162+
163+
s = io.StringIO()
164+
ps = pstats.Stats(profiler, stream=s).sort_stats("cumulative")
165+
ps.print_stats(n_lines)
166+
print(s.getvalue())
167+
return result
168+
169+
return wrapper
170+
171+
return decorator

src/tdamapper/utils/metrics.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,11 @@ def minkowski(p):
114114
return euclidean()
115115
elif np.isinf(p):
116116
return chebyshev()
117-
return lambda x, y: _metrics.minkowski(p, x, y)
117+
118+
def dist(x, y):
119+
return _metrics.minkowski(p, x, y)
120+
121+
return dist
118122

119123

120124
def cosine():

tests/test_bench_cover.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
from sklearn.datasets import load_digits
77

8+
from tdamapper._common import profile
89
from tdamapper.utils.metrics import euclidean
910
from tdamapper.utils.vptree_flat import VPTree as FVPT
1011
from tdamapper.utils.vptree_hier import VPTree as HVPT
@@ -21,6 +22,10 @@ def dataset(dim=10, num=1000):
2122
return [np.random.rand(dim) for _ in range(num)]
2223

2324

25+
def dist_proj(x, y):
26+
return dist(x[1:], x[1:])
27+
28+
2429
class TestVpSettings(unittest.TestCase):
2530

2631
setup_logging()
@@ -39,11 +44,12 @@ def cover(self, vpt, X, r):
3944
def run_bench(self, X, r, dist, vp, **kwargs):
4045
XX = np.array([[i] + [xi for xi in x] for i, x in enumerate(X)])
4146
t0 = time.time()
42-
vpt = vp(XX, metric=lambda x, y: dist(x[1:], y[1:]), **kwargs)
47+
vpt = vp(XX, metric=dist_proj, **kwargs)
4348
list(self.cover(vpt, XX, r))
4449
t1 = time.time()
4550
self.logger.info(f"time: {t1 - t0}")
4651

52+
@profile(n_lines=20)
4753
def test_cover_random(self):
4854
for r in [1.0, 10.0, 100.0]:
4955
for n in [100, 1000, 10000]:
@@ -61,6 +67,7 @@ def test_cover_random(self):
6167
self.run_bench(X, r, dist, SkBallTree, leaf_radius=r)
6268
self.logger.info("")
6369

70+
@profile(n_lines=20)
6471
def test_cover_digits(self):
6572
X, _ = load_digits(return_X_y=True)
6673
# X = PCA(n_components=3).fit_transform(X)

tests/test_bench_vptree.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,11 @@ def _test_knn_search_naive(self, data, name):
9393
d(np.array([0.0]), np.array([0.0])) # jit-compile numba
9494
t0 = time()
9595
for val in data:
96-
data.sort(key=lambda x: d(x, val))
96+
97+
def _dist_key(x):
98+
return d(x, val)
99+
100+
data.sort(key=_dist_key)
97101
[x for x in data[: self.k]]
98102
t1 = time()
99103
self.logger.info(f"{name}: {t1 - t0}")

tests/test_unit_proximity.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@ def dataset(dim=1, num=10000):
1010
return [np.random.rand(dim) for _ in range(num)]
1111

1212

13+
def absdist(x, y):
14+
return abs(x - y)
15+
16+
1317
class TestProximity(unittest.TestCase):
1418

1519
def test_ball_proximity(self):
1620
data = list(range(100))
17-
cover = BallCover(radius=10, metric=lambda x, y: abs(x - y))
21+
cover = BallCover(radius=10, metric=absdist)
1822
cover.fit(data)
1923
for x in data:
2024
result = cover.search(x)
@@ -23,7 +27,7 @@ def test_ball_proximity(self):
2327

2428
def test_knn_proximity(self):
2529
data = list(range(100))
26-
cover = KNNCover(neighbors=11, metric=lambda x, y: abs(x - y))
30+
cover = KNNCover(neighbors=11, metric=absdist)
2731
cover.fit(data)
2832
for x in range(5, 94):
2933
result = cover.search(x)

0 commit comments

Comments
 (0)