Skip to content

Commit 17b671b

Browse files
Fixed #22 Propose new pyfftw_sdp (#25)
* propose revised pyfftw in challenger * minor fix * update files * improve comments * fix black formating * addressed comments * fixed unnecessary import * add comment to improve clarity * addressed comments * increase the default value of max_n parameter * added test function * renamed test function * fixed minor issue * minor change * Revised comment * addressed comments * minor changes * minor change in docstring * minor change in docstring * add comment to test function to show its purpose * avoid tracking len(T) via attribute
1 parent 027e2c0 commit 17b671b

3 files changed

Lines changed: 123 additions & 33 deletions

File tree

sdp/pyfftw_sdp.py

Lines changed: 104 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,48 +4,120 @@
44

55
class SLIDING_DOT_PRODUCT:
66
# https://stackoverflow.com/a/30615425/2955541
7-
def __init__(self):
8-
self.m = 0
9-
self.n = 0
10-
self.threads = 1
11-
self.rfft_Q_obj = None
12-
self.rfft_T_obj = None
13-
self.irfft_obj = None
14-
15-
def __call__(self, Q, T):
16-
if Q.shape[0] != self.m or T.shape[0] != self.n:
17-
self.m = Q.shape[0]
18-
self.n = T.shape[0]
19-
shape = pyfftw.next_fast_len(self.n)
20-
self.rfft_Q_obj = pyfftw.builders.rfft(
21-
np.empty(self.m), overwrite_input=True, n=shape, threads=self.threads
7+
def __init__(self, max_n=2**20):
8+
"""
9+
Parameters
10+
----------
11+
max_n : int
12+
Maximum length to preallocate arrays for. This will be the size of the
13+
the real-valued array. A complex-valued array of size `1 + (max_n // 2)`
14+
will also be preallocated.
15+
"""
16+
# Preallocate arrays
17+
self.real_arr = pyfftw.empty_aligned(max_n, dtype="float64")
18+
self.complex_arr = pyfftw.empty_aligned(1 + (max_n // 2), dtype="complex128")
19+
20+
# Store FFTW objects, keyed by (next_fast_n, n_threads, planning_flag)
21+
self.rfft_objects = {}
22+
self.irfft_objects = {}
23+
24+
def __call__(self, Q, T, n_threads=1, planning_flag="FFTW_MEASURE"):
25+
"""
26+
Compute the sliding dot product between `Q` and `T` using FFTW via pyfftw.
27+
28+
Parameters
29+
----------
30+
Q : numpy.ndarray
31+
Query array or subsequence.
32+
33+
T : numpy.ndarray
34+
Time series or sequence.
35+
36+
n_threads : int, default=1
37+
Number of threads to use for FFTW computations.
38+
39+
planning_flag : str, default="FFTW_MEASURE"
40+
The planning flag that will be used in FFTW for planning.
41+
See pyfftw documentation for details. Current options include:
42+
"FFTW_ESTIMATE", "FFTW_MEASURE", "FFTW_PATIENT", and "FFTW_EXHAUSTIVE".
43+
44+
Returns
45+
-------
46+
out : numpy.ndarray
47+
Sliding dot product between `Q` and `T`.
48+
"""
49+
m = Q.shape[0]
50+
n = T.shape[0]
51+
next_fast_n = pyfftw.next_fast_len(n)
52+
53+
# Update preallocated arrays if needed
54+
if next_fast_n > len(self.real_arr):
55+
self.real_arr = pyfftw.empty_aligned(next_fast_n, dtype="float64")
56+
self.complex_arr = pyfftw.empty_aligned(
57+
1 + (next_fast_n // 2), dtype="complex128"
2258
)
23-
self.rfft_T_obj = pyfftw.builders.rfft(
24-
np.empty(self.n), overwrite_input=True, n=shape, threads=self.threads
59+
60+
real_arr = self.real_arr[:next_fast_n]
61+
complex_arr = self.complex_arr[: 1 + (next_fast_n // 2)]
62+
63+
# Get or create FFTW objects
64+
key = (next_fast_n, n_threads, planning_flag)
65+
66+
rfft_obj = self.rfft_objects.get(key, None)
67+
if rfft_obj is None:
68+
rfft_obj = pyfftw.FFTW(
69+
input_array=real_arr,
70+
output_array=complex_arr,
71+
direction="FFTW_FORWARD",
72+
flags=(planning_flag,),
73+
threads=n_threads,
2574
)
26-
self.irfft_obj = pyfftw.builders.irfft(
27-
self.rfft_Q_obj.output_array,
28-
overwrite_input=True,
29-
n=shape,
30-
threads=self.threads,
75+
self.rfft_objects[key] = rfft_obj
76+
else:
77+
rfft_obj.update_arrays(real_arr, complex_arr)
78+
79+
irfft_obj = self.irfft_objects.get(key, None)
80+
if irfft_obj is None:
81+
irfft_obj = pyfftw.FFTW(
82+
input_array=complex_arr,
83+
output_array=real_arr,
84+
direction="FFTW_BACKWARD",
85+
flags=(planning_flag, "FFTW_DESTROY_INPUT"),
86+
threads=n_threads,
3187
)
88+
self.irfft_objects[key] = irfft_obj
89+
else:
90+
irfft_obj.update_arrays(complex_arr, real_arr)
91+
92+
# RFFT(T)
93+
real_arr[:n] = T
94+
real_arr[n:] = 0.0
95+
rfft_obj.execute() # output is in complex_arr
96+
complex_arr_T = complex_arr.copy()
97+
98+
# RFFT(Q)
99+
# Scale by 1/next_fast_n to account for
100+
# FFTW's unnormalized inverse FFT via execute()
101+
real_arr[:m] = Q[::-1] / next_fast_n
102+
real_arr[m:] = 0.0
103+
rfft_obj.execute() # output is in complex_arr
104+
105+
# RFFT(T) * RFFT(Q)
106+
np.multiply(complex_arr, complex_arr_T, out=complex_arr)
32107

33-
Qr = Q[::-1] # Reverse/flip Q
34-
rfft_padded_Q = self.rfft_Q_obj(Qr)
35-
rfft_padded_T = self.rfft_T_obj(T)
108+
# IRFFT (input is in complex_arr)
109+
irfft_obj.execute() # output is in real_arr
36110

37-
return self.irfft_obj(np.multiply(rfft_padded_Q, rfft_padded_T)).real[
38-
self.m - 1 : self.n
39-
]
111+
return real_arr[m - 1 : n]
40112

41113

42114
_sliding_dot_product = SLIDING_DOT_PRODUCT()
43115

44116

45-
def setup(Q, T):
46-
_sliding_dot_product(Q, T)
117+
def setup(Q, T, n_threads=1, planning_flag="FFTW_MEASURE"):
118+
_sliding_dot_product(Q, T, n_threads=n_threads, planning_flag=planning_flag)
47119
return
48120

49121

50-
def sliding_dot_product(Q, T):
51-
return _sliding_dot_product(Q, T)
122+
def sliding_dot_product(Q, T, n_threads=1, planning_flag="FFTW_MEASURE"):
123+
return _sliding_dot_product(Q, T, n_threads=n_threads, planning_flag=planning_flag)

test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,21 @@ def test_setup():
192192
raise e
193193

194194
return
195+
196+
197+
def test_pyfftw_sdp_max_n():
198+
# When `len(T)` larger than `max_n` in pyfftw_sdp,
199+
# the internal preallocated arrays should be resized.
200+
# This test checks that functionality.
201+
from sdp.pyfftw_sdp import SLIDING_DOT_PRODUCT
202+
203+
T = np.random.rand(2**12)
204+
Q = np.random.rand(2**8)
205+
206+
sliding_dot_product = SLIDING_DOT_PRODUCT(max_n=2**10)
207+
comp = sliding_dot_product(Q, T)
208+
ref = naive_sliding_dot_product(Q, T)
209+
210+
np.testing.assert_allclose(comp, ref)
211+
212+
return

timing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
if __name__ == "__main__":
1111
parser = argparse.ArgumentParser(
12-
description="./timing.py -noheader -pmin 6 -pmax 23 -pdiff 3 pyfftw challenger"
12+
description="./timing.py -pmin 6 -pmax 23 -pdiff 3 pyfftw challenger"
1313
)
1414
parser.add_argument("-noheader", default=False, action="store_true")
1515
parser.add_argument(

0 commit comments

Comments
 (0)