Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 44 additions & 22 deletions src/spikeinterface/extractors/phykilosortextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
ComputeTemplates,
create_sorting_analyzer,
SortingAnalyzer,
aggregate_channels,
)
from spikeinterface.core.core_tools import define_function_from_class

from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations
from probeinterface import read_prb, Probe
from probeinterface import read_prb, Probe, ProbeGroup


class BasePhyKilosortSortingExtractor(BaseSorting):
Expand Down Expand Up @@ -314,7 +315,9 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove
read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort")


def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offset_to_uV=None) -> SortingAnalyzer:
def read_kilosort_as_analyzer(
folder_path, recording=None, unwhiten=True, gain_to_uV=None, offset_to_uV=None
) -> SortingAnalyzer:
"""
Load Kilosort output into a SortingAnalyzer. Output from Kilosort version 4.1 and
above are supported. The function may work on older versions of Kilosort output,
Expand All @@ -324,6 +327,8 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse
----------
folder_path : str or Path
Path to the output Phy folder (containing the params.py).
recording : BaseRecording
A spikeinterface Recording object which will be attached to the analyzer
unwhiten : bool, default: True
Unwhiten the templates computed by kilosort.
gain_to_uV : float | None, default: None
Expand Down Expand Up @@ -359,25 +364,45 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse

if (phy_path / "probe.prb").is_file():
probegroup = read_prb(phy_path / "probe.prb")
if len(probegroup.probes) > 0:
warnings.warn("Found more than one probe. Selecting the first probe in ProbeGroup.")
probe = probegroup.probes[0]
elif (phy_path / "channel_positions.npy").is_file():
probe = Probe(si_units="um")
channel_positions = np.load(phy_path / "channel_positions.npy")
probe.set_contacts(channel_positions)
probe.set_device_channel_indices(range(probe.get_contact_count()))
channel_map = np.load(phy_path / "channel_map.npy")
probe.set_device_channel_indices(channel_map)

probegroup = ProbeGroup()
probegroup.add_probe(probe)
else:
AssertionError(f"Cannot read probe layout from folder {phy_path}.")

# to make the initial analyzer, we'll use a fake recording and set it to None later
recording, _ = generate_ground_truth_recording(
probe=probe,
sampling_frequency=sampling_frequency,
durations=[duration],
num_units=1,
seed=1205,
)
# Check that user-defined recording probe geometry is consistent with phy output
if recording is not None:
user_gave_recording = True
all_contact_positions = np.vstack([probe.contact_positions for probe in probegroup.probes])
for recording_channel_location, probe_contact_position in zip(
recording.get_channel_locations(), all_contact_positions
):
if not np.all(recording_channel_location == probe_contact_position):
raise ValueError(
"Recording channel locations from `recording` do not match probe channel locations from `folder_path/probe.prb`."
"Hence there is an inconsistency between probe layout or wiring between the recording and sorting output."
"Please resolve this inconsistency."
)
else:
user_gave_recording = False
# to make the initial analyzer, we'll use a fake recording and set it to None later
recordings = []
for probe in probegroup.probes:
one_recording, _ = generate_ground_truth_recording(
probe=probe,
sampling_frequency=sampling_frequency,
durations=[duration],
num_units=1,
seed=1205,
)
recordings.append(one_recording)
recording = aggregate_channels(recordings)

sparsity = _make_sparsity_from_templates(sorting, recording, phy_path)

Expand All @@ -397,7 +422,9 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse
)
_make_locations(sorting_analyzer, phy_path)

sorting_analyzer._recording = None
if not user_gave_recording:
sorting_analyzer._recording = None

return sorting_analyzer


Expand All @@ -413,14 +440,9 @@ def _make_locations(sorting_analyzer, kilosort_output_path):
else:
return

# Check that the spike locations vector is the same size as the spike vector
# When recording is given, need to trim spike locations to match spikes in sorting
num_spikes = len(sorting_analyzer.sorting.to_spike_vector())
num_spike_locs = len(locs_np)
if num_spikes != num_spike_locs:
warnings.warn(
"The number of spikes does not match the number of spike locations in `spike_positions.npy`. Skipping spike locations."
)
return
locs_np = locs_np[:num_spikes]

num_dims = len(locs_np[0])
column_names = ["x", "y", "z"][:num_dims]
Expand Down
Loading