Skip to content

Commit 203869e

Browse files
committed
add sdp functions and their tests
1 parent 3c70d47 commit 203869e

File tree

2 files changed

+208
-0
lines changed

2 files changed

+208
-0
lines changed

stumpy/sdp.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import numpy as np
22
from numba import njit
3+
from scipy.fft import next_fast_len
4+
from scipy.fft._pocketfft.basic import c2r, r2c
35
from scipy.signal import convolve
46

57
from . import config
@@ -65,3 +67,35 @@ def _convolve_sliding_dot_product(Q, T):
6567
# sequences fully overlap.
6668

6769
return convolve(np.flipud(Q), T, mode="valid")
70+
71+
72+
def _pocketfft_sliding_dot_product(Q, T):
73+
"""
74+
Use scipy.fft._pocketfft to compute
75+
the sliding dot product.
76+
77+
Parameters
78+
----------
79+
Q : numpy.ndarray
80+
Query array or subsequence
81+
82+
T : numpy.ndarray
83+
Time series or sequence
84+
85+
Returns
86+
-------
87+
output : numpy.ndarray
88+
Sliding dot product between `Q` and `T`.
89+
"""
90+
n = len(T)
91+
m = len(Q)
92+
next_fast_n = next_fast_len(n, real=True)
93+
94+
tmp = np.empty((2, next_fast_n))
95+
tmp[0, :m] = Q[::-1]
96+
tmp[0, m:] = 0.0
97+
tmp[1, :n] = T
98+
tmp[1, n:] = 0.0
99+
fft_2d = r2c(True, tmp, axis=-1)
100+
101+
return c2r(False, np.multiply(fft_2d[0], fft_2d[1]), n=next_fast_n)[m - 1 : n]

tests/test_sdp.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import inspect
2+
import warnings
3+
from operator import eq, lt
4+
5+
import numpy as np
6+
import pytest
7+
from numpy import testing as npt
8+
from scipy.fft import next_fast_len
9+
10+
from stumpy import sdp
11+
12+
# README
13+
# Real FFT algorithm performs more efficiently when the length
14+
# of the input array `arr` is composed of small prime factors.
15+
# The next_fast_len(arr, real=True) function from Scipy returns
16+
# the same length if len(arr) is composed of a subset of
17+
# prime numbers 2, 3, 5. Therefore, these radices are
18+
# considered as the most efficient for the real FFT algorithm.
19+
20+
# To ensure that the tests cover different cases, the following cases
21+
# are considered:
22+
# 1. len(T) is even, and len(T) == next_fast_len(len(T), real=True)
23+
# 2. len(T) is odd, and len(T) == next_fast_len(len(T), real=True)
24+
# 3. len(T) is even, and len(T) < next_fast_len(len(T), real=True)
25+
# 4. len(T) is odd, and len(T) < next_fast_len(len(T), real=True)
26+
# And 5. a special case of 1, where len(T) is power of 2.
27+
28+
# Therefore:
29+
# 1. len(T) is composed of 2 and a subset of {3, 5}
30+
# 2. len(T) is composed of a subset of {3, 5}
31+
# 3. len(T) is composed of a subset of {7, 11, 13, ...} and 2
32+
# 4. len(T) is composed of a subset of {7, 11, 13, ...}
33+
# 5. len(T) is power of 2
34+
35+
# In some cases, the prime factors are raised to a power of
36+
# certain degree to increase the length of array to be around
37+
# 1000-2000. This allows us to test sliding_dot_product for
38+
# wider range of query lengths.
39+
40+
test_inputs = [
41+
# Input format:
42+
# (
43+
# len(T),
44+
# remainder, # from `len(T) % 2`
45+
# comparator, # for len(T) comparator next_fast_len(len(T), real=True)
46+
# )
47+
(
48+
2 * (3**2) * (5**3),
49+
0,
50+
eq,
51+
), # = 2250, Even `len(T)`, and `len(T) == next_fast_len(len(T), real=True)`
52+
(
53+
(3**2) * (5**3),
54+
1,
55+
eq,
56+
), # = 1125, Odd `len(T)`, and `len(T) == next_fast_len(len(T), real=True)`.
57+
(
58+
2 * 7 * 11 * 13,
59+
0,
60+
lt,
61+
), # = 2002, Even `len(T)`, and `len(T) < next_fast_len(len(T), real=True)`
62+
(
63+
7 * 11 * 13,
64+
1,
65+
lt,
66+
), # = 1001, Odd `len(T)`, and `len(T) < next_fast_len(len(T), real=True)`
67+
]
68+
69+
70+
def naive_sliding_dot_product(Q, T):
71+
m = len(Q)
72+
l = T.shape[0] - m + 1
73+
out = np.empty(l)
74+
for i in range(l):
75+
out[i] = np.dot(Q, T[i : i + m])
76+
return out
77+
78+
79+
def get_sdp_functions():
80+
out = []
81+
for func_name, func in inspect.getmembers(sdp, inspect.isfunction):
82+
if func_name.endswith("sliding_dot_product"):
83+
out.append((func_name, func))
84+
85+
return out
86+
87+
88+
@pytest.mark.parametrize("n_T, remainder, comparator", test_inputs)
89+
def test_remainder(n_T, remainder, comparator):
90+
assert n_T % 2 == remainder
91+
92+
93+
@pytest.mark.parametrize("n_T, remainder, comparator", test_inputs)
94+
def test_comparator(n_T, remainder, comparator):
95+
shape = next_fast_len(n_T, real=True)
96+
assert comparator(n_T, shape)
97+
98+
99+
@pytest.mark.parametrize("n_T, remainder, comparator", test_inputs)
100+
def test_sdp(n_T, remainder, comparator):
101+
# test_sdp for cases 1-4
102+
103+
n_Q_prime = [
104+
2,
105+
3,
106+
5,
107+
7,
108+
11,
109+
13,
110+
17,
111+
19,
112+
23,
113+
29,
114+
31,
115+
37,
116+
41,
117+
43,
118+
47,
119+
53,
120+
59,
121+
61,
122+
67,
123+
71,
124+
73,
125+
79,
126+
83,
127+
89,
128+
97,
129+
]
130+
n_Q_power2 = [2, 4, 8, 16, 32, 64]
131+
n_Q_values = n_Q_prime + n_Q_power2 + [n_T]
132+
n_Q_values = sorted(n_Q for n_Q in set(n_Q_values) if n_Q <= n_T)
133+
134+
# utils.import_sdp_mods()
135+
for n_Q in n_Q_values:
136+
Q = np.random.rand(n_Q)
137+
T = np.random.rand(n_T)
138+
ref = naive_sliding_dot_product(Q, T)
139+
for func_name, func in get_sdp_functions():
140+
try:
141+
comp = func(Q, T)
142+
npt.assert_allclose(comp, ref)
143+
except Exception as e: # pragma: no cover
144+
msg = f"Error in {func_name}, with n_Q={n_Q} and n_T={n_T}"
145+
warnings.warn(msg)
146+
raise e
147+
148+
return
149+
150+
151+
def test_sdp_power2():
152+
# test for case 5. len(T) is power of 2
153+
pmin = 3
154+
pmax = 13
155+
156+
for func_name, func in get_sdp_functions():
157+
try:
158+
for q in range(pmin, pmax + 1):
159+
n_Q = 2**q
160+
for p in range(q, pmax + 1):
161+
n_T = 2**p
162+
Q = np.random.rand(n_Q)
163+
T = np.random.rand(n_T)
164+
165+
ref = naive_sliding_dot_product(Q, T)
166+
comp = func(Q, T)
167+
npt.assert_allclose(comp, ref)
168+
169+
except Exception as e: # pragma: no cover
170+
msg = f"Error in {func_name}, with q={q} and p={p}"
171+
warnings.warn(msg)
172+
raise e
173+
174+
return

0 commit comments

Comments
 (0)