Skip to content

Commit 5ba4ec3

Browse files
galenlynchclaude
andcommitted
perf: pre-allocate output in _apply_sos / _parallel_reduce_axis1
The previous dispatch had each parallel worker return ``(c0, c1, block)`` tuples; the calling thread then allocated the output array and copied each block into place. That post-collection allocate-and-copy is wasted work since the channel/time slices are non-overlapping — workers can write directly into a pre-allocated output. Measured on a (30000, 384) float32 chunk with sosfiltfilt and n_workers=5: pattern wall (ms) speedup E. sequential 173.89 1.00× A. submit + collect + alloc + copy 75.66 2.30× (current) B. pre-alloc, write in place 60.51 2.87× (this PR) C. pool.map, write in place 63.55 2.74× D. manual threading.Thread 64.76 2.69× So we save ~15 ms wall per `_apply_sos` call (likewise for `_parallel_reduce_axis1`) by dropping the redundant copy. Ideal 5× scaling would be 34.78 ms; the remaining gap to ideal is the GIL-held Python wrapper inside scipy's sosfiltfilt — pattern doesn't matter there (B/C/D are all within noise), so we keep the simpler submit/result form. Same pattern applied to common_reference._parallel_reduce_axis1. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent b7788e4 commit 5ba4ec3

2 files changed

Lines changed: 26 additions & 14 deletions

File tree

src/spikeinterface/preprocessing/common_reference.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,9 @@ def _parallel_reduce_axis1(self, traces):
244244
numpy's partition-based median and BLAS-backed mean release the GIL
245245
during per-row work, so Python-thread parallelism delivers real
246246
speedup (measured ~10× on 16 threads for 1M × 384 median).
247+
248+
Workers write directly into a pre-allocated output array — see
249+
FilterRecordingSegment._apply_sos for the same pattern.
247250
"""
248251
if self.n_workers == 1:
249252
return self.operator_func(traces, axis=1)
@@ -258,15 +261,17 @@ def _parallel_reduce_axis1(self, traces):
258261
block = (T + effective - 1) // effective
259262
bounds = [(t0, min(t0 + block, T)) for t0 in range(0, T, block)]
260263

264+
# Probe dtype: median/mean of a 1×C row gives the same dtype as the
265+
# full reduction.
266+
out_dtype = self.operator_func(traces[:1, :], axis=1).dtype
267+
out = np.empty(T, dtype=out_dtype)
268+
261269
def _work(t0, t1):
262-
return t0, t1, self.operator_func(traces[t0:t1, :], axis=1)
270+
out[t0:t1] = self.operator_func(traces[t0:t1, :], axis=1)
263271

264272
futures = [pool.submit(_work, t0, t1) for t0, t1 in bounds]
265-
results = [fut.result() for fut in futures]
266-
out_dtype = results[0][2].dtype
267-
out = np.empty(T, dtype=out_dtype)
268-
for t0, t1, block_out in results:
269-
out[t0:t1] = block_out
273+
for fut in futures:
274+
fut.result()
270275
return out
271276

272277
def get_traces(self, start_frame, end_frame, channel_indices):

src/spikeinterface/preprocessing/filter.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,12 @@ def _apply_sos(self, fn, traces, axis=0):
241241
implementations of ``sosfiltfilt``/``sosfilt`` release the GIL during
242242
per-column work, so Python-thread parallelism delivers real speedup
243243
(measured ~3× on 8 threads for a 1M × 384 float32 chunk).
244+
245+
Workers write directly into a pre-allocated output array — eliminating
246+
the per-block tuple return + post-loop allocate-and-copy that adds
247+
~15 ms of wall time per call on a (30k, 384) float32 chunk. Each
248+
block writes into a non-overlapping channel slice, so concurrent
249+
writes are safe.
244250
"""
245251
if self.n_workers == 1:
246252
return fn(self.coeff, traces, axis=axis)
@@ -251,17 +257,18 @@ def _apply_sos(self, fn, traces, axis=0):
251257
block = (C + self.n_workers - 1) // self.n_workers
252258
bounds = [(c0, min(c0 + block, C)) for c0 in range(0, C, block)]
253259

260+
# Probe the output dtype on a tiny slice (longer than scipy's internal
261+
# padlen of 6 * len(sos)) so we can pre-allocate. Cost: microseconds.
262+
probe_len = max(64, 6 * self.coeff.shape[0] + 1)
263+
out_dtype = fn(self.coeff, traces[:probe_len, :1], axis=axis).dtype
264+
out = np.empty((traces.shape[0], C), dtype=out_dtype)
265+
254266
def _work(c0, c1):
255-
return c0, c1, fn(self.coeff, traces[:, c0:c1], axis=axis)
267+
out[:, c0:c1] = fn(self.coeff, traces[:, c0:c1], axis=axis)
256268

257269
futures = [pool.submit(_work, c0, c1) for c0, c1 in bounds]
258-
results = [fut.result() for fut in futures]
259-
# Allocate the output using the first block's dtype (scipy may promote
260-
# int input to float64).
261-
out_dtype = results[0][2].dtype
262-
out = np.empty((traces.shape[0], C), dtype=out_dtype)
263-
for c0, c1, block_out in results:
264-
out[:, c0:c1] = block_out
270+
for fut in futures:
271+
fut.result()
265272
return out
266273

267274
def get_traces(self, start_frame, end_frame, channel_indices):

0 commit comments

Comments
 (0)