99import time
1010
1111
12+ from spikeinterface .core import SortingAnalyzer , ChannelSparsity , NumpySorting
1213from spikeinterface .core import SortingAnalyzer , NumpySorting
1314from spikeinterface .core .job_tools import fix_job_kwargs , split_job_kwargs
1415from spikeinterface import load , create_sorting_analyzer , load_sorting_analyzer
@@ -46,6 +47,7 @@ def __init__(self, study_folder):
4647 self .levels = None
4748 self .colors_by_case = None
4849 self .colors_by_levels = {}
50+ self .labels_by_levels = {}
4951 self .scan_folder ()
5052
5153 @classmethod
@@ -120,8 +122,41 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None):
120122 rec , gt_sorting = data
121123
122124 if gt_sorting is not None :
125+ if "gt_unit_locations" in gt_sorting .get_property_keys ():
126+ # if real units locations is present then use it for a better sparsity
127+ # then the real max channel is used
128+ radius_um = 100.0
129+ channel_ids = rec .channel_ids
130+ unit_ids = gt_sorting .unit_ids
131+ gt_unit_locations = gt_sorting .get_property ("gt_unit_locations" )
132+ channel_locations = rec .get_channel_locations ()
133+ max_channel_indices = np .argmin (
134+ np .linalg .norm (
135+ gt_unit_locations [:, np .newaxis , :2 ] - channel_locations [np .newaxis , :], axis = 2
136+ ),
137+ axis = 1 ,
138+ )
139+ mask = np .zeros ((unit_ids .size , channel_ids .size ), dtype = "bool" )
140+ distances = np .linalg .norm (
141+ channel_locations [:, np .newaxis ] - channel_locations [np .newaxis , :], axis = 2
142+ )
143+ for unit_ind , unit_id in enumerate (unit_ids ):
144+ chan_ind = max_channel_indices [unit_ind ]
145+ (chan_inds ,) = np .nonzero (distances [chan_ind , :] <= radius_um )
146+ mask [unit_ind , chan_inds ] = True
147+ sparsity = ChannelSparsity (mask , unit_ids , channel_ids )
148+ sparse = False
149+ else :
150+ sparse = True
151+ sparsity = None
152+
123153 analyzer = create_sorting_analyzer (
124- gt_sorting , rec , sparse = True , format = "binary_folder" , folder = local_analyzer_folder
154+ gt_sorting ,
155+ rec ,
156+ sparse = sparse ,
157+ sparsity = sparsity ,
158+ format = "binary_folder" ,
159+ folder = local_analyzer_folder ,
125160 )
126161 analyzer .compute ("random_spikes" )
127162 analyzer .compute ("templates" )
@@ -135,6 +170,7 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None):
135170 analyzer = create_sorting_analyzer (
136171 gt_sorting , rec , sparse = False , format = "binary_folder" , folder = local_analyzer_folder
137172 )
173+
138174 else :
139175 # new case : analzyer
140176 assert isinstance (data , SortingAnalyzer )
@@ -254,6 +290,8 @@ def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs):
254290
255291 for key in job_keys :
256292 benchmark = self .create_benchmark (key )
293+ if verbose :
294+ print ("### Run benchmark" , key , "###" )
257295 t0 = time .perf_counter ()
258296 benchmark .run (** job_kwargs )
259297 t1 = time .perf_counter ()
@@ -324,14 +362,16 @@ def get_run_times(self, case_keys=None):
324362 df .index .names = self .levels
325363 return df
326364
327- def get_grouped_keys_mapping (self , levels_to_group_by = None ):
365+ def get_grouped_keys_mapping (self , levels_to_group_by = None , case_keys = None ):
328366 """
329367 Return a dictionary of grouped keys.
330368
331369 Parameters
332370 ----------
333371 levels_to_group_by : list
334372 A list of levels to group by.
373+ case_keys : list
374+ Optionaly a sub list of case_keys to consider
335375
336376 Returns
337377 -------
@@ -341,18 +381,19 @@ def get_grouped_keys_mapping(self, levels_to_group_by=None):
341381 labels : dict
342382 A dictionary of labels, with the new keys as keys and the labels as values.
343383 """
344- cases = list (self .cases .keys ())
384+ if case_keys is None :
385+ case_keys = list (self .cases .keys ())
345386 if levels_to_group_by is None or self .levels is None :
346- keys_mapping = {key : [key ] for key in cases }
387+ keys_mapping = {key : [key ] for key in case_keys }
347388 elif len (self .levels ) == 1 :
348- keys_mapping = {key : [key ] for key in cases }
389+ keys_mapping = {key : [key ] for key in case_keys }
349390 else :
350391 study_levels = self .levels
351392 assert np .all (
352393 [l in study_levels for l in levels_to_group_by ]
353394 ), f"levels_to_group_by must be in { study_levels } , got { levels_to_group_by } "
354395 keys_mapping = {}
355- for key in cases :
396+ for key in case_keys :
356397 new_key = tuple (key [list (study_levels ).index (level )] for level in levels_to_group_by )
357398 if len (new_key ) == 1 :
358399 new_key = new_key [0 ]
@@ -361,13 +402,17 @@ def get_grouped_keys_mapping(self, levels_to_group_by=None):
361402 keys_mapping [new_key ].append (key )
362403
363404 if levels_to_group_by is None :
364- labels = {key : self .cases [key ]["label" ] for key in cases }
405+ labels = {key : self .cases [key ]["label" ] for key in case_keys }
365406 else :
366- key0 = list ( keys_mapping . keys ()) [0 ]
367- if isinstance ( key0 , tuple ) :
368- labels = { key : "-" . join ( key ) for key in keys_mapping }
407+ level_key = tuple ( levels_to_group_by ) if len ( levels_to_group_by ) > 1 else levels_to_group_by [0 ]
408+ if level_key in self . labels_by_levels :
409+ labels = self . labels_by_levels [ level_key ]
369410 else :
370- labels = {key : key for key in keys_mapping }
411+ key0 = list (keys_mapping .keys ())[0 ]
412+ if isinstance (key0 , tuple ):
413+ labels = {key : "-" .join (key ) for key in keys_mapping }
414+ else :
415+ labels = {key : key for key in keys_mapping }
371416
372417 return keys_mapping , labels
373418
@@ -383,6 +428,8 @@ def compute_results(self, case_keys=None, verbose=False, **result_params):
383428
384429 job_keys = []
385430 for key in case_keys :
431+ if verbose :
432+ print ("### Compute result" , key , "###" )
386433 benchmark = self .benchmarks [key ]
387434 assert benchmark is not None
388435 benchmark .compute_result (** result_params )
@@ -438,9 +485,9 @@ def get_gt_unit_locations(self, case_key):
438485 unit_locations_ext = sorting_analyzer .get_extension ("unit_locations" )
439486 return unit_locations_ext .get_data ()
440487
441- def get_templates (self , key , operator = "average" ):
488+ def get_templates (self , key , operator = "average" , outputs = "numpy" ):
442489 sorting_analyzer = self .get_sorting_analyzer (case_key = key )
443- templates = sorting_analyzer .get_extenson ("templates" ).get_data (operator = operator )
490+ templates = sorting_analyzer .get_extension ("templates" ).get_data (operator = operator , outputs = outputs )
444491 return templates
445492
446493 def compute_metrics (self , case_keys = None , metric_names = ["snr" , "firing_rate" ], force = False , ** job_kwargs ):
@@ -668,15 +715,15 @@ def get_count_units(self, case_keys=None, well_detected_score=None, redundant_sc
668715 gt_sorting = comp .sorting1
669716 sorting = comp .sorting2
670717
671- count_units .loc [key , "num_gt" ] = len (gt_sorting .get_unit_ids ())
672- count_units .loc [key , "num_sorter" ] = len (sorting .get_unit_ids ())
673- count_units .loc [key , "num_well_detected" ] = comp .count_well_detected_units (well_detected_score )
718+ count_units .at [key , "num_gt" ] = len (gt_sorting .get_unit_ids ())
719+ count_units .at [key , "num_sorter" ] = len (sorting .get_unit_ids ())
720+ count_units .at [key , "num_well_detected" ] = comp .count_well_detected_units (well_detected_score )
674721
675722 if comp .exhaustive_gt :
676- count_units .loc [key , "num_redundant" ] = comp .count_redundant_units (redundant_score )
677- count_units .loc [key , "num_overmerged" ] = comp .count_overmerged_units (overmerged_score )
678- count_units .loc [key , "num_false_positive" ] = comp .count_false_positive_units (redundant_score )
679- count_units .loc [key , "num_bad" ] = comp .count_bad_units ()
723+ count_units .at [key , "num_redundant" ] = comp .count_redundant_units (redundant_score )
724+ count_units .at [key , "num_overmerged" ] = comp .count_overmerged_units (overmerged_score )
725+ count_units .at [key , "num_false_positive" ] = comp .count_false_positive_units (redundant_score )
726+ count_units .at [key , "num_bad" ] = comp .count_bad_units ()
680727
681728 return count_units
682729
0 commit comments