|
6 | 6 | import spikeinterface.core as si |
7 | 7 | import spikeinterface.preprocessing as spre |
8 | 8 | import spikeinterface.extractors as se |
9 | | -from spikeinterface.core import generate_recording |
| 9 | +from spikeinterface.core import generate_recording, NumpyRecording |
10 | 10 | import importlib.util |
11 | 11 |
|
12 | 12 | ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) |
@@ -103,6 +103,27 @@ def test_highpass_spatial_filter_synthetic_data(num_channels, ntr_pad, ntr_tap, |
103 | 103 | assert raw_traces.shape == si_filtered.shape |
104 | 104 |
|
105 | 105 |
|
| 106 | +def test_highpass_spatial_filter_with_dead_channels(): |
| 107 | + """Regression test: AGC must handle dead (all-zero) channels without broadcast error. |
| 108 | +
|
| 109 | + PR #4286 changed epsilon from a scalar to a per-channel array, but the agc() |
| 110 | + function indexed gain with ~dead_channels without applying the same mask to |
| 111 | + epsilons, causing a broadcast error when any channels had zero signal. |
| 112 | + """ |
| 113 | + num_channels = 32 |
| 114 | + rec = generate_recording(num_channels=num_channels, durations=[0.5]) |
| 115 | + # Materialize traces and zero out 3 channels to make them "dead" |
| 116 | + traces = rec.get_traces().copy() |
| 117 | + traces[:, [0, 15, 31]] = 0.0 |
| 118 | + rec_with_dead = NumpyRecording( |
| 119 | + traces_list=[traces], sampling_frequency=rec.sampling_frequency, channel_ids=rec.channel_ids |
| 120 | + ) |
| 121 | + rec_with_dead.set_probe(rec.get_probe(), in_place=True) |
| 122 | + filtered = spre.highpass_spatial_filter(rec_with_dead, n_channel_pad=2) |
| 123 | + result = filtered.get_traces() |
| 124 | + assert result.shape == traces.shape |
| 125 | + |
| 126 | + |
106 | 127 | @pytest.mark.parametrize("dtype", [np.int16, np.float32, np.float64]) |
107 | 128 | def test_dtype_stability(dtype): |
108 | 129 | """ |
|
0 commit comments