Skip to content

Commit 1f05b9d

Browse files
sharifhsnclaude
andcommitted
PERF: Custom per-segment FFT for Welch PSD mean averaging
For the common case (average='mean', output='power', no NaNs, data > 10 MB), bypass scipy.signal.spectrogram entirely and compute the PSD via batched per-segment rfft. This loops over segments (typically 1-3 for epoched data) instead of rows (potentially 100K+), with a single batched scipy.fft.rfft call per segment. On 320 epochs x 376 channels, psd_array_welch goes from ~170ms (chunked spectrogram, previous commit) to ~116ms (1.5x additional). Combined with the previous commit: 5047ms -> 116ms (43x total). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 200aabd commit 1f05b9d

1 file changed

Lines changed: 112 additions & 64 deletions

File tree

mne/time_frequency/psd.py

Lines changed: 112 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from functools import partial
77

88
import numpy as np
9-
from scipy.signal import spectrogram
9+
from scipy.fft import rfft
10+
from scipy.signal import get_window, spectrogram
1011

1112
from ..fixes import _reshape_view
1213
from ..parallel import parallel_func
@@ -60,17 +61,51 @@ def _decomp_aggregate_mask(epoch, func, average, freq_sl):
6061
return spect
6162

6263

64+
def _welch_mean(epoch, fs, window, nperseg, nfft, noverlap, detrend, freq_sl):
65+
"""Compute Welch PSD with mean averaging via batched per-segment FFT.
66+
67+
Processes one segment at a time across all rows simultaneously, avoiding
68+
the overhead of calling scipy.signal.spectrogram per-row. Each segment
69+
iteration uses a single batched rfft call over all rows in the chunk.
70+
"""
71+
win = get_window(window, nperseg)
72+
scale = 1.0 / (fs * np.dot(win, win))
73+
step = nperseg - noverlap
74+
seg_starts = np.arange(0, epoch.shape[-1] - nperseg + 1, step)
75+
n_seg = len(seg_starts)
76+
n_freqs = nfft // 2 + 1
77+
row_bytes = epoch[0].nbytes
78+
chunk_size = max(1, int(10e6 / row_bytes))
79+
n_rows = epoch.shape[0]
80+
result = np.empty((n_rows, n_freqs))
81+
for r0 in range(0, n_rows, chunk_size):
82+
r1 = min(r0 + chunk_size, n_rows)
83+
psd_acc = np.zeros((r1 - r0, n_freqs))
84+
for s in seg_starts:
85+
seg = epoch[r0:r1, s : s + nperseg].copy()
86+
if detrend:
87+
seg -= seg.mean(axis=-1, keepdims=True)
88+
seg *= win
89+
ft = rfft(seg, n=nfft, axis=-1)
90+
psd_acc += np.real(ft * ft.conj())
91+
psd_acc *= scale / n_seg
92+
# One-sided spectrum: double non-DC/Nyquist components
93+
if nfft % 2:
94+
psd_acc[:, 1:] *= 2
95+
else:
96+
psd_acc[:, 1:-1] *= 2
97+
result[r0:r1] = psd_acc
98+
return result[:, freq_sl]
99+
100+
63101
def _spect_func(epoch, func, freq_sl, average, *, output="power"):
64102
"""Aux function."""
65-
# Process in chunks to balance vectorization (scipy.signal.spectrogram
66-
# handles multi-row input efficiently) against memory usage.
67103
kwargs = dict(func=func, average=average, freq_sl=freq_sl)
68104
if epoch.nbytes > 10e6:
69105
# Process in chunks of rows instead of one-by-one. Each chunk is
70106
# passed to spectrogram as a 2D array, which is much faster than
71107
# calling spectrogram per-row via np.apply_along_axis.
72108
n_rows = epoch.shape[0]
73-
# Target ~10 MB per chunk (same threshold as the original code)
74109
row_bytes = epoch[0].nbytes
75110
chunk_size = max(1, int(10e6 / row_bytes))
76111
parts = []
@@ -252,69 +287,82 @@ def psd_array_welch(
252287
f"{n_overlap} overlap and {window} window"
253288
)
254289

