Skip to content

Commit 8491e73

Browse files
authored
Merge branch 'main' into dev
2 parents cab6857 + 37d2c4b commit 8491e73

2 files changed

Lines changed: 27 additions & 1 deletion

File tree

src/spikeinterface/core/baserecording.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,13 @@ def _channel_slice(self, channel_ids, renamed_channel_ids=None):
731731
def _remove_channels(self, remove_channel_ids):
732732
from .channelslice import ChannelSliceRecording
733733

734+
recording_channel_ids = self.get_channel_ids()
735+
non_present_channel_ids = list(set(remove_channel_ids).difference(recording_channel_ids))
736+
if len(non_present_channel_ids) != 0:
737+
raise ValueError(
738+
f"`remove_channel_ids` {non_present_channel_ids} are not in recording ids {recording_channel_ids}."
739+
)
740+
734741
new_channel_ids = self.channel_ids[~np.isin(self.channel_ids, remove_channel_ids)]
735742
sub_recording = ChannelSliceRecording(self, new_channel_ids)
736743
return sub_recording

src/spikeinterface/core/tests/test_channelslicerecording.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,26 @@ def test_failure_with_non_unique_channel_ids():
7777
seed = 10
7878
rec = generate_recording(num_channels=4, durations=durations, set_probe=False, seed=seed)
7979
with pytest.raises(AssertionError):
80-
rec_sliced = ChannelSliceRecording(rec, channel_ids=[0, 1], renamed_channel_ids=[0, 0])
80+
rec_sliced = ChannelSliceRecording(rec, channel_ids=["0", "1"], renamed_channel_ids=[0, 0])
81+
82+
83+
def test_remove_channels():
84+
"""
85+
Check that `remove_channels` returns a recording with the correct channels removed, and that
86+
it raises an error if non-existent channels are given.
87+
"""
88+
durations = [1.0]
89+
seed = 1205
90+
91+
# Note: generated recordings have channel ids: '0', '1', '2', '3', ...
92+
rec = generate_recording(num_channels=4, durations=durations, set_probe=False, seed=seed)
93+
94+
rec_sliced = rec.remove_channels(remove_channel_ids=["0", "2"])
95+
rec_sliced_channel_ids = rec_sliced.get_channel_ids()
96+
assert np.all(rec_sliced_channel_ids == np.array(["1", "3"]))
97+
98+
with pytest.raises(ValueError):
99+
rec_sliced = rec.remove_channels(remove_channel_ids=[0, "1"])
81100

82101

83102
if __name__ == "__main__":

0 commit comments

Comments
 (0)