Skip to content

Commit 861f73a

Browse files
committed
wip
1 parent aa1a6f3 commit 861f73a

4 files changed

Lines changed: 26 additions & 9 deletions

File tree

src/spikeinterface/benchmark/benchmark_matching.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,16 @@ def plot_performances_ordered(self, *args, **kwargs):
9191

9292
return plot_performances_ordered(self, *args, **kwargs)
9393

94-
def plot_collisions(self, case_keys=None, axs=None, figsize=None):
94+
def plot_collisions(self, case_keys=None, metric="l2", mode="lines", show_legend=True, axs=None, figsize=None):
9595
if case_keys is None:
9696
case_keys = list(self.cases.keys())
9797
import matplotlib.pyplot as plt
9898

9999
if axs is None:
100100
fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
101101
axs = axs[0, :]
102+
else:
103+
fig = axs[0].figure
102104

103105

104106
for count, key in enumerate(case_keys):
@@ -108,10 +110,13 @@ def plot_collisions(self, case_keys=None, axs=None, figsize=None):
108110
plot_comparison_collision_by_similarity(
109111
self.get_result(key)["gt_collision"],
110112
templates_array,
113+
metric=metric,
111114
ax=ax,
112-
show_legend=True,
113-
mode="lines",
114-
good_only=False,
115+
show_legend=show_legend,
116+
mode=mode,
117+
# good_only=False,
118+
# good_only=False,
119+
good_only=True,
115120
)
116121

117122
ax.set_title(label)

src/spikeinterface/comparison/collision.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,14 +214,23 @@ def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good
214214

215215
tp1 = self.all_tp[ind1, ind2, :]
216216
fn1 = self.all_fn[ind1, ind2, :]
217-
recall1 = tp1 / (tp1 + fn1)
217+
recall1 = np.zeros(tp1.size)
218+
recall1[:] = np.nan
219+
mask = (tp1 + fn1) > 0
220+
if np.any(mask):
221+
recall1[mask] = tp1[mask] / (tp1[mask] + fn1[mask])
222+
218223
recall_scores.append(recall1)
219224
similarities.append(similarity_matrix[r, c])
220225
pair_names.append(f"{u1} {u2}")
221226

222227
tp2 = self.all_tp[ind2, ind1, :]
223228
fn2 = self.all_fn[ind2, ind1, :]
224-
recall2 = tp2 / (tp2 + fn2)
229+
recall2 = np.zeros(tp2.size)
230+
recall2[:] = np.nan
231+
mask = (tp2 + fn2) > 0
232+
if np.any(mask):
233+
recall2[mask] = tp2[mask] / (tp2[mask] + fn2[mask])
225234
recall_scores.append(recall2)
226235
similarities.append(similarity_matrix[r, c])
227236
pair_names.append(f"{u2} {u1}")

src/spikeinterface/sorters/internal/spyking_circus2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,15 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
3535
"motion_correction": {"preset": "dredge_fast"},
3636
"merging": {"max_distance_um": 50},
3737
"clustering": {"method": "circus-clustering", "method_kwargs": dict()},
38-
"matching": {"method": "circus-omp-svd", "method_kwargs": dict()},
38+
# "matching": {"method": "circus-omp-svd", "method_kwargs": dict()},
39+
"matching": {"method": "wobble", "method_kwargs": dict()},
3940
"apply_preprocessing": True,
4041
"templates_from_svd": True,
4142
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
4243
"chunk_preprocessing": {"memory_limit": None},
4344
"multi_units_only": False,
44-
"job_kwargs": {"n_jobs": 0.75},
45+
# "job_kwargs": {"n_jobs": 0.75},
46+
"job_kwargs": {"n_jobs": None},
4547
"seed": 42,
4648
"deterministic_peaks_detection": False,
4749
"debug": False,

src/spikeinterface/sortingcomponents/clustering/circus.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ class CircusClustering:
5454
"few_waveforms": None,
5555
"ms_before": 2.0,
5656
"ms_after": 2.0,
57-
"remove_small_snr": False,
57+
# "remove_small_snr": False,
58+
"remove_small_snr": True,
5859
"seed": None,
5960
"noise_threshold": 4,
6061
"rank": 5,

0 commit comments

Comments
 (0)