|
4 | 4 |
|
5 | 5 | class SLIDING_DOT_PRODUCT: |
6 | 6 | # https://stackoverflow.com/a/30615425/2955541 |
7 | | - def __init__(self): |
8 | | - self.m = 0 |
9 | | - self.n = 0 |
10 | | - self.threads = 1 |
11 | | - self.rfft_Q_obj = None |
12 | | - self.rfft_T_obj = None |
13 | | - self.irfft_obj = None |
14 | | - |
15 | | - def __call__(self, Q, T): |
16 | | - if Q.shape[0] != self.m or T.shape[0] != self.n: |
17 | | - self.m = Q.shape[0] |
18 | | - self.n = T.shape[0] |
19 | | - shape = pyfftw.next_fast_len(self.n) |
20 | | - self.rfft_Q_obj = pyfftw.builders.rfft( |
21 | | - np.empty(self.m), overwrite_input=True, n=shape, threads=self.threads |
| 7 | + def __init__(self, max_n=2**20): |
| 8 | + """ |
| 9 | + Parameters |
| 10 | + ---------- |
| 11 | + max_n : int |
| 12 | + Maximum length to preallocate arrays for. This will be the size of the |
| 13 | + the real-valued array. A complex-valued array of size `1 + (max_n // 2)` |
| 14 | + will also be preallocated. |
| 15 | + """ |
| 16 | + # Preallocate arrays |
| 17 | + self.real_arr = pyfftw.empty_aligned(max_n, dtype="float64") |
| 18 | + self.complex_arr = pyfftw.empty_aligned(1 + (max_n // 2), dtype="complex128") |
| 19 | + |
| 20 | + # Store FFTW objects, keyed by (next_fast_n, n_threads, planning_flag) |
| 21 | + self.rfft_objects = {} |
| 22 | + self.irfft_objects = {} |
| 23 | + |
| 24 | + def __call__(self, Q, T, n_threads=1, planning_flag="FFTW_MEASURE"): |
| 25 | + """ |
| 26 | + Compute the sliding dot product between `Q` and `T` using FFTW via pyfftw. |
| 27 | +
|
| 28 | + Parameters |
| 29 | + ---------- |
| 30 | + Q : numpy.ndarray |
| 31 | + Query array or subsequence. |
| 32 | +
|
| 33 | + T : numpy.ndarray |
| 34 | + Time series or sequence. |
| 35 | +
|
| 36 | + n_threads : int, default=1 |
| 37 | + Number of threads to use for FFTW computations. |
| 38 | +
|
| 39 | + planning_flag : str, default="FFTW_MEASURE" |
| 40 | + The planning flag that will be used in FFTW for planning. |
| 41 | + See pyfftw documentation for details. Current options include: |
| 42 | + "FFTW_ESTIMATE", "FFTW_MEASURE", "FFTW_PATIENT", and "FFTW_EXHAUSTIVE". |
| 43 | +
|
| 44 | + Returns |
| 45 | + ------- |
| 46 | + out : numpy.ndarray |
| 47 | + Sliding dot product between `Q` and `T`. |
| 48 | + """ |
| 49 | + m = Q.shape[0] |
| 50 | + n = T.shape[0] |
| 51 | + next_fast_n = pyfftw.next_fast_len(n) |
| 52 | + |
| 53 | + # Update preallocated arrays if needed |
| 54 | + if next_fast_n > len(self.real_arr): |
| 55 | + self.real_arr = pyfftw.empty_aligned(next_fast_n, dtype="float64") |
| 56 | + self.complex_arr = pyfftw.empty_aligned( |
| 57 | + 1 + (next_fast_n // 2), dtype="complex128" |
22 | 58 | ) |
23 | | - self.rfft_T_obj = pyfftw.builders.rfft( |
24 | | - np.empty(self.n), overwrite_input=True, n=shape, threads=self.threads |
| 59 | + |
| 60 | + real_arr = self.real_arr[:next_fast_n] |
| 61 | + complex_arr = self.complex_arr[: 1 + (next_fast_n // 2)] |
| 62 | + |
| 63 | + # Get or create FFTW objects |
| 64 | + key = (next_fast_n, n_threads, planning_flag) |
| 65 | + |
| 66 | + rfft_obj = self.rfft_objects.get(key, None) |
| 67 | + if rfft_obj is None: |
| 68 | + rfft_obj = pyfftw.FFTW( |
| 69 | + input_array=real_arr, |
| 70 | + output_array=complex_arr, |
| 71 | + direction="FFTW_FORWARD", |
| 72 | + flags=(planning_flag,), |
| 73 | + threads=n_threads, |
25 | 74 | ) |
26 | | - self.irfft_obj = pyfftw.builders.irfft( |
27 | | - self.rfft_Q_obj.output_array, |
28 | | - overwrite_input=True, |
29 | | - n=shape, |
30 | | - threads=self.threads, |
| 75 | + self.rfft_objects[key] = rfft_obj |
| 76 | + else: |
| 77 | + rfft_obj.update_arrays(real_arr, complex_arr) |
| 78 | + |
| 79 | + irfft_obj = self.irfft_objects.get(key, None) |
| 80 | + if irfft_obj is None: |
| 81 | + irfft_obj = pyfftw.FFTW( |
| 82 | + input_array=complex_arr, |
| 83 | + output_array=real_arr, |
| 84 | + direction="FFTW_BACKWARD", |
| 85 | + flags=(planning_flag, "FFTW_DESTROY_INPUT"), |
| 86 | + threads=n_threads, |
31 | 87 | ) |
| 88 | + self.irfft_objects[key] = irfft_obj |
| 89 | + else: |
| 90 | + irfft_obj.update_arrays(complex_arr, real_arr) |
| 91 | + |
| 92 | + # RFFT(T) |
| 93 | + real_arr[:n] = T |
| 94 | + real_arr[n:] = 0.0 |
| 95 | + rfft_obj.execute() # output is in complex_arr |
| 96 | + complex_arr_T = complex_arr.copy() |
| 97 | + |
| 98 | + # RFFT(Q) |
| 99 | + # Scale by 1/next_fast_n to account for |
| 100 | + # FFTW's unnormalized inverse FFT via execute() |
| 101 | + real_arr[:m] = Q[::-1] / next_fast_n |
| 102 | + real_arr[m:] = 0.0 |
| 103 | + rfft_obj.execute() # output is in complex_arr |
| 104 | + |
| 105 | + # RFFT(T) * RFFT(Q) |
| 106 | + np.multiply(complex_arr, complex_arr_T, out=complex_arr) |
32 | 107 |
|
33 | | - Qr = Q[::-1] # Reverse/flip Q |
34 | | - rfft_padded_Q = self.rfft_Q_obj(Qr) |
35 | | - rfft_padded_T = self.rfft_T_obj(T) |
| 108 | + # IRFFT (input is in complex_arr) |
| 109 | + irfft_obj.execute() # output is in real_arr |
36 | 110 |
|
37 | | - return self.irfft_obj(np.multiply(rfft_padded_Q, rfft_padded_T)).real[ |
38 | | - self.m - 1 : self.n |
39 | | - ] |
| 111 | + return real_arr[m - 1 : n] |
40 | 112 |
|
41 | 113 |
|
42 | 114 | _sliding_dot_product = SLIDING_DOT_PRODUCT() |
43 | 115 |
|
44 | 116 |
|
45 | | -def setup(Q, T): |
46 | | - _sliding_dot_product(Q, T) |
| 117 | +def setup(Q, T, n_threads=1, planning_flag="FFTW_MEASURE"): |
| 118 | + _sliding_dot_product(Q, T, n_threads=n_threads, planning_flag=planning_flag) |
47 | 119 | return |
48 | 120 |
|
49 | 121 |
|
50 | | -def sliding_dot_product(Q, T): |
51 | | - return _sliding_dot_product(Q, T) |
| 122 | +def sliding_dot_product(Q, T, n_threads=1, planning_flag="FFTW_MEASURE"): |
| 123 | + return _sliding_dot_product(Q, T, n_threads=n_threads, planning_flag=planning_flag) |
0 commit comments