|
6 | 6 | from functools import partial |
7 | 7 |
|
8 | 8 | import numpy as np |
9 | | -from scipy.signal import spectrogram |
| 9 | +from scipy.fft import rfft |
| 10 | +from scipy.signal import get_window, spectrogram |
10 | 11 |
|
11 | 12 | from ..fixes import _reshape_view |
12 | 13 | from ..parallel import parallel_func |
@@ -60,17 +61,51 @@ def _decomp_aggregate_mask(epoch, func, average, freq_sl): |
60 | 61 | return spect |
61 | 62 |
|
62 | 63 |
|
| 64 | +def _welch_mean(epoch, fs, window, nperseg, nfft, noverlap, detrend, freq_sl): |
| 65 | + """Compute Welch PSD with mean averaging via batched per-segment FFT. |
| 66 | +
|
| 67 | + Processes one segment at a time across all rows simultaneously, avoiding |
| 68 | + the overhead of calling scipy.signal.spectrogram per-row. Each segment |
| 69 | + iteration uses a single batched rfft call over all rows in the chunk. |
| 70 | + """ |
| 71 | + win = get_window(window, nperseg) |
| 72 | + scale = 1.0 / (fs * np.dot(win, win)) |
| 73 | + step = nperseg - noverlap |
| 74 | + seg_starts = np.arange(0, epoch.shape[-1] - nperseg + 1, step) |
| 75 | + n_seg = len(seg_starts) |
| 76 | + n_freqs = nfft // 2 + 1 |
| 77 | + row_bytes = epoch[0].nbytes |
| 78 | + chunk_size = max(1, int(10e6 / row_bytes)) |
| 79 | + n_rows = epoch.shape[0] |
| 80 | + result = np.empty((n_rows, n_freqs)) |
| 81 | + for r0 in range(0, n_rows, chunk_size): |
| 82 | + r1 = min(r0 + chunk_size, n_rows) |
| 83 | + psd_acc = np.zeros((r1 - r0, n_freqs)) |
| 84 | + for s in seg_starts: |
| 85 | + seg = epoch[r0:r1, s : s + nperseg].copy() |
| 86 | + if detrend: |
| 87 | + seg -= seg.mean(axis=-1, keepdims=True) |
| 88 | + seg *= win |
| 89 | + ft = rfft(seg, n=nfft, axis=-1) |
| 90 | + psd_acc += np.real(ft * ft.conj()) |
| 91 | + psd_acc *= scale / n_seg |
| 92 | + # One-sided spectrum: double non-DC/Nyquist components |
| 93 | + if nfft % 2: |
| 94 | + psd_acc[:, 1:] *= 2 |
| 95 | + else: |
| 96 | + psd_acc[:, 1:-1] *= 2 |
| 97 | + result[r0:r1] = psd_acc |
| 98 | + return result[:, freq_sl] |
| 99 | + |
| 100 | + |
63 | 101 | def _spect_func(epoch, func, freq_sl, average, *, output="power"): |
64 | 102 | """Aux function.""" |
65 | | - # Process in chunks to balance vectorization (scipy.signal.spectrogram |
66 | | - # handles multi-row input efficiently) against memory usage. |
67 | 103 | kwargs = dict(func=func, average=average, freq_sl=freq_sl) |
68 | 104 | if epoch.nbytes > 10e6: |
69 | 105 | # Process in chunks of rows instead of one-by-one. Each chunk is |
70 | 106 | # passed to spectrogram as a 2D array, which is much faster than |
71 | 107 | # calling spectrogram per-row via np.apply_along_axis. |
72 | 108 | n_rows = epoch.shape[0] |
73 | | - # Target ~10 MB per chunk (same threshold as the original code) |
74 | 109 | row_bytes = epoch[0].nbytes |
75 | 110 | chunk_size = max(1, int(10e6 / row_bytes)) |
76 | 111 | parts = [] |
@@ -252,69 +287,82 @@ def psd_array_welch( |
252 | 287 | f"{n_overlap} overlap and {window} window" |
253 | 288 | ) |
254 | 289 |
|
255 | | - parallel, my_spect_func, n_jobs = parallel_func(_spect_func, n_jobs=n_jobs) |
256 | | - _func = partial( |
257 | | - spectrogram, |
258 | | - detrend=detrend, |
259 | | - noverlap=n_overlap, |
260 | | - nperseg=n_per_seg, |
261 | | - nfft=n_fft, |
262 | | - fs=sfreq, |
263 | | - window=window, |
264 | | - mode=mode, |
265 | | - ) |
266 | | - if nan_present and aligned_nan: |
267 | | - # Aligned NaNs across channels → treat as bad annotations. |
268 | | - good_mask = ~nan_mask_full |
269 | | - t_onsets, t_offsets = _mask_to_onsets_offsets(good_mask[0]) |
270 | | - x_splits = [x[..., t_ons:t_off] for t_ons, t_off in zip(t_onsets, t_offsets)] |
271 | | - # weights reflect the number of samples used from each span. For spans longer |
272 | | - # than `n_per_seg`, trailing samples may be discarded. For spans shorter than |
273 | | - # `n_per_seg`, the wrapped function (`scipy.signal.spectrogram`) automatically |
274 | | - # reduces `n_per_seg` to match the span length (with a warning). |
275 | | - step = n_per_seg - n_overlap |
276 | | - span_lengths = [span.shape[-1] for span in x_splits] |
277 | | - weights = [ |
278 | | - w if w < n_per_seg else w - ((w - n_overlap) % step) for w in span_lengths |
279 | | - ] |
280 | | - agg_func = partial(np.average, weights=weights) |
281 | | - if n_jobs > 1: |
282 | | - logger.info( |
283 | | - f"Data split into {len(x_splits)} (probably unequal) chunks due to " |
284 | | - '"bad_*" annotations. Parallelization may be sub-optimal.' |
285 | | - ) |
286 | | - if (np.array(span_lengths) < n_per_seg).any(): |
287 | | - logger.info( |
288 | | - "At least one good data span is shorter than n_per_seg, and will be " |
289 | | - "analyzed with a shorter window than the rest of the file." |
290 | | - ) |
| 290 | + # Fast path: for the common case of mean-averaged PSD without NaN |
| 291 | + # complications, compute directly via batched per-segment FFT. This |
| 292 | + # loops over segments (typically 1-3) rather than rows (potentially |
| 293 | + # 100K+), using a single batched rfft call per segment. |
| 294 | + psds = None |
| 295 | + if average == "mean" and output == "power" and not nan_present and x.nbytes > 10e6: |
| 296 | + psds = _welch_mean( |
| 297 | + x, sfreq, window, n_per_seg, n_fft, n_overlap, detrend, freq_sl |
| 298 | + ) |
291 | 299 |
|
292 | | - def func(*args, **kwargs): |
293 | | - # swallow SciPy warnings caused by short good data spans |
294 | | - with warnings.catch_warnings(): |
295 | | - warnings.filterwarnings( |
296 | | - action="ignore", |
297 | | - module="scipy", |
298 | | - category=UserWarning, |
299 | | - message=r"nperseg = \d+ is greater than input length", |
| 300 | + if psds is None: |
| 301 | + parallel, my_spect_func, n_jobs = parallel_func(_spect_func, n_jobs=n_jobs) |
| 302 | + _func = partial( |
| 303 | + spectrogram, |
| 304 | + detrend=detrend, |
| 305 | + noverlap=n_overlap, |
| 306 | + nperseg=n_per_seg, |
| 307 | + nfft=n_fft, |
| 308 | + fs=sfreq, |
| 309 | + window=window, |
| 310 | + mode=mode, |
| 311 | + ) |
| 312 | + if nan_present and aligned_nan: |
| 313 | + # Aligned NaNs across channels → treat as bad annotations. |
| 314 | + good_mask = ~nan_mask_full |
| 315 | + t_onsets, t_offsets = _mask_to_onsets_offsets(good_mask[0]) |
| 316 | + x_splits = [ |
| 317 | + x[..., t_ons:t_off] for t_ons, t_off in zip(t_onsets, t_offsets) |
| 318 | + ] |
| 319 | + # weights reflect the number of samples used from each span. |
| 320 | + step = n_per_seg - n_overlap |
| 321 | + span_lengths = [span.shape[-1] for span in x_splits] |
| 322 | + weights = [ |
| 323 | + w if w < n_per_seg else w - ((w - n_overlap) % step) |
| 324 | + for w in span_lengths |
| 325 | + ] |
| 326 | + agg_func = partial(np.average, weights=weights) |
| 327 | + if n_jobs > 1: |
| 328 | + logger.info( |
| 329 | + f"Data split into {len(x_splits)} (probably unequal) chunks " |
| 330 | + 'due to "bad_*" annotations. Parallelization may be ' |
| 331 | + "sub-optimal." |
| 332 | + ) |
| 333 | + if (np.array(span_lengths) < n_per_seg).any(): |
| 334 | + logger.info( |
| 335 | + "At least one good data span is shorter than n_per_seg, " |
| 336 | + "and will be analyzed with a shorter window than the rest " |
| 337 | + "of the file." |
300 | 338 | ) |
301 | | - return _func(*args, **kwargs) |
302 | 339 |
|
303 | | - else: |
304 | | - # Either no NaNs, or NaNs are not aligned across channels. |
305 | | - if nan_present and not aligned_nan: |
306 | | - logger.info( |
307 | | - "NaN masks are not aligned across channels; treating NaNs as " |
308 | | - "per-channel contamination." |
309 | | - ) |
310 | | - x_splits = [arr for arr in np.array_split(x, n_jobs) if arr.size != 0] |
311 | | - agg_func = np.concatenate |
312 | | - func = _func |
313 | | - f_spect = parallel( |
314 | | - my_spect_func(d, func=func, freq_sl=freq_sl, average=average, output=output) |
315 | | - for d in x_splits |
316 | | - ) |
317 | | - psds = agg_func(f_spect, axis=0) |
| 340 | + def func(*args, **kwargs): |
| 341 | + # swallow SciPy warnings caused by short good data spans |
| 342 | + with warnings.catch_warnings(): |
| 343 | + warnings.filterwarnings( |
| 344 | + action="ignore", |
| 345 | + module="scipy", |
| 346 | + category=UserWarning, |
| 347 | + message=r"nperseg = \d+ is greater than input length", |
| 348 | + ) |
| 349 | + return _func(*args, **kwargs) |
| 350 | + |
| 351 | + else: |
| 352 | + # Either no NaNs, or NaNs are not aligned across channels. |
| 353 | + if nan_present and not aligned_nan: |
| 354 | + logger.info( |
| 355 | + "NaN masks are not aligned across channels; treating NaNs " |
| 356 | + "as per-channel contamination." |
| 357 | + ) |
| 358 | + x_splits = [arr for arr in np.array_split(x, n_jobs) if arr.size != 0] |
| 359 | + agg_func = np.concatenate |
| 360 | + func = _func |
| 361 | + f_spect = parallel( |
| 362 | + my_spect_func(d, func=func, freq_sl=freq_sl, average=average, output=output) |
| 363 | + for d in x_splits |
| 364 | + ) |
| 365 | + psds = agg_func(f_spect, axis=0) |
318 | 366 | shape = dshape + (len(freqs),) |
319 | 367 | if average is None: |
320 | 368 | shape = shape + (-1,) |
|
0 commit comments