Skip to content

Commit a2fee4e

Browse files
galenlynchclaude
andcommitted
perf: n_workers kwarg for FilterRecording + CommonReferenceRecording
Adds opt-in intra-chunk thread-parallelism to two preprocessors: channel-split sosfilt/sosfiltfilt in FilterRecording, time-split median/mean in CommonReferenceRecording. Default n_workers=1 preserves existing behavior. Per-caller-thread inner pools ----------------------------- Each outer thread that calls ``get_traces()`` on a parallel-enabled segment gets its own inner ThreadPoolExecutor, stored in a ``WeakKeyDictionary`` keyed by the calling ``Thread`` object. Rationale: * Avoids the shared-pool queueing pathology that would occur if N outer workers (e.g., TimeSeriesChunkExecutor with n_jobs=N) all submitted into a single shared pool with fewer max_workers than outer callers. Under a shared pool, ``n_workers=2`` with ``n_jobs=24`` thrashed at 3.36 s on the test pipeline; per-caller pools: 1.47 s. * Keying by the Thread object (not thread-id integer) avoids the thread-id-reuse hazard: thread IDs can be reused after a thread dies, which would cause a new thread to silently inherit a dead thread's pool. * WeakKeyDictionary + weakref.finalize ensures automatic shutdown of the inner pool when the calling thread is garbage-collected. The finalizer calls ``pool.shutdown(wait=False)`` to avoid blocking the finalizer thread; in-flight tasks would be cancelled, but the owning thread submits+joins synchronously, so none exist when it exits. When useful ----------- * Direct ``get_traces()`` callers (interactive viewers, streaming consumers, mipmap-zarr tile builders) that don't use ``TimeSeriesChunkExecutor``. * Default SI users who haven't tuned job_kwargs. * RAM-constrained deployments that can't crank ``n_jobs`` to core count: on a 24-core host, ``n_jobs=6, n_workers=2`` gets within 8% of ``n_jobs=24, n_workers=1`` at ~1/4 the RAM. Performance (1M × 384 float32 BP+CMR pipeline, 24-core host, thread engine) --------------------------------------------------------------------------- === Component-level (scipy/numpy only) === sosfiltfilt serial → 8 threads: 7.80 s → 2.67 s (2.92x) np.median serial → 16 threads: 3.51 s → 0.33 s (10.58x) === Per-stage end-to-end (rec.get_traces) === Bandpass (5th-order, 300-6k Hz): 8.59 s → 3.20 s (2.69x) CMR median (global): 4.01 s → 0.81 s (4.95x) === CRE outer × inner Pareto, per-caller pools === outer=24, inner=1 each: 1.54 s (100% of peak) outer=24, inner=8 each: 1.42 s (108% of peak; oversubscribed) outer=12, inner=1 each: 1.59 s (97%, ~1/2 RAM of outer=24) outer=6, inner=2 each: 1.75 s (92%, ~1/4 RAM of outer=24) outer=4, inner=6 each: 1.83 s (87%, ~1/6 RAM with 24 threads) Tests ----- New ``test_parallel_pool_semantics.py`` verifies the per-caller-thread contract: single caller reuses one pool; concurrent callers get distinct pools. Existing bandpass + CMR tests still pass. Independent of the companion FIR phase-shift PR (perf/phase-shift-fir); the two can land in either order. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 36ad576 commit a2fee4e

6 files changed

Lines changed: 612 additions & 3 deletions

File tree

Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
"""Benchmark script for the parallel bandpass + CMR speedups.
2+
3+
Runs head-to-head comparisons on synthetic NumpyRecording fixtures so the
4+
numbers are reproducible without external ephys data:
5+
6+
1. Component-level (hot operation only, no SI plumbing):
7+
- scipy.signal.sosfiltfilt serial vs channel-parallel threads
8+
- np.median(axis=1) serial vs time-parallel threads
9+
2. Per-stage end-to-end (``rec.get_traces()`` path):
10+
- BandpassFilterRecording stock vs n_workers=8
11+
- CommonReferenceRecording stock vs n_workers=16
12+
3. CRE (``TimeSeriesChunkExecutor``) × inner (n_workers) interaction at
13+
matched chunk_duration="1s".
14+
15+
FilterRecordingSegment and CommonReferenceRecordingSegment use
16+
**per-caller-thread inner pools** (WeakKeyDictionary keyed by the calling
17+
Thread object). Each outer thread that calls get_traces() gets its own
18+
inner ThreadPoolExecutor, so n_workers composes cleanly with CRE's outer
19+
parallelism — no shared-pool queueing pathology. See
20+
``tests/test_parallel_pool_semantics.py`` for the contract.
21+
22+
Measured on a 24-core x86_64 host with 1M x 384 float32 chunks (SI 0.103
23+
dev, numpy 2.1, scipy 1.14, full get_traces() path end-to-end):
24+
25+
=== Component-level (hot kernel only, no SI plumbing) ===
26+
sosfiltfilt serial → 8 threads: 7.80 s → 2.67 s (2.92x)
27+
np.median serial → 16 threads: 3.51 s → 0.33 s (10.58x)
28+
29+
=== Per-stage end-to-end (rec.get_traces) ===
30+
Bandpass (5th-order, 300-6k Hz): 8.59 s → 3.20 s (2.69x)
31+
CMR median (global): 4.01 s → 0.81 s (4.95x)
32+
33+
=== CRE outer × inner (chunk=1s, per-caller pools) ===
34+
Bandpass: stock n=1 → stock n=8 thread: 7.42 s → 1.40 s (5.3x outer)
35+
n_workers=8 n=1: 3.18 s (2.3x inner)
36+
n_workers=8 n=8 thread: 1.24 s (combined)
37+
CMR: stock n=1 → stock n=8 thread: 3.98 s → 0.61 s (6.5x outer)
38+
n_workers=16 n=1: 1.58 s (2.5x inner)
39+
n_workers=16 n=8 thread: 0.36 s (11.0x combined)
40+
41+
Bandpass and CMR scale sub-linearly with thread count due to memory
42+
bandwidth saturation; 2.7x / 5x per stage on 8 / 16 threads respectively
43+
is consistent with the DRAM ceiling at these chunk sizes, not a
44+
parallelism bug. Under CRE, the outer-vs-inner combination depends on
45+
whether the inner pool has headroom over n_jobs — per-caller pools make
46+
this deterministic regardless.
47+
48+
Run with ``python -m benchmarks.preprocessing.bench_perf`` from repo root.
49+
"""
50+
51+
from __future__ import annotations
52+
53+
import time
54+
55+
import numpy as np
56+
import scipy.signal
57+
58+
from spikeinterface import NumpyRecording
59+
from spikeinterface.preprocessing import (
60+
BandpassFilterRecording,
61+
CommonReferenceRecording,
62+
)
63+
64+
65+
def _make_recording(T: int = 1_048_576, C: int = 384, fs: float = 30_000.0, dtype=np.float32):
66+
"""Synthetic NumpyRecording matching typical Neuropixels shard shape."""
67+
rng = np.random.default_rng(0)
68+
traces = rng.standard_normal((T, C)).astype(dtype) * 100.0
69+
rec = NumpyRecording([traces], sampling_frequency=fs)
70+
return rec
71+
72+
73+
def _time_get_traces(rec, *, n_reps=3, warmup=1):
74+
"""Median-of-N timing of rec.get_traces() for the full single segment."""
75+
for _ in range(warmup):
76+
rec.get_traces()
77+
times = []
78+
for _ in range(n_reps):
79+
t0 = time.perf_counter()
80+
rec.get_traces()
81+
times.append(time.perf_counter() - t0)
82+
return float(np.median(times))
83+
84+
85+
def _time_callable(fn, *, n_reps=3, warmup=1):
86+
"""Best-of-N timing for a bare callable. Used for component-level benches
87+
where we want to isolate the hot operation from surrounding glue."""
88+
for _ in range(warmup):
89+
fn()
90+
times = []
91+
for _ in range(n_reps):
92+
t0 = time.perf_counter()
93+
fn()
94+
times.append(time.perf_counter() - t0)
95+
return float(min(times))
96+
97+
98+
def _time_cre(executor, *, n_reps=2, warmup=1):
99+
"""Min-of-N timing for a TimeSeriesChunkExecutor invocation."""
100+
for _ in range(warmup):
101+
executor.run()
102+
times = []
103+
for _ in range(n_reps):
104+
t0 = time.perf_counter()
105+
executor.run()
106+
times.append(time.perf_counter() - t0)
107+
return float(min(times))
108+
109+
110+
def _cre_init(recording):
111+
return {"recording": recording}
112+
113+
114+
def _cre_func(segment_index, start_frame, end_frame, worker_dict):
115+
worker_dict["recording"].get_traces(
116+
start_frame=start_frame, end_frame=end_frame, segment_index=segment_index
117+
)
118+
119+
120+
def bench_sosfiltfilt_component():
121+
"""Component-level bench: just scipy.signal.sosfiltfilt vs channel-parallel.
122+
123+
Isolates the hot SOS operation from the full BandpassFilter.get_traces
124+
path so you can see the kernel-only speedup (no margin fetch, no dtype
125+
cast, no slice).
126+
"""
127+
from concurrent.futures import ThreadPoolExecutor
128+
129+
print("--- [component] sosfiltfilt (1M x 384 float32) ---")
130+
T, C = 1_048_576, 384
131+
rng = np.random.default_rng(0)
132+
x = rng.standard_normal((T, C)).astype(np.float32) * 100.0
133+
sos = scipy.signal.butter(5, [300.0, 6000.0], btype="bandpass", fs=30_000.0, output="sos")
134+
135+
pool = ThreadPoolExecutor(max_workers=8)
136+
137+
def parallel_call():
138+
block = (C + 8 - 1) // 8
139+
bounds = [(c0, min(c0 + block, C)) for c0 in range(0, C, block)]
140+
141+
def _work(c0, c1):
142+
return c0, c1, scipy.signal.sosfiltfilt(sos, x[:, c0:c1], axis=0)
143+
144+
results = [fut.result() for fut in [pool.submit(_work, c0, c1) for c0, c1 in bounds]]
145+
out = np.empty((T, C), dtype=results[0][2].dtype)
146+
for c0, c1, block_out in results:
147+
out[:, c0:c1] = block_out
148+
return out
149+
150+
t_stock = _time_callable(lambda: scipy.signal.sosfiltfilt(sos, x, axis=0))
151+
t_par = _time_callable(parallel_call)
152+
pool.shutdown()
153+
print(f" scipy.sosfiltfilt serial: {t_stock:6.2f} s")
154+
print(f" scipy.sosfiltfilt 8 threads: {t_par:6.2f} s ({t_stock / t_par:4.2f}x)")
155+
print()
156+
157+
158+
def bench_median_component():
159+
"""Component-level bench: just np.median(axis=1) vs threaded across time blocks."""
160+
from concurrent.futures import ThreadPoolExecutor
161+
162+
print("--- [component] np.median axis=1 (1M x 384 float32) ---")
163+
T, C = 1_048_576, 384
164+
rng = np.random.default_rng(0)
165+
x = rng.standard_normal((T, C)).astype(np.float32) * 100.0
166+
167+
pool = ThreadPoolExecutor(max_workers=16)
168+
169+
def parallel_call():
170+
block = (T + 16 - 1) // 16
171+
bounds = [(t0, min(t0 + block, T)) for t0 in range(0, T, block)]
172+
173+
def _work(t0, t1):
174+
return t0, t1, np.median(x[t0:t1, :], axis=1)
175+
176+
results = [fut.result() for fut in [pool.submit(_work, t0, t1) for t0, t1 in bounds]]
177+
out = np.empty(T, dtype=results[0][2].dtype)
178+
for t0, t1, block_out in results:
179+
out[t0:t1] = block_out
180+
return out
181+
182+
t_stock = _time_callable(lambda: np.median(x, axis=1))
183+
t_par = _time_callable(parallel_call)
184+
pool.shutdown()
185+
print(f" np.median serial: {t_stock:6.2f} s")
186+
print(f" np.median 16 threads: {t_par:6.2f} s ({t_stock / t_par:4.2f}x)")
187+
print()
188+
189+
190+
def bench_bandpass():
191+
"""End-to-end bench: BandpassFilterRecording stock vs n_workers=8."""
192+
print("=== Bandpass (5th-order Butterworth 300-6000 Hz, 1M x 384 float32) ===")
193+
rec = _make_recording(dtype=np.float32)
194+
stock = BandpassFilterRecording(rec, freq_min=300.0, freq_max=6000.0, margin_ms=40.0)
195+
fast = BandpassFilterRecording(rec, freq_min=300.0, freq_max=6000.0, margin_ms=40.0, n_workers=8)
196+
197+
t_stock = _time_get_traces(stock)
198+
t_fast = _time_get_traces(fast)
199+
print(f" stock (n_workers=1): {t_stock:6.2f} s")
200+
print(f" parallel (n_workers=8): {t_fast:6.2f} s ({t_stock / t_fast:4.2f}x)")
201+
# Equivalence check
202+
ref = stock.get_traces(start_frame=1000, end_frame=10_000)
203+
out = fast.get_traces(start_frame=1000, end_frame=10_000)
204+
assert np.allclose(out, ref, rtol=1e-5, atol=1e-4), "parallel bandpass output mismatch"
205+
print(" output matches stock within float32 tolerance")
206+
print()
207+
208+
209+
def bench_cmr():
210+
"""End-to-end bench: CommonReferenceRecording stock vs n_workers=16."""
211+
print("=== CMR median (global, 1M x 384 float32) ===")
212+
rec = _make_recording(dtype=np.float32)
213+
stock = CommonReferenceRecording(rec, operator="median", reference="global")
214+
fast = CommonReferenceRecording(rec, operator="median", reference="global", n_workers=16)
215+
216+
t_stock = _time_get_traces(stock)
217+
t_fast = _time_get_traces(fast)
218+
print(f" stock (n_workers=1): {t_stock:6.2f} s")
219+
print(f" parallel (n_workers=16): {t_fast:6.2f} s ({t_stock / t_fast:4.2f}x)")
220+
ref = stock.get_traces(start_frame=1000, end_frame=10_000)
221+
out = fast.get_traces(start_frame=1000, end_frame=10_000)
222+
np.testing.assert_array_equal(out, ref)
223+
print(" output is bitwise-identical to stock")
224+
print()
225+
226+
227+
def bench_bandpass_cre_interaction():
228+
"""Bandpass: outer (TimeSeriesChunkExecutor) × inner (n_workers) parallelism.
229+
230+
At SI's default ``chunk_duration="1s"``, the intra-chunk ``n_workers``
231+
kwarg is only useful when outer CRE workers don't already saturate cores.
232+
When combined, the result depends on whether inner-pool ``max_workers``
233+
exceeds outer ``n_jobs``.
234+
"""
235+
from spikeinterface.core.job_tools import TimeSeriesChunkExecutor
236+
237+
print("=== Bandpass: outer (CRE) × inner (n_workers), 1M × 384 float32, chunk=1s ===")
238+
rec = _make_recording(dtype=np.float32)
239+
240+
def make_cre(bp_rec, n_jobs):
241+
return TimeSeriesChunkExecutor(
242+
time_series=bp_rec, func=_cre_func, init_func=_cre_init, init_args=(bp_rec,),
243+
pool_engine="thread", n_jobs=n_jobs, chunk_duration="1s", progress_bar=False,
244+
)
245+
246+
t_stock_n1 = _time_cre(make_cre(BandpassFilterRecording(rec), n_jobs=1))
247+
t_stock_n8 = _time_cre(make_cre(BandpassFilterRecording(rec), n_jobs=8))
248+
t_fast_n1 = _time_cre(make_cre(BandpassFilterRecording(rec, n_workers=8), n_jobs=1))
249+
t_fast_n8 = _time_cre(make_cre(BandpassFilterRecording(rec, n_workers=8), n_jobs=8))
250+
251+
print(f" {'config':<40} {'time':>8} {'vs baseline':>12}")
252+
print(f" {'stock, CRE n=1 (baseline)':<40} {t_stock_n1:6.2f} s {'1.00×':>12}")
253+
print(f" {'stock, CRE n=8 thread':<40} {t_stock_n8:6.2f} s {t_stock_n1/t_stock_n8:5.2f}× (outer only)")
254+
print(f" {'n_workers=8, CRE n=1':<40} {t_fast_n1:6.2f} s {t_stock_n1/t_fast_n1:5.2f}× (inner only)")
255+
print(f" {'n_workers=8, CRE n=8 thread':<40} {t_fast_n8:6.2f} s {t_stock_n1/t_fast_n8:5.2f}× (both)")
256+
print()
257+
258+
259+
def bench_cmr_cre_interaction():
260+
"""CMR: outer (TimeSeriesChunkExecutor) × inner (n_workers) parallelism."""
261+
from spikeinterface.core.job_tools import TimeSeriesChunkExecutor
262+
263+
print("=== CMR: outer (CRE) × inner (n_workers), 1M × 384 float32, chunk=1s ===")
264+
rec = _make_recording(dtype=np.float32)
265+
266+
def make_cre(cmr_rec, n_jobs):
267+
return TimeSeriesChunkExecutor(
268+
time_series=cmr_rec, func=_cre_func, init_func=_cre_init, init_args=(cmr_rec,),
269+
pool_engine="thread", n_jobs=n_jobs, chunk_duration="1s", progress_bar=False,
270+
)
271+
272+
t_stock_n1 = _time_cre(make_cre(CommonReferenceRecording(rec), n_jobs=1))
273+
t_stock_n8 = _time_cre(make_cre(CommonReferenceRecording(rec), n_jobs=8))
274+
t_fast_n1 = _time_cre(make_cre(CommonReferenceRecording(rec, n_workers=16), n_jobs=1))
275+
t_fast_n8 = _time_cre(make_cre(CommonReferenceRecording(rec, n_workers=16), n_jobs=8))
276+
277+
print(f" {'config':<40} {'time':>8} {'vs baseline':>12}")
278+
print(f" {'stock, CRE n=1 (baseline)':<40} {t_stock_n1:6.2f} s {'1.00×':>12}")
279+
print(f" {'stock, CRE n=8 thread':<40} {t_stock_n8:6.2f} s {t_stock_n1/t_stock_n8:5.2f}× (outer only)")
280+
print(f" {'n_workers=16, CRE n=1':<40} {t_fast_n1:6.2f} s {t_stock_n1/t_fast_n1:5.2f}× (inner only)")
281+
print(f" {'n_workers=16, CRE n=8 thread':<40} {t_fast_n8:6.2f} s {t_stock_n1/t_fast_n8:5.2f}× (both)")
282+
print()
283+
284+
285+
def main():
286+
print("### COMPONENT-LEVEL (hot operation only) ###")
287+
print()
288+
bench_sosfiltfilt_component()
289+
bench_median_component()
290+
291+
print("### PER-STAGE END-TO-END (rec.get_traces()) ###")
292+
print()
293+
bench_bandpass()
294+
bench_cmr()
295+
296+
print("### CRE OUTER × INNER (chunk=1s) ###")
297+
print()
298+
bench_bandpass_cre_interaction()
299+
bench_cmr_cre_interaction()
300+
301+
302+
if __name__ == "__main__":
303+
main()

0 commit comments

Comments
 (0)