@@ -29,30 +29,24 @@ def run(self):
2929 self .result = {"sorting" : sorting }
3030
3131 def compute_result (self , residulal_peak_threshold = 6 , ** job_kwargs ):
32-
33- sorting = self .result ['sorting' ]
34- analyzer = create_sorting_analyzer (
35- sorting , self .recording , sparse = True , format = "memory" , ** job_kwargs
36- )
32+
33+ sorting = self .result ["sorting" ]
34+ analyzer = create_sorting_analyzer (sorting , self .recording , sparse = True , format = "memory" , ** job_kwargs )
3735 analyzer .compute ("random_spikes" )
3836 analyzer .compute ("templates" )
3937 analyzer .compute ("noise_levels" )
40- analyzer .compute (
41- {
42- "spike_amplitudes" : {},
43- "amplitude_scalings" : {"handle_collisions" : False }
44- },
45- ** job_kwargs )
38+ analyzer .compute ({"spike_amplitudes" : {}, "amplitude_scalings" : {"handle_collisions" : False }}, ** job_kwargs )
4639
4740 analyzer .compute ("quality_metrics" , ** job_kwargs )
4841
4942 residual , peaks = analyse_residual (
50- analyzer , detect_peaks_kwargs = dict (
43+ analyzer ,
44+ detect_peaks_kwargs = dict (
5145 method = "locally_exclusive" ,
5246 peak_sign = "neg" ,
5347 detect_threshold = residulal_peak_threshold ,
5448 ),
55- ** job_kwargs
49+ ** job_kwargs ,
5650 )
5751
5852 self .result ["sorter_analyzer" ] = analyzer
@@ -65,11 +59,9 @@ def compute_result(self, residulal_peak_threshold=6, **job_kwargs):
6559 ("multi_comp" , "pickle" ),
6660 ("sorter_analyzer" , "sorting_analyzer" ),
6761 ("peaks_from_residual" , "npy" ),
68-
6962 ]
7063
7164
72-
7365class SorterStudyWithoutGroundTruth (BenchmarkStudy ):
7466 """
7567 This class is an alternative to SorterStudy when the dataset do not have groundtruth
@@ -84,19 +76,21 @@ def create_benchmark(self, key):
8476 sorter_folder = self .folder / "sorters" / self .key_to_str (key )
8577 benchmark = SorterBenchmarkWithoutGroundTruth (recording , gt_sorting , params , sorter_folder )
8678 return benchmark
87-
79+
8880 def _get_comparison_groups (self ):
8981 # multicomparison are done on all cases sharing the same dataset key.
9082 case_keys = list (self .cases .keys ())
9183 groups = {}
9284 for case_key in case_keys :
93- data_key = self .cases [case_key ][' dataset' ]
85+ data_key = self .cases [case_key ][" dataset" ]
9486 if data_key not in groups :
9587 groups [data_key ] = []
9688 groups [data_key ].append (case_key )
9789 return groups
9890
99- def compute_results (self , case_keys = None , verbose = False , delta_time = 0.4 , match_score = 0.5 , chance_score = 0.1 , ** result_params ):
91+ def compute_results (
92+ self , case_keys = None , verbose = False , delta_time = 0.4 , match_score = 0.5 , chance_score = 0.1 , ** result_params
93+ ):
10094 # Here we need a hack because the results is not computed case by case but all at once
10195
10296 assert case_keys is None , "SorterStudyWithoutGroundTruth do not permit compute_results for sub cases"
@@ -112,7 +106,7 @@ def compute_results(self, case_keys=None, verbose=False, delta_time=0.4, match_s
112106
113107 for data_key , group in groups .items ():
114108
115- sorting_list = [self .get_result (key )[' sorting' ] for key in group ]
109+ sorting_list = [self .get_result (key )[" sorting" ] for key in group ]
116110 name_list = [key for key in group ]
117111 multi_comp = compare_multiple_sorters (
118112 sorting_list ,
@@ -129,7 +123,7 @@ def compute_results(self, case_keys=None, verbose=False, delta_time=0.4, match_s
129123 # and then the same multi comp is stored for each case_key
130124 for key in case_keys :
131125 benchmark = self .benchmarks [key ]
132- benchmark .result [' multi_comp' ] = multi_comp
126+ benchmark .result [" multi_comp" ] = multi_comp
133127 benchmark .save_result (self .folder / "results" / self .key_to_str (key ))
134128
135129 def plot_residual_peak_amplitudes (self , figsize = None ):
@@ -140,7 +134,7 @@ def plot_residual_peak_amplitudes(self, figsize=None):
140134
141135 for data_key , group in groups .items ():
142136 fig , ax = plt .subplots (figsize = figsize )
143-
137+
144138 lim0 , lim1 = np .inf , - np .inf
145139
146140 for key in group :
@@ -155,7 +149,6 @@ def plot_residual_peak_amplitudes(self, figsize=None):
155149 if lim0 > 0 :
156150 lim0 = 0
157151
158-
159152 for key in group :
160153 peaks = self .get_result (key )["peaks_from_residual" ]
161154 print (peaks .size )
@@ -164,6 +157,7 @@ def plot_residual_peak_amplitudes(self, figsize=None):
164157 ax .plot (bins [:- 1 ], count , color = colors [key ], label = self .cases [key ]["label" ])
165158
166159 ax .legend ()
160+
167161 # def plot_quality_metrics_comparison_on_agreement(self, qm_name='rp_contamination', figsize=None):
168162 # import matplotlib.pyplot as plt
169163
@@ -185,13 +179,13 @@ def plot_residual_peak_amplitudes(self, figsize=None):
185179
186180 # multi_comp = self.get_result(key1)['multi_comp']
187181 # comp = multi_comp.comparisons[key1, key2]
188-
182+
189183 # match_12 = comp.hungarian_match_12
190184 # if match_12.dtype.kind =='i':
191185 # mask = match_12.values != -1
192186 # if match_12.dtype.kind =='U':
193187 # mask = match_12.values != ''
194-
188+
195189 # common_unit1_ids = match_12[mask].index
196190 # common_unit2_ids = match_12[mask].values
197191 # metrics1 = self.get_result(key1)["sorter_analyzer"].get_extension("quality_metrics").get_data()
@@ -212,7 +206,6 @@ def plot_residual_peak_amplitudes(self, figsize=None):
212206 # ax.set_xticklabels([])
213207 # ax.set_yticklabels([])
214208
215-
216209 # def plot_quality_metrics_comparison_on_non_agreement(self, qm_name='rp_contamination', figsize=None):
217210 # import matplotlib.pyplot as plt
218211
@@ -223,4 +216,3 @@ def plot_residual_peak_amplitudes(self, figsize=None):
223216 # fig, ax = plt.subplots(figsize=figsize)
224217 # for key in group:
225218 # pass
226-
0 commit comments