255-
parallel, my_spect_func, n_jobs = parallel_func(_spect_func, n_jobs=n_jobs)
256-
_func = partial(
257-
spectrogram,
258-
detrend=detrend,
259-
noverlap=n_overlap,
260-
nperseg=n_per_seg,
261-
nfft=n_fft,
262-
fs=sfreq,
263-
window=window,
264-
mode=mode,
265-
)
266-
if nan_present and aligned_nan:
267-
# Aligned NaNs across channels → treat as bad annotations.
268-
good_mask = ~nan_mask_full
269-
t_onsets, t_offsets = _mask_to_onsets_offsets(good_mask[0])
270-
x_splits = [x[..., t_ons:t_off] for t_ons, t_off in zip(t_onsets, t_offsets)]
271-
# weights reflect the number of samples used from each span. For spans longer
272-
# than `n_per_seg`, trailing samples may be discarded. For spans shorter than
273-
# `n_per_seg`, the wrapped function (`scipy.signal.spectrogram`) automatically
274-
# reduces `n_per_seg` to match the span length (with a warning).
275-
step = n_per_seg - n_overlap
276-
span_lengths = [span.shape[-1] for span in x_splits]
277-
weights = [
278-
w if w < n_per_seg else w - ((w - n_overlap) % step) for w in span_lengths
279-
]
280-
agg_func = partial(np.average, weights=weights)
281-
if n_jobs > 1:
282-
logger.info(
283-
f"Data split into {len(x_splits)} (probably unequal) chunks due to "
284-
'"bad_*" annotations. Parallelization may be sub-optimal.'
285-
)
286-
if (np.array(span_lengths) < n_per_seg).any():
287-
logger.info(
288-
"At least one good data span is shorter than n_per_seg, and will be "
289-
"analyzed with a shorter window than the rest of the file."
290-
)
290+
# Fast path: for the common case of mean-averaged PSD without NaN
291+
# complications, compute directly via batched per-segment FFT. This
292+
# loops over segments (typically 1-3) rather than rows (potentially
293+
# 100K+), using a single batched rfft call per segment.
294+
psds = None
295+
if average == "mean" and output == "power" and not nan_present and x.nbytes > 10e6:
296+
psds = _welch_mean(
297+
x, sfreq, window, n_per_seg, n_fft, n_overlap, detrend, freq_sl
298+
)
291299

292-
def func(*args, **kwargs):
293-
# swallow SciPy warnings caused by short good data spans
294-
with warnings.catch_warnings():
295-
warnings.filterwarnings(
296-
action="ignore",
297-
module="scipy",
298-
category=UserWarning,
299-
message=r"nperseg = \d+ is greater than input length",
300+
if psds is None:
301+
parallel, my_spect_func, n_jobs = parallel_func(_spect_func, n_jobs=n_jobs)
302+
_func = partial(
303+
spectrogram,
304+
detrend=detrend,
305+
noverlap=n_overlap,
306+
nperseg=n_per_seg,
307+
nfft=n_fft,
308+
fs=sfreq,
309+
window=window,
310+
mode=mode,
311+
)
312+
if nan_present and aligned_nan:
313+
# Aligned NaNs across channels → treat as bad annotations.
314+
good_mask = ~nan_mask_full
315+
t_onsets, t_offsets = _mask_to_onsets_offsets(good_mask[0])
316+
x_splits = [
317+
x[..., t_ons:t_off] for t_ons, t_off in zip(t_onsets, t_offsets)
318+
]
319+
# weights reflect the number of samples used from each span.
320+
step = n_per_seg - n_overlap
321+
span_lengths = [span.shape[-1] for span in x_splits]
322+
weights = [
323+
w if w < n_per_seg else w - ((w - n_overlap) % step)
324+
for w in span_lengths
325+
]
326+
agg_func = partial(np.average, weights=weights)
327+
if n_jobs > 1:
328+
logger.info(
329+
f"Data split into {len(x_splits)} (probably unequal) chunks "
330+
'due to "bad_*" annotations. Parallelization may be '
331+
"sub-optimal."
332+
)
333+
if (np.array(span_lengths) < n_per_seg).any():
334+
logger.info(
335+
"At least one good data span is shorter than n_per_seg, "
336+
"and will be analyzed with a shorter window than the rest "
337+
"of the file."
300338
)
301-
return _func(*args, **kwargs)
302339

303-
else:
304-
# Either no NaNs, or NaNs are not aligned across channels.
305-
if nan_present and not aligned_nan:
306-
logger.info(
307-
"NaN masks are not aligned across channels; treating NaNs as "
308-
"per-channel contamination."
309-
)
310-
x_splits = [arr for arr in np.array_split(x, n_jobs) if arr.size != 0]
311-
agg_func = np.concatenate
312-
func = _func
313-
f_spect = parallel(
314-
my_spect_func(d, func=func, freq_sl=freq_sl, average=average, output=output)
315-
for d in x_splits
316-
)
317-
psds = agg_func(f_spect, axis=0)
340+
def func(*args, **kwargs):
341+
# swallow SciPy warnings caused by short good data spans
342+
with warnings.catch_warnings():
343+
warnings.filterwarnings(
344+
action="ignore",
345+
module="scipy",
346+
category=UserWarning,
347+
message=r"nperseg = \d+ is greater than input length",
348+
)
349+
return _func(*args, **kwargs)
350+
351+
else:
352+
# Either no NaNs, or NaNs are not aligned across channels.
353+
if nan_present and not aligned_nan:
354+
logger.info(
355+
"NaN masks are not aligned across channels; treating NaNs "
356+
"as per-channel contamination."
357+
)
358+
x_splits = [arr for arr in np.array_split(x, n_jobs) if arr.size != 0]
359+
agg_func = np.concatenate
360+
func = _func
361+
f_spect = parallel(
362+
my_spect_func(d, func=func, freq_sl=freq_sl, average=average, output=output)
363+
for d in x_splits
364+
)
365+
psds = agg_func(f_spect, axis=0)
318366
shape = dshape + (len(freqs),)
319367
if average is None:
320368
shape = shape + (-1,)

0 commit comments

Comments
 (0)