diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 0e5dd2694d..5b53360ca5 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -12,11 +12,12 @@ ComputeTemplates, create_sorting_analyzer, SortingAnalyzer, + aggregate_channels, ) from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations -from probeinterface import read_prb, Probe +from probeinterface import read_prb, Probe, ProbeGroup class BasePhyKilosortSortingExtractor(BaseSorting): @@ -314,7 +315,9 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort") -def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offset_to_uV=None) -> SortingAnalyzer: +def read_kilosort_as_analyzer( + folder_path, recording=None, unwhiten=True, gain_to_uV=None, offset_to_uV=None +) -> SortingAnalyzer: """ Load Kilosort output into a SortingAnalyzer. Output from Kilosort version 4.1 and above are supported. The function may work on older versions of Kilosort output, @@ -324,6 +327,8 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse ---------- folder_path : str or Path Path to the output Phy folder (containing the params.py). + recording : BaseRecording + A spikeinterface Recording object which will be attached to the analyzer unwhiten : bool, default: True Unwhiten the templates computed by kilosort. gain_to_uV : float | None, default: None @@ -359,25 +364,45 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse if (phy_path / "probe.prb").is_file(): probegroup = read_prb(phy_path / "probe.prb") - if len(probegroup.probes) > 0: - warnings.warn("Found more than one probe. Selecting the first probe in ProbeGroup.") - probe = probegroup.probes[0] elif (phy_path / "channel_positions.npy").is_file(): probe = Probe(si_units="um") channel_positions = np.load(phy_path / "channel_positions.npy") probe.set_contacts(channel_positions) - probe.set_device_channel_indices(range(probe.get_contact_count())) + channel_map = np.load(phy_path / "channel_map.npy") + probe.set_device_channel_indices(channel_map) + + probegroup = ProbeGroup() + probegroup.add_probe(probe) else: AssertionError(f"Cannot read probe layout from folder {phy_path}.") - # to make the initial analyzer, we'll use a fake recording and set it to None later - recording, _ = generate_ground_truth_recording( - probe=probe, - sampling_frequency=sampling_frequency, - durations=[duration], - num_units=1, - seed=1205, - ) + # Check that user-defined recording probe geometry is consistent with phy output + if recording is not None: + user_gave_recording = True + all_contact_positions = np.vstack([probe.contact_positions for probe in probegroup.probes]) + for recording_channel_location, probe_contact_position in zip( + recording.get_channel_locations(), all_contact_positions + ): + if not np.all(recording_channel_location == probe_contact_position): + raise ValueError( + "Recording channel locations from `recording` do not match probe channel locations from `folder_path/probe.prb`." + "Hence there is an inconsistency between probe layout or wiring between the recording and sorting output." + "Please resolve this inconsistency." + ) + else: + user_gave_recording = False + # to make the initial analyzer, we'll use a fake recording and set it to None later + recordings = [] + for probe in probegroup.probes: + one_recording, _ = generate_ground_truth_recording( + probe=probe, + sampling_frequency=sampling_frequency, + durations=[duration], + num_units=1, + seed=1205, + ) + recordings.append(one_recording) + recording = aggregate_channels(recordings) sparsity = _make_sparsity_from_templates(sorting, recording, phy_path) @@ -397,7 +422,9 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse ) _make_locations(sorting_analyzer, phy_path) - sorting_analyzer._recording = None + if not user_gave_recording: + sorting_analyzer._recording = None + return sorting_analyzer @@ -413,14 +440,9 @@ def _make_locations(sorting_analyzer, kilosort_output_path): else: return - # Check that the spike locations vector is the same size as the spike vector + # When recording is given, need to trim spike locations to match spikes in sorting num_spikes = len(sorting_analyzer.sorting.to_spike_vector()) - num_spike_locs = len(locs_np) - if num_spikes != num_spike_locs: - warnings.warn( - "The number of spikes does not match the number of spike locations in `spike_positions.npy`. Skipping spike locations." - ) - return + locs_np = locs_np[:num_spikes] num_dims = len(locs_np[0]) column_names = ["x", "y", "z"][:num_dims]