@@ -29,30 +29,24 @@ def run(self):
2929 self .result = {"sorting" : sorting }
3030
3131 def compute_result (self , residulal_peak_threshold = 6 , ** job_kwargs ):
32-
33- sorting = self .result ['sorting' ]
34- analyzer = create_sorting_analyzer (
35- sorting , self .recording , sparse = True , format = "memory" , ** job_kwargs
36- )
32+
33+ sorting = self .result ["sorting" ]
34+ analyzer = create_sorting_analyzer (sorting , self .recording , sparse = True , format = "memory" , ** job_kwargs )
3735 analyzer .compute ("random_spikes" )
3836 analyzer .compute ("templates" )
3937 analyzer .compute ("noise_levels" )
40- analyzer .compute (
41- {
42- "spike_amplitudes" : {},
43- "amplitude_scalings" : {"handle_collisions" : False }
44- },
45- ** job_kwargs )
38+ analyzer .compute ({"spike_amplitudes" : {}, "amplitude_scalings" : {"handle_collisions" : False }}, ** job_kwargs )
4639
4740 analyzer .compute ("quality_metrics" , ** job_kwargs )
4841
4942 residual , peaks = analyse_residual (
50- analyzer , detect_peaks_kwargs = dict (
43+ analyzer ,
44+ detect_peaks_kwargs = dict (
5145 method = "locally_exclusive" ,
5246 peak_sign = "neg" ,
5347 detect_threshold = residulal_peak_threshold ,
5448 ),
55- ** job_kwargs
49+ ** job_kwargs ,
5650 )
5751
5852 self .result ["sorter_analyzer" ] = analyzer
@@ -66,11 +60,9 @@ def compute_result(self, residulal_peak_threshold=6, **job_kwargs):
6660 ("multi_comp" , "pickle" ),
6761 ("sorter_analyzer" , "sorting_analyzer" ),
6862 ("peaks_from_residual" , "npy" ),
69-
7063 ]
7164
7265
73-
7466class SorterStudyWithoutGroundTruth (BenchmarkStudy ):
7567 """
7668 This class is an alternative to SorterStudy when the dataset do not have groundtruth.
@@ -87,19 +79,21 @@ def create_benchmark(self, key):
8779 sorter_folder = self .folder / "sorters" / self .key_to_str (key )
8880 benchmark = SorterBenchmarkWithoutGroundTruth (recording , gt_sorting , params , sorter_folder )
8981 return benchmark
90-
82+
9183 def _get_comparison_groups (self ):
9284 # multicomparison are done on all cases sharing the same dataset key.
9385 case_keys = list (self .cases .keys ())
9486 groups = {}
9587 for case_key in case_keys :
96- data_key = self .cases [case_key ][' dataset' ]
88+ data_key = self .cases [case_key ][" dataset" ]
9789 if data_key not in groups :
9890 groups [data_key ] = []
9991 groups [data_key ].append (case_key )
10092 return groups
10193
102- def compute_results (self , case_keys = None , verbose = False , delta_time = 0.4 , match_score = 0.5 , chance_score = 0.1 , ** result_params ):
94+ def compute_results (
95+ self , case_keys = None , verbose = False , delta_time = 0.4 , match_score = 0.5 , chance_score = 0.1 , ** result_params
96+ ):
10397 # Here we need a hack because the results is not computed case by case but all at once
10498
10599 assert case_keys is None , "SorterStudyWithoutGroundTruth do not permit compute_results for sub cases"
@@ -115,7 +109,7 @@ def compute_results(self, case_keys=None, verbose=False, delta_time=0.4, match_s
115109
116110 for data_key , group in groups .items ():
117111
118- sorting_list = [self .get_result (key )[' sorting' ] for key in group ]
112+ sorting_list = [self .get_result (key )[" sorting" ] for key in group ]
119113 name_list = [key for key in group ]
120114 multi_comp = compare_multiple_sorters (
121115 sorting_list ,
@@ -132,7 +126,7 @@ def compute_results(self, case_keys=None, verbose=False, delta_time=0.4, match_s
132126 # and then the same multi comp is stored for each case_key
133127 for key in case_keys :
134128 benchmark = self .benchmarks [key ]
135- benchmark .result [' multi_comp' ] = multi_comp
129+ benchmark .result [" multi_comp" ] = multi_comp
136130 benchmark .save_result (self .folder / "results" / self .key_to_str (key ))
137131
138132 def plot_residual_peak_amplitudes (self , figsize = None ):
@@ -143,7 +137,7 @@ def plot_residual_peak_amplitudes(self, figsize=None):
143137
144138 for data_key , group in groups .items ():
145139 fig , ax = plt .subplots (figsize = figsize )
146-
140+
147141 lim0 , lim1 = np .inf , - np .inf
148142
149143 for key in group :
@@ -158,7 +152,6 @@ def plot_residual_peak_amplitudes(self, figsize=None):
158152 if lim0 > 0 :
159153 lim0 = 0
160154
161-
162155 for key in group :
163156 peaks = self .get_result (key )["peaks_from_residual" ]
164157 print (peaks .size )
@@ -167,7 +160,7 @@ def plot_residual_peak_amplitudes(self, figsize=None):
167160 ax .plot (bins [:- 1 ], count , color = colors [key ], label = self .cases [key ]["label" ])
168161
169162 ax .legend ()
170-
163+
171164 # def plot_quality_metrics_comparison_on_agreement(self, qm_name='rp_contamination', figsize=None):
172165 # import matplotlib.pyplot as plt
173166
@@ -189,13 +182,13 @@ def plot_residual_peak_amplitudes(self, figsize=None):
189182
190183 # multi_comp = self.get_result(key1)['multi_comp']
191184 # comp = multi_comp.comparisons[key1, key2]
192-
185+
193186 # match_12 = comp.hungarian_match_12
194187 # if match_12.dtype.kind =='i':
195188 # mask = match_12.values != -1
196189 # if match_12.dtype.kind =='U':
197190 # mask = match_12.values != ''
198-
191+
199192 # common_unit1_ids = match_12[mask].index
200193 # common_unit2_ids = match_12[mask].values
201194 # metrics1 = self.get_result(key1)["sorter_analyzer"].get_extension("quality_metrics").get_data()
@@ -215,5 +208,3 @@ def plot_residual_peak_amplitudes(self, figsize=None):
215208 # ax.set_yticks([])
216209 # ax.set_xticklabels([])
217210 # ax.set_yticklabels([])
218-
219-
0 commit comments