@@ -180,6 +180,13 @@ def get_traces(
180180class MdaSortingExtractor (BaseSorting ):
181181 """Load MDA format data as a sorting extractor.
182182
183+ NOTE: As in the MDA format, the max_channel property indexes the channels that are given as input
184+ to the sorter.
185+ If sorting was run on a subset of channels of the recording, then the max_channel values are
186+ based on that subset, so care must be taken when associating these values with a recording.
187+ If additional sorting segments are added to this sorting extractor after initialization,
188+ then max_channel will not be updated. The max_channel indices begin at 1.
189+
183190 Parameters
184191 ----------
185192 file_path : str or Path
@@ -202,14 +209,27 @@ def __init__(self, file_path, sampling_frequency):
202209 sorting_segment = MdaSortingSegment (firings )
203210 self .add_sorting_segment (sorting_segment )
204211
212+ # Store the max channel for each unit
213+ # Every spike assigned to a unit (label) has the same max channel
214+ # ref: https://github.com/SpikeInterface/spikeinterface/issues/3695#issuecomment-2663329006
215+ max_channels = []
216+ segment = self ._sorting_segments [0 ]
217+ for unit_id in self .unit_ids :
218+ label_mask = segment ._labels == unit_id
219+ # since all max channels are the same, we can just grab the first occurrence for the unit
220+ max_channel = segment ._max_channels [label_mask ][0 ]
221+ max_channels .append (max_channel )
222+
223+ self .set_property (key = "max_channel" , values = max_channels )
224+
205225 self ._kwargs = {
206226 "file_path" : str (Path (file_path ).absolute ()),
207227 "sampling_frequency" : sampling_frequency ,
208228 }
209229
210230 @staticmethod
211231 def write_sorting (sorting , save_path , write_primary_channels = False ):
212- assert sorting .get_num_segments () == 1 , "MdaSorting.write_sorting() can only write a single segment " " sorting"
232+ assert sorting .get_num_segments () == 1 , "MdaSorting.write_sorting() can only write a single segment sorting"
213233 unit_ids = sorting .get_unit_ids ()
214234 times_list = []
215235 labels_list = []
@@ -223,7 +243,7 @@ def write_sorting(sorting, save_path, write_primary_channels=False):
223243 else :
224244 labels_list .append (np .ones (times .shape , dtype = int ) * unit_index )
225245 if write_primary_channels :
226- if "max_channel" in sorting .get_unit_property_names ( unit_id ):
246+ if "max_channel" in sorting .get_property_keys ( ):
227247 primary_channels_list .append ([sorting .get_unit_property (unit_id , "max_channel" )] * times .shape [0 ])
228248 else :
229249 raise ValueError (
0 commit comments