Skip to content

Commit 0615ca9

Browse files
authored
Merge pull request #3701 from rly/mda_channels
Set max_channel property of MdaSortingExtractor
2 parents 64d253c + c40ee90 commit 0615ca9

2 files changed

Lines changed: 34 additions & 2 deletions

File tree

src/spikeinterface/extractors/mdaextractors.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,13 @@ def get_traces(
180180
class MdaSortingExtractor(BaseSorting):
181181
"""Load MDA format data as a sorting extractor.
182182
183+
NOTE: As in the MDA format, the max_channel property indexes the channels that are given as input
184+
to the sorter.
185+
If sorting was run on a subset of channels of the recording, then the max_channel values are
186+
based on that subset, so care must be taken when associating these values with a recording.
187+
If additional sorting segments are added to this sorting extractor after initialization,
188+
then max_channel will not be updated. The max_channel indices begin at 1.
189+
183190
Parameters
184191
----------
185192
file_path : str or Path
@@ -202,14 +209,27 @@ def __init__(self, file_path, sampling_frequency):
202209
sorting_segment = MdaSortingSegment(firings)
203210
self.add_sorting_segment(sorting_segment)
204211

212+
# Store the max channel for each unit
213+
# Every spike assigned to a unit (label) has the same max channel
214+
# ref: https://github.com/SpikeInterface/spikeinterface/issues/3695#issuecomment-2663329006
215+
max_channels = []
216+
segment = self._sorting_segments[0]
217+
for unit_id in self.unit_ids:
218+
label_mask = segment._labels == unit_id
219+
# since all max channels are the same, we can just grab the first occurrence for the unit
220+
max_channel = segment._max_channels[label_mask][0]
221+
max_channels.append(max_channel)
222+
223+
self.set_property(key="max_channel", values=max_channels)
224+
205225
self._kwargs = {
206226
"file_path": str(Path(file_path).absolute()),
207227
"sampling_frequency": sampling_frequency,
208228
}
209229

210230
@staticmethod
211231
def write_sorting(sorting, save_path, write_primary_channels=False):
212-
assert sorting.get_num_segments() == 1, "MdaSorting.write_sorting() can only write a single segment " "sorting"
232+
assert sorting.get_num_segments() == 1, "MdaSorting.write_sorting() can only write a single segment sorting"
213233
unit_ids = sorting.get_unit_ids()
214234
times_list = []
215235
labels_list = []
@@ -223,7 +243,7 @@ def write_sorting(sorting, save_path, write_primary_channels=False):
223243
else:
224244
labels_list.append(np.ones(times.shape, dtype=int) * unit_index)
225245
if write_primary_channels:
226-
if "max_channel" in sorting.get_unit_property_names(unit_id):
246+
if "max_channel" in sorting.get_property_keys():
227247
primary_channels_list.append([sorting.get_unit_property(unit_id, "max_channel")] * times.shape[0])
228248
else:
229249
raise ValueError(

src/spikeinterface/extractors/tests/test_mdaextractors.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,25 @@ def test_mda_extractors(create_cache_folder):
2121

2222
check_recordings_equal(rec, rec_mda, return_scaled=False)
2323

24+
# Write without setting max_channel
2425
MdaSortingExtractor.write_sorting(sort, cache_folder / "mdatest" / "firings.mda")
2526
sort_mda = MdaSortingExtractor(
2627
cache_folder / "mdatest" / "firings.mda", sampling_frequency=sort.get_sampling_frequency()
2728
)
2829

2930
check_sortings_equal(sort, sort_mda)
3031

32+
# Set a fake max channel (1-indexed) for each unit
33+
sort.set_property(key="max_channel", values=[i % rec.get_num_channels() + 1 for i in range(sort.get_num_units())])
34+
35+
# Write with setting max_channel
36+
MdaSortingExtractor.write_sorting(sort, cache_folder / "mdatest" / "firings.mda", write_primary_channels=True)
37+
sort_mda = MdaSortingExtractor(
38+
cache_folder / "mdatest" / "firings.mda", sampling_frequency=sort.get_sampling_frequency()
39+
)
40+
41+
check_sortings_equal(sort, sort_mda)
42+
3143

3244
if __name__ == "__main__":
3345
test_mda_extractors()

0 commit comments

Comments
 (0)