Skip to content

Commit 0f3cd0f

Browse files
"Allow" for 3D channel locations for compute unit locs and template metrics (#4572)
Co-authored-by: Alessio Buccino <alejoe9187@gmail.com>
1 parent b16101f commit 0f3cd0f

3 files changed

Lines changed: 57 additions & 8 deletions

File tree

src/spikeinterface/metrics/template/template_metrics.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,10 @@ def _set_params(
201201
if include_multi_channel_metrics or (
202202
metric_names is not None and any([m in get_multi_channel_template_metric_names() for m in metric_names])
203203
):
204-
assert (
205-
self.sorting_analyzer.get_channel_locations().shape[1] == 2
206-
), "If multi-channel metrics are computed, channel locations must be 2D."
204+
if self.sorting_analyzer.get_channel_locations().shape[1] == 3:
205+
warnings.warn(
206+
"Multi-channel metrics assume 2D channel locations. We will assume the first two dimensions are the physically relevant ones"
207+
)
207208

208209
if metric_names is None:
209210
metric_names = get_single_channel_template_metric_names()
@@ -260,7 +261,9 @@ def _prepare_data(self, sorting_analyzer, unit_ids):
260261
)
261262
all_templates = get_dense_templates_array(sorting_analyzer, return_in_uV=True, operator=operator)
262263

263-
channel_locations = sorting_analyzer.get_channel_locations()
264+
analyzer_channel_locations = sorting_analyzer.get_channel_locations()
265+
# the template metrics only work for 2D probes. We warn users with 3D locations above.
266+
channel_locations = analyzer_channel_locations[:, :2]
264267

265268
main_channel_templates = []
266269
peaks_info = []

src/spikeinterface/postprocessing/localization_tools.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def compute_monopolar_triangulation(
104104
unit_location = np.zeros((unit_ids.size, 4), dtype="float64")
105105
for i, unit_id in enumerate(unit_ids):
106106
chan_inds = sparsity.unit_id_to_channel_indices[unit_id]
107-
local_contact_locations = contact_locations[chan_inds, :]
107+
local_contact_locations = contact_locations[chan_inds, :2]
108108

109109
# wf is (nsample, nchan) - chann is only nieghboor
110110
wf = templates[i, :, :][:, chan_inds]
@@ -182,7 +182,7 @@ def compute_center_of_mass(
182182
unit_location = np.zeros((len(unit_ids), 2), dtype="float64")
183183
for i, unit_id in enumerate(unit_ids):
184184
chan_inds = sparsity.unit_id_to_channel_indices[unit_id]
185-
local_contact_locations = contact_locations[chan_inds, :]
185+
local_contact_locations = contact_locations[chan_inds, :2]
186186

187187
wf = templates[i, :, :]
188188

@@ -247,7 +247,7 @@ def compute_grid_convolution(
247247
unit_location: np.array
248248
"""
249249

250-
contact_locations = sorting_analyzer_or_templates.get_channel_locations()
250+
contact_locations = sorting_analyzer_or_templates.get_channel_locations()[:, :2]
251251

252252
templates = get_dense_templates_array(
253253
sorting_analyzer_or_templates, return_in_uV=get_return_in_uV(sorting_analyzer_or_templates)
@@ -693,7 +693,7 @@ def compute_location_max_channel(
693693
extremum_channels_index = get_template_extremum_channel(
694694
templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, outputs="index", operator=operator
695695
)
696-
contact_locations = templates_or_sorting_analyzer.get_channel_locations()
696+
contact_locations = templates_or_sorting_analyzer.get_channel_locations()[:, :2]
697697
if unit_ids is None:
698698
unit_ids = templates_or_sorting_analyzer.unit_ids
699699
else:

src/spikeinterface/postprocessing/tests/test_unit_locations.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite
22
from spikeinterface.postprocessing import ComputeUnitLocations
33
import pytest
4+
from probeinterface import Probe
5+
from spikeinterface.core import create_sorting_analyzer, generate_ground_truth_recording
6+
import numpy as np
47

58

69
class TestUnitLocationsExtension(AnalyzerExtensionCommonTestSuite):
@@ -18,3 +21,46 @@ class TestUnitLocationsExtension(AnalyzerExtensionCommonTestSuite):
1821
)
1922
def test_extension(self, params):
2023
self.run_extension_tests(ComputeUnitLocations, params=params)
24+
25+
26+
def test_2d_and_3d_unit_localization():
27+
"""
28+
Our localization tools do not use the 3rd dimension of contact position.
29+
Hence if we pass the same data with a 2D probe and a 3D probe (with the
30+
same 2D positions), we should get the same result for all methods.
31+
32+
Also serves as an integration test of all the localization methods for
33+
2D and 3D channel locations.
34+
"""
35+
36+
# make a 2D synthetic recording
37+
positions_2D = [[0, 0], [0, 1], [1, 0], [1, 1]]
38+
39+
probe = Probe(ndim=2, si_units="um")
40+
probe.set_contacts(positions=positions_2D)
41+
probe.set_device_channel_indices(np.arange(4))
42+
43+
recording, sorting = generate_ground_truth_recording(num_channels=4, num_units=2, probe=probe, seed=1205)
44+
analyzer_2D = create_sorting_analyzer(sorting, recording, sparse=False)
45+
analyzer_2D.compute(["random_spikes", "templates"])
46+
47+
# make a 3D synthetic recording
48+
positions_3D = [[0, 0, 10], [0, 1, 15], [1, 0, 20], [1, 1, 100]]
49+
50+
probe = Probe(ndim=3, si_units="um")
51+
probe.set_contacts(positions=positions_3D, plane_axes=[[[1, 0, 0], [0, 1, 0]] * 4])
52+
53+
probe.set_device_channel_indices(np.arange(4))
54+
55+
recording_3D, sorting_3D = generate_ground_truth_recording(num_channels=4, num_units=2, probe=probe, seed=1205)
56+
analyzer_3D = create_sorting_analyzer(sorting_3D, recording_3D, sparse=False)
57+
analyzer_3D.compute(["random_spikes", "templates"])
58+
59+
for method in ["center_of_mass", "grid_convolution", "monopolar_triangulation", "max_channel"]:
60+
61+
analyzer_2D.compute("unit_locations", method=method)
62+
analyzer_3D.compute("unit_locations", method=method)
63+
unit_locations_2D = analyzer_2D.get_extension("unit_locations").get_data()
64+
unit_locations_3D = analyzer_3D.get_extension("unit_locations").get_data()
65+
66+
assert np.all(unit_locations_2D[:, :2] == unit_locations_3D[:, :2])

0 commit comments

Comments
 (0)