Skip to content

Commit 2ef382c

Browse files
authored
Merge pull request #3915 from alejoe91/fix-match-score-in-benchmark
Expose `match_score` to benchmark compute_results
2 parents acf848f + 9ec3c31 commit 2ef382c

3 files changed

Lines changed: 12 additions & 6 deletions

File tree

src/spikeinterface/benchmark/benchmark_matching.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,11 @@ def run(self, **job_kwargs):
3838
self.result = {"sorting": sorting, "spikes": spikes}
3939
self.result["templates"] = self.templates
4040

41-
def compute_result(self, with_collision=False, **result_params):
41+
def compute_result(self, with_collision=False, match_score=0.5, exhaustive_gt=True):
4242
sorting = self.result["sorting"]
43-
comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True)
43+
comp = compare_sorter_to_ground_truth(
44+
self.gt_sorting, sorting, exhaustive_gt=exhaustive_gt, match_score=match_score
45+
)
4446
self.result["gt_comparison"] = comp
4547
if with_collision:
4648
self.result["gt_collision"] = CollisionGTComparison(self.gt_sorting, sorting, exhaustive_gt=True)

src/spikeinterface/benchmark/benchmark_merging.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ def run(self, **job_kwargs):
3535

3636
self.result["sorting"] = merged_analyzer.sorting
3737

38-
def compute_result(self, **result_params):
38+
def compute_result(self, match_score=0.5, exhaustive_gt=True):
3939
sorting = self.result["sorting"]
40-
comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True)
40+
comp = compare_sorter_to_ground_truth(
41+
self.gt_sorting, sorting, exhaustive_gt=exhaustive_gt, match_score=match_score
42+
)
4143
self.result["gt_comparison"] = comp
4244

4345
_run_key_saved = [("sorting", "sorting"), ("merges", "pickle"), ("merged_pairs", "pickle"), ("outs", "pickle")]

src/spikeinterface/benchmark/benchmark_sorter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@ def run(self):
2626
sorting = NumpySorting.from_sorting(raw_sorting)
2727
self.result = {"sorting": sorting}
2828

29-
def compute_result(self, exhaustive_gt=True):
29+
def compute_result(self, match_score=0.5, exhaustive_gt=True):
3030
# run becnhmark result
3131
sorting = self.result["sorting"]
32-
comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=exhaustive_gt)
32+
comp = compare_sorter_to_ground_truth(
33+
self.gt_sorting, sorting, exhaustive_gt=exhaustive_gt, match_score=match_score
34+
)
3335
self.result["gt_comparison"] = comp
3436

3537
_run_key_saved = [

0 commit comments

Comments
 (0)