Skip to content

Commit e9c0e24

Browse files
committed
add pocketfft sdp and test
1 parent b581b38 commit e9c0e24

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

stumpy/sdp.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,38 @@ def _oaconvolve_sliding_dot_product(Q, T):
7575
return oaconvolve(np.ascontiguousarray(Q[::-1]), T, mode="valid")
7676

7777

78+
def _pocketfft_sliding_dot_product(Q, T):
79+
"""
80+
Use scipy.fft._pocketfft to compute
81+
the sliding dot product.
82+
83+
Parameters
84+
----------
85+
Q : numpy.ndarray
86+
Query array or subsequence
87+
88+
T : numpy.ndarray
89+
Time series or sequence
90+
91+
Returns
92+
-------
93+
output : numpy.ndarray
94+
Sliding dot product between `Q` and `T`.
95+
"""
96+
n = len(T)
97+
m = len(Q)
98+
next_fast_n = next_fast_len(n, real=True)
99+
100+
tmp = np.empty((2, next_fast_n))
101+
tmp[0, :m] = Q[::-1]
102+
tmp[0, m:] = 0.0
103+
tmp[1, :n] = T
104+
tmp[1, n:] = 0.0
105+
fft_2d = r2c(True, tmp, axis=-1)
106+
107+
return c2r(False, np.multiply(fft_2d[0], fft_2d[1]), n=next_fast_n)[m - 1 : n]
108+
109+
78110
def _sliding_dot_product(Q, T):
79111
"""
80112
Compute the sliding dot product between `Q` and `T`

tests/test_sdp.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ def test_oaconvolve_sliding_dot_product(Q, T):
3636
npt.assert_almost_equal(ref_mp, comp_mp)
3737

3838

39+
@pytest.mark.parametrize("Q, T", test_data)
40+
def test_pocketfft_sliding_dot_product(Q, T):
41+
ref_mp = naive.rolling_window_dot_product(Q, T)
42+
comp_mp = sdp._pocketfft_sliding_dot_product(Q, T)
43+
npt.assert_almost_equal(ref_mp, comp_mp)
44+
45+
3946
@pytest.mark.parametrize("Q, T", test_data)
4047
def test_sliding_dot_product(Q, T):
4148
ref_mp = naive.rolling_window_dot_product(Q, T)

0 commit comments

Comments
 (0)