Skip to content

Commit 1a129ad

Browse files
Slice epsilons with dead channels to match gains during agc. (#4430)
Co-authored-by: Alessio Buccino <alejoe9187@gmail.com>
1 parent 3dbfacf commit 1a129ad

2 files changed

Lines changed: 26 additions & 3 deletions

File tree

src/spikeinterface/preprocessing/highpass_spatial_filter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __init__(
131131
rms_values = recording.get_property("noise_level_rms_raw")
132132
else:
133133
random_slice_kwargs = {} if random_slice_kwargs is None else random_slice_kwargs
134-
rms_values = get_noise_levels(recording, method="rms", return_scaled=False, **random_slice_kwargs)
134+
rms_values = get_noise_levels(recording, method="rms", return_in_uV=False, **random_slice_kwargs)
135135

136136
# Pre-compute spatial filtering parameters
137137
butter_kwargs = dict(btype="highpass", N=highpass_butter_order, Wn=highpass_butter_wn)
@@ -308,7 +308,9 @@ def agc(traces, window, epsilons):
308308

309309
dead_channels = np.sum(gain, axis=0) == 0
310310

311-
traces[:, ~dead_channels] = traces[:, ~dead_channels] / np.maximum(epsilons, gain[:, ~dead_channels])
311+
traces[:, ~dead_channels] = traces[:, ~dead_channels] / np.maximum(
312+
epsilons[~dead_channels], gain[:, ~dead_channels]
313+
)
312314

313315
return traces, gain
314316

src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import spikeinterface.core as si
77
import spikeinterface.preprocessing as spre
88
import spikeinterface.extractors as se
9-
from spikeinterface.core import generate_recording
9+
from spikeinterface.core import generate_recording, NumpyRecording
1010
import importlib.util
1111

1212
ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS"))
@@ -103,6 +103,27 @@ def test_highpass_spatial_filter_synthetic_data(num_channels, ntr_pad, ntr_tap,
103103
assert raw_traces.shape == si_filtered.shape
104104

105105

106+
def test_highpass_spatial_filter_with_dead_channels():
107+
"""Regression test: AGC must handle dead (all-zero) channels without broadcast error.
108+
109+
PR #4286 changed epsilon from a scalar to a per-channel array, but the agc()
110+
function indexed gain with ~dead_channels without applying the same mask to
111+
epsilons, causing a broadcast error when any channels had zero signal.
112+
"""
113+
num_channels = 32
114+
rec = generate_recording(num_channels=num_channels, durations=[0.5])
115+
# Materialize traces and zero out 3 channels to make them "dead"
116+
traces = rec.get_traces().copy()
117+
traces[:, [0, 15, 31]] = 0.0
118+
rec_with_dead = NumpyRecording(
119+
traces_list=[traces], sampling_frequency=rec.sampling_frequency, channel_ids=rec.channel_ids
120+
)
121+
rec_with_dead.set_probe(rec.get_probe(), in_place=True)
122+
filtered = spre.highpass_spatial_filter(rec_with_dead, n_channel_pad=2)
123+
result = filtered.get_traces()
124+
assert result.shape == traces.shape
125+
126+
106127
@pytest.mark.parametrize("dtype", [np.int16, np.float32, np.float64])
107128
def test_dtype_stability(dtype):
108129
"""

0 commit comments

Comments
 (0)