|
| 1 | +"""Benchmark script for the parallel bandpass + CMR speedups. |
| 2 | +
|
| 3 | +Runs head-to-head comparisons on synthetic NumpyRecording fixtures so the |
| 4 | +numbers are reproducible without external ephys data: |
| 5 | +
|
| 6 | +1. Component-level (hot operation only, no SI plumbing): |
| 7 | + - scipy.signal.sosfiltfilt serial vs channel-parallel threads |
| 8 | + - np.median(axis=1) serial vs time-parallel threads |
| 9 | +2. Per-stage end-to-end (``rec.get_traces()`` path): |
| 10 | + - BandpassFilterRecording stock vs n_workers=8 |
| 11 | + - CommonReferenceRecording stock vs n_workers=16 |
| 12 | +3. CRE (``TimeSeriesChunkExecutor``) × inner (n_workers) interaction at |
| 13 | + matched chunk_duration="1s". |
| 14 | +
|
| 15 | +FilterRecordingSegment and CommonReferenceRecordingSegment use |
| 16 | +**per-caller-thread inner pools** (WeakKeyDictionary keyed by the calling |
| 17 | +Thread object). Each outer thread that calls get_traces() gets its own |
| 18 | +inner ThreadPoolExecutor, so n_workers composes cleanly with CRE's outer |
| 19 | +parallelism — no shared-pool queueing pathology. See |
| 20 | +``tests/test_parallel_pool_semantics.py`` for the contract. |
| 21 | +
|
| 22 | +Measured on a 24-core x86_64 host with 1M x 384 float32 chunks (SI 0.103 |
| 23 | +dev, numpy 2.1, scipy 1.14, full get_traces() path end-to-end): |
| 24 | +
|
| 25 | + === Component-level (hot kernel only, no SI plumbing) === |
| 26 | + sosfiltfilt serial → 8 threads: 7.80 s → 2.67 s (2.92x) |
| 27 | + np.median serial → 16 threads: 3.51 s → 0.33 s (10.58x) |
| 28 | +
|
| 29 | + === Per-stage end-to-end (rec.get_traces) === |
| 30 | + Bandpass (5th-order, 300-6k Hz): 8.59 s → 3.20 s (2.69x) |
| 31 | + CMR median (global): 4.01 s → 0.81 s (4.95x) |
| 32 | +
|
| 33 | + === CRE outer × inner (chunk=1s, per-caller pools) === |
| 34 | + Bandpass: stock n=1 → stock n=8 thread: 7.42 s → 1.40 s (5.3x outer) |
| 35 | + n_workers=8 n=1: 3.18 s (2.3x inner) |
| 36 | + n_workers=8 n=8 thread: 1.24 s (combined) |
| 37 | + CMR: stock n=1 → stock n=8 thread: 3.98 s → 0.61 s (6.5x outer) |
| 38 | + n_workers=16 n=1: 1.58 s (2.5x inner) |
| 39 | + n_workers=16 n=8 thread: 0.36 s (11.0x combined) |
| 40 | +
|
| 41 | +Bandpass and CMR scale sub-linearly with thread count due to memory |
| 42 | +bandwidth saturation; 2.7x / 5x per stage on 8 / 16 threads respectively |
| 43 | +is consistent with the DRAM ceiling at these chunk sizes, not a |
| 44 | +parallelism bug. Under CRE, the outer-vs-inner combination depends on |
| 45 | +whether the inner pool has headroom over n_jobs — per-caller pools make |
| 46 | +this deterministic regardless. |
| 47 | +
|
| 48 | +Run with ``python -m benchmarks.preprocessing.bench_perf`` from repo root. |
| 49 | +""" |
| 50 | + |
| 51 | +from __future__ import annotations |
| 52 | + |
| 53 | +import time |
| 54 | + |
| 55 | +import numpy as np |
| 56 | +import scipy.signal |
| 57 | + |
| 58 | +from spikeinterface import NumpyRecording |
| 59 | +from spikeinterface.preprocessing import ( |
| 60 | + BandpassFilterRecording, |
| 61 | + CommonReferenceRecording, |
| 62 | +) |
| 63 | + |
| 64 | + |
| 65 | +def _make_recording(T: int = 1_048_576, C: int = 384, fs: float = 30_000.0, dtype=np.float32): |
| 66 | + """Synthetic NumpyRecording matching typical Neuropixels shard shape.""" |
| 67 | + rng = np.random.default_rng(0) |
| 68 | + traces = rng.standard_normal((T, C)).astype(dtype) * 100.0 |
| 69 | + rec = NumpyRecording([traces], sampling_frequency=fs) |
| 70 | + return rec |
| 71 | + |
| 72 | + |
| 73 | +def _time_get_traces(rec, *, n_reps=3, warmup=1): |
| 74 | + """Median-of-N timing of rec.get_traces() for the full single segment.""" |
| 75 | + for _ in range(warmup): |
| 76 | + rec.get_traces() |
| 77 | + times = [] |
| 78 | + for _ in range(n_reps): |
| 79 | + t0 = time.perf_counter() |
| 80 | + rec.get_traces() |
| 81 | + times.append(time.perf_counter() - t0) |
| 82 | + return float(np.median(times)) |
| 83 | + |
| 84 | + |
| 85 | +def _time_callable(fn, *, n_reps=3, warmup=1): |
| 86 | + """Best-of-N timing for a bare callable. Used for component-level benches |
| 87 | + where we want to isolate the hot operation from surrounding glue.""" |
| 88 | + for _ in range(warmup): |
| 89 | + fn() |
| 90 | + times = [] |
| 91 | + for _ in range(n_reps): |
| 92 | + t0 = time.perf_counter() |
| 93 | + fn() |
| 94 | + times.append(time.perf_counter() - t0) |
| 95 | + return float(min(times)) |
| 96 | + |
| 97 | + |
| 98 | +def _time_cre(executor, *, n_reps=2, warmup=1): |
| 99 | + """Min-of-N timing for a TimeSeriesChunkExecutor invocation.""" |
| 100 | + for _ in range(warmup): |
| 101 | + executor.run() |
| 102 | + times = [] |
| 103 | + for _ in range(n_reps): |
| 104 | + t0 = time.perf_counter() |
| 105 | + executor.run() |
| 106 | + times.append(time.perf_counter() - t0) |
| 107 | + return float(min(times)) |
| 108 | + |
| 109 | + |
| 110 | +def _cre_init(recording): |
| 111 | + return {"recording": recording} |
| 112 | + |
| 113 | + |
| 114 | +def _cre_func(segment_index, start_frame, end_frame, worker_dict): |
| 115 | + worker_dict["recording"].get_traces( |
| 116 | + start_frame=start_frame, end_frame=end_frame, segment_index=segment_index |
| 117 | + ) |
| 118 | + |
| 119 | + |
| 120 | +def bench_sosfiltfilt_component(): |
| 121 | + """Component-level bench: just scipy.signal.sosfiltfilt vs channel-parallel. |
| 122 | +
|
| 123 | + Isolates the hot SOS operation from the full BandpassFilter.get_traces |
| 124 | + path so you can see the kernel-only speedup (no margin fetch, no dtype |
| 125 | + cast, no slice). |
| 126 | + """ |
| 127 | + from concurrent.futures import ThreadPoolExecutor |
| 128 | + |
| 129 | + print("--- [component] sosfiltfilt (1M x 384 float32) ---") |
| 130 | + T, C = 1_048_576, 384 |
| 131 | + rng = np.random.default_rng(0) |
| 132 | + x = rng.standard_normal((T, C)).astype(np.float32) * 100.0 |
| 133 | + sos = scipy.signal.butter(5, [300.0, 6000.0], btype="bandpass", fs=30_000.0, output="sos") |
| 134 | + |
| 135 | + pool = ThreadPoolExecutor(max_workers=8) |
| 136 | + |
| 137 | + def parallel_call(): |
| 138 | + block = (C + 8 - 1) // 8 |
| 139 | + bounds = [(c0, min(c0 + block, C)) for c0 in range(0, C, block)] |
| 140 | + |
| 141 | + def _work(c0, c1): |
| 142 | + return c0, c1, scipy.signal.sosfiltfilt(sos, x[:, c0:c1], axis=0) |
| 143 | + |
| 144 | + results = [fut.result() for fut in [pool.submit(_work, c0, c1) for c0, c1 in bounds]] |
| 145 | + out = np.empty((T, C), dtype=results[0][2].dtype) |
| 146 | + for c0, c1, block_out in results: |
| 147 | + out[:, c0:c1] = block_out |
| 148 | + return out |
| 149 | + |
| 150 | + t_stock = _time_callable(lambda: scipy.signal.sosfiltfilt(sos, x, axis=0)) |
| 151 | + t_par = _time_callable(parallel_call) |
| 152 | + pool.shutdown() |
| 153 | + print(f" scipy.sosfiltfilt serial: {t_stock:6.2f} s") |
| 154 | + print(f" scipy.sosfiltfilt 8 threads: {t_par:6.2f} s ({t_stock / t_par:4.2f}x)") |
| 155 | + print() |
| 156 | + |
| 157 | + |
| 158 | +def bench_median_component(): |
| 159 | + """Component-level bench: just np.median(axis=1) vs threaded across time blocks.""" |
| 160 | + from concurrent.futures import ThreadPoolExecutor |
| 161 | + |
| 162 | + print("--- [component] np.median axis=1 (1M x 384 float32) ---") |
| 163 | + T, C = 1_048_576, 384 |
| 164 | + rng = np.random.default_rng(0) |
| 165 | + x = rng.standard_normal((T, C)).astype(np.float32) * 100.0 |
| 166 | + |
| 167 | + pool = ThreadPoolExecutor(max_workers=16) |
| 168 | + |
| 169 | + def parallel_call(): |
| 170 | + block = (T + 16 - 1) // 16 |
| 171 | + bounds = [(t0, min(t0 + block, T)) for t0 in range(0, T, block)] |
| 172 | + |
| 173 | + def _work(t0, t1): |
| 174 | + return t0, t1, np.median(x[t0:t1, :], axis=1) |
| 175 | + |
| 176 | + results = [fut.result() for fut in [pool.submit(_work, t0, t1) for t0, t1 in bounds]] |
| 177 | + out = np.empty(T, dtype=results[0][2].dtype) |
| 178 | + for t0, t1, block_out in results: |
| 179 | + out[t0:t1] = block_out |
| 180 | + return out |
| 181 | + |
| 182 | + t_stock = _time_callable(lambda: np.median(x, axis=1)) |
| 183 | + t_par = _time_callable(parallel_call) |
| 184 | + pool.shutdown() |
| 185 | + print(f" np.median serial: {t_stock:6.2f} s") |
| 186 | + print(f" np.median 16 threads: {t_par:6.2f} s ({t_stock / t_par:4.2f}x)") |
| 187 | + print() |
| 188 | + |
| 189 | + |
| 190 | +def bench_bandpass(): |
| 191 | + """End-to-end bench: BandpassFilterRecording stock vs n_workers=8.""" |
| 192 | + print("=== Bandpass (5th-order Butterworth 300-6000 Hz, 1M x 384 float32) ===") |
| 193 | + rec = _make_recording(dtype=np.float32) |
| 194 | + stock = BandpassFilterRecording(rec, freq_min=300.0, freq_max=6000.0, margin_ms=40.0) |
| 195 | + fast = BandpassFilterRecording(rec, freq_min=300.0, freq_max=6000.0, margin_ms=40.0, n_workers=8) |
| 196 | + |
| 197 | + t_stock = _time_get_traces(stock) |
| 198 | + t_fast = _time_get_traces(fast) |
| 199 | + print(f" stock (n_workers=1): {t_stock:6.2f} s") |
| 200 | + print(f" parallel (n_workers=8): {t_fast:6.2f} s ({t_stock / t_fast:4.2f}x)") |
| 201 | + # Equivalence check |
| 202 | + ref = stock.get_traces(start_frame=1000, end_frame=10_000) |
| 203 | + out = fast.get_traces(start_frame=1000, end_frame=10_000) |
| 204 | + assert np.allclose(out, ref, rtol=1e-5, atol=1e-4), "parallel bandpass output mismatch" |
| 205 | + print(" output matches stock within float32 tolerance") |
| 206 | + print() |
| 207 | + |
| 208 | + |
| 209 | +def bench_cmr(): |
| 210 | + """End-to-end bench: CommonReferenceRecording stock vs n_workers=16.""" |
| 211 | + print("=== CMR median (global, 1M x 384 float32) ===") |
| 212 | + rec = _make_recording(dtype=np.float32) |
| 213 | + stock = CommonReferenceRecording(rec, operator="median", reference="global") |
| 214 | + fast = CommonReferenceRecording(rec, operator="median", reference="global", n_workers=16) |
| 215 | + |
| 216 | + t_stock = _time_get_traces(stock) |
| 217 | + t_fast = _time_get_traces(fast) |
| 218 | + print(f" stock (n_workers=1): {t_stock:6.2f} s") |
| 219 | + print(f" parallel (n_workers=16): {t_fast:6.2f} s ({t_stock / t_fast:4.2f}x)") |
| 220 | + ref = stock.get_traces(start_frame=1000, end_frame=10_000) |
| 221 | + out = fast.get_traces(start_frame=1000, end_frame=10_000) |
| 222 | + np.testing.assert_array_equal(out, ref) |
| 223 | + print(" output is bitwise-identical to stock") |
| 224 | + print() |
| 225 | + |
| 226 | + |
| 227 | +def bench_bandpass_cre_interaction(): |
| 228 | + """Bandpass: outer (TimeSeriesChunkExecutor) × inner (n_workers) parallelism. |
| 229 | +
|
| 230 | + At SI's default ``chunk_duration="1s"``, the intra-chunk ``n_workers`` |
| 231 | + kwarg is only useful when outer CRE workers don't already saturate cores. |
| 232 | + When combined, the result depends on whether inner-pool ``max_workers`` |
| 233 | + exceeds outer ``n_jobs``. |
| 234 | + """ |
| 235 | + from spikeinterface.core.job_tools import TimeSeriesChunkExecutor |
| 236 | + |
| 237 | + print("=== Bandpass: outer (CRE) × inner (n_workers), 1M × 384 float32, chunk=1s ===") |
| 238 | + rec = _make_recording(dtype=np.float32) |
| 239 | + |
| 240 | + def make_cre(bp_rec, n_jobs): |
| 241 | + return TimeSeriesChunkExecutor( |
| 242 | + time_series=bp_rec, func=_cre_func, init_func=_cre_init, init_args=(bp_rec,), |
| 243 | + pool_engine="thread", n_jobs=n_jobs, chunk_duration="1s", progress_bar=False, |
| 244 | + ) |
| 245 | + |
| 246 | + t_stock_n1 = _time_cre(make_cre(BandpassFilterRecording(rec), n_jobs=1)) |
| 247 | + t_stock_n8 = _time_cre(make_cre(BandpassFilterRecording(rec), n_jobs=8)) |
| 248 | + t_fast_n1 = _time_cre(make_cre(BandpassFilterRecording(rec, n_workers=8), n_jobs=1)) |
| 249 | + t_fast_n8 = _time_cre(make_cre(BandpassFilterRecording(rec, n_workers=8), n_jobs=8)) |
| 250 | + |
| 251 | + print(f" {'config':<40} {'time':>8} {'vs baseline':>12}") |
| 252 | + print(f" {'stock, CRE n=1 (baseline)':<40} {t_stock_n1:6.2f} s {'1.00×':>12}") |
| 253 | + print(f" {'stock, CRE n=8 thread':<40} {t_stock_n8:6.2f} s {t_stock_n1/t_stock_n8:5.2f}× (outer only)") |
| 254 | + print(f" {'n_workers=8, CRE n=1':<40} {t_fast_n1:6.2f} s {t_stock_n1/t_fast_n1:5.2f}× (inner only)") |
| 255 | + print(f" {'n_workers=8, CRE n=8 thread':<40} {t_fast_n8:6.2f} s {t_stock_n1/t_fast_n8:5.2f}× (both)") |
| 256 | + print() |
| 257 | + |
| 258 | + |
| 259 | +def bench_cmr_cre_interaction(): |
| 260 | + """CMR: outer (TimeSeriesChunkExecutor) × inner (n_workers) parallelism.""" |
| 261 | + from spikeinterface.core.job_tools import TimeSeriesChunkExecutor |
| 262 | + |
| 263 | + print("=== CMR: outer (CRE) × inner (n_workers), 1M × 384 float32, chunk=1s ===") |
| 264 | + rec = _make_recording(dtype=np.float32) |
| 265 | + |
| 266 | + def make_cre(cmr_rec, n_jobs): |
| 267 | + return TimeSeriesChunkExecutor( |
| 268 | + time_series=cmr_rec, func=_cre_func, init_func=_cre_init, init_args=(cmr_rec,), |
| 269 | + pool_engine="thread", n_jobs=n_jobs, chunk_duration="1s", progress_bar=False, |
| 270 | + ) |
| 271 | + |
| 272 | + t_stock_n1 = _time_cre(make_cre(CommonReferenceRecording(rec), n_jobs=1)) |
| 273 | + t_stock_n8 = _time_cre(make_cre(CommonReferenceRecording(rec), n_jobs=8)) |
| 274 | + t_fast_n1 = _time_cre(make_cre(CommonReferenceRecording(rec, n_workers=16), n_jobs=1)) |
| 275 | + t_fast_n8 = _time_cre(make_cre(CommonReferenceRecording(rec, n_workers=16), n_jobs=8)) |
| 276 | + |
| 277 | + print(f" {'config':<40} {'time':>8} {'vs baseline':>12}") |
| 278 | + print(f" {'stock, CRE n=1 (baseline)':<40} {t_stock_n1:6.2f} s {'1.00×':>12}") |
| 279 | + print(f" {'stock, CRE n=8 thread':<40} {t_stock_n8:6.2f} s {t_stock_n1/t_stock_n8:5.2f}× (outer only)") |
| 280 | + print(f" {'n_workers=16, CRE n=1':<40} {t_fast_n1:6.2f} s {t_stock_n1/t_fast_n1:5.2f}× (inner only)") |
| 281 | + print(f" {'n_workers=16, CRE n=8 thread':<40} {t_fast_n8:6.2f} s {t_stock_n1/t_fast_n8:5.2f}× (both)") |
| 282 | + print() |
| 283 | + |
| 284 | + |
| 285 | +def main(): |
| 286 | + print("### COMPONENT-LEVEL (hot operation only) ###") |
| 287 | + print() |
| 288 | + bench_sosfiltfilt_component() |
| 289 | + bench_median_component() |
| 290 | + |
| 291 | + print("### PER-STAGE END-TO-END (rec.get_traces()) ###") |
| 292 | + print() |
| 293 | + bench_bandpass() |
| 294 | + bench_cmr() |
| 295 | + |
| 296 | + print("### CRE OUTER × INNER (chunk=1s) ###") |
| 297 | + print() |
| 298 | + bench_bandpass_cre_interaction() |
| 299 | + bench_cmr_cre_interaction() |
| 300 | + |
| 301 | + |
| 302 | +if __name__ == "__main__": |
| 303 | + main() |
0 commit comments