11from spikeinterface .postprocessing .tests .common_extension_tests import AnalyzerExtensionCommonTestSuite
22from spikeinterface .postprocessing import ComputeUnitLocations
33import pytest
4+ from probeinterface import Probe
5+ from spikeinterface .core import create_sorting_analyzer , generate_ground_truth_recording
6+ import numpy as np
47
58
69class TestUnitLocationsExtension (AnalyzerExtensionCommonTestSuite ):
@@ -18,3 +21,46 @@ class TestUnitLocationsExtension(AnalyzerExtensionCommonTestSuite):
1821 )
1922 def test_extension (self , params ):
2023 self .run_extension_tests (ComputeUnitLocations , params = params )
24+
25+
26+ def test_2d_and_3d_unit_localization ():
27+ """
28+ Our localization tools do not use the 3rd dimension of contact position.
29+ Hence if we pass the same data with a 2D probe and a 3D probe (with the
30+ same 2D positions), we should get the same result for all methods.
31+
32+ Also serves as an integration test of all the localization methods for
33+ 2D and 3D channel locations.
34+ """
35+
36+ # make a 2D synthetic recording
37+ positions_2D = [[0 , 0 ], [0 , 1 ], [1 , 0 ], [1 , 1 ]]
38+
39+ probe = Probe (ndim = 2 , si_units = "um" )
40+ probe .set_contacts (positions = positions_2D )
41+ probe .set_device_channel_indices (np .arange (4 ))
42+
43+ recording , sorting = generate_ground_truth_recording (num_channels = 4 , num_units = 2 , probe = probe , seed = 1205 )
44+ analyzer_2D = create_sorting_analyzer (sorting , recording , sparse = False )
45+ analyzer_2D .compute (["random_spikes" , "templates" ])
46+
47+ # make a 3D synthetic recording
48+ positions_3D = [[0 , 0 , 10 ], [0 , 1 , 15 ], [1 , 0 , 20 ], [1 , 1 , 100 ]]
49+
50+ probe = Probe (ndim = 3 , si_units = "um" )
51+ probe .set_contacts (positions = positions_3D , plane_axes = [[[1 , 0 , 0 ], [0 , 1 , 0 ]] * 4 ])
52+
53+ probe .set_device_channel_indices (np .arange (4 ))
54+
55+ recording_3D , sorting_3D = generate_ground_truth_recording (num_channels = 4 , num_units = 2 , probe = probe , seed = 1205 )
56+ analyzer_3D = create_sorting_analyzer (sorting_3D , recording_3D , sparse = False )
57+ analyzer_3D .compute (["random_spikes" , "templates" ])
58+
59+ for method in ["center_of_mass" , "grid_convolution" , "monopolar_triangulation" , "max_channel" ]:
60+
61+ analyzer_2D .compute ("unit_locations" , method = method )
62+ analyzer_3D .compute ("unit_locations" , method = method )
63+ unit_locations_2D = analyzer_2D .get_extension ("unit_locations" ).get_data ()
64+ unit_locations_3D = analyzer_3D .get_extension ("unit_locations" ).get_data ()
65+
66+ assert np .all (unit_locations_2D [:, :2 ] == unit_locations_3D [:, :2 ])
0 commit comments