Skip to content

Commit ea2ef8c

Browse files
authored
Add gain and offset for read_kilosort_as_analyzer (#4428)
1 parent cb38613 commit ea2ef8c

1 file changed

Lines changed: 37 additions & 7 deletions

File tree

src/spikeinterface/extractors/phykilosortextractors.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove
314314
read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort")
315315

316316

317-
def read_kilosort_as_analyzer(folder_path, unwhiten=True) -> SortingAnalyzer:
317+
def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offset_to_uV=None) -> SortingAnalyzer:
318318
"""
319319
Load Kilosort output into a SortingAnalyzer. Output from Kilosort version 4.1 and
320320
above are supported. The function may work on older versions of Kilosort output,
@@ -326,13 +326,28 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True) -> SortingAnalyzer:
326326
Path to the output Phy folder (containing the params.py).
327327
unwhiten : bool, default: True
328328
Unwhiten the templates computed by kilosort.
329+
gain_to_uV : float | None, default: None
330+
The gain to apply to convert traces to uV
331+
offset_to_uV : float | None, default: None
332+
The offset to apply to the traces
329333
330334
Returns
331335
-------
332336
sorting_analyzer : SortingAnalyzer
333337
A SortingAnalyzer object.
334338
"""
335339

340+
if gain_to_uV is None:
341+
warnings.warn(
342+
"No `gain_to_uv` value given. Outputted data will be in dimensionless units. If you know the conversion factor, please pass it to the `read_kilosort_as_analyzer` function."
343+
)
344+
gain_to_uV = 1.0
345+
if offset_to_uV is None:
346+
warnings.warn(
347+
"No `offset_to_uV` value given. Outputted data may not be offset correctly. If you know the offset factor, please pass it to the `read_kilosort_as_analyzer` function."
348+
)
349+
offset_to_uV = 0.0
350+
336351
phy_path = Path(folder_path)
337352

338353
sorting = read_phy(phy_path)
@@ -371,7 +386,15 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True) -> SortingAnalyzer:
371386
# first compute random spikes. These do nothing, but are needed for si-gui to run
372387
sorting_analyzer.compute("random_spikes")
373388

374-
_make_templates(sorting_analyzer, phy_path, sparsity.mask, sampling_frequency, unwhiten=unwhiten)
389+
_make_templates(
390+
sorting_analyzer,
391+
phy_path,
392+
sparsity.mask,
393+
sampling_frequency,
394+
gain_to_uV=gain_to_uV,
395+
offset_to_uV=offset_to_uV,
396+
unwhiten=unwhiten,
397+
)
375398
_make_locations(sorting_analyzer, phy_path)
376399

377400
sorting_analyzer._recording = None
@@ -429,15 +452,21 @@ def _make_sparsity_from_templates(sorting, recording, kilosort_output_path):
429452
return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids)
430453

431454

432-
def _make_templates(sorting_analyzer, kilosort_output_path, mask, sampling_frequency, unwhiten=True):
455+
def _make_templates(
456+
sorting_analyzer, kilosort_output_path, mask, sampling_frequency, gain_to_uV, offset_to_uV, unwhiten=True
457+
):
433458
"""Constructs a `templates` extension from the amplitudes numpy array
434459
in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`."""
435460

436461
template_extension = ComputeTemplates(sorting_analyzer)
437462

438463
whitened_templates = np.load(kilosort_output_path / "templates.npy")
439464
wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy")
440-
new_templates = _compute_unwhitened_templates(whitened_templates, wh_inv) if unwhiten else whitened_templates
465+
new_templates = (
466+
_compute_unwhitened_templates(whitened_templates, wh_inv, gain_to_uV, offset_to_uV)
467+
if unwhiten
468+
else whitened_templates
469+
)
441470

442471
template_extension.data = {"average": new_templates}
443472

@@ -476,13 +505,14 @@ def _make_templates(sorting_analyzer, kilosort_output_path, mask, sampling_frequ
476505
sorting_analyzer.extensions["templates"] = template_extension
477506

478507

479-
def _compute_unwhitened_templates(whitened_templates, wh_inv):
508+
def _compute_unwhitened_templates(whitened_templates, wh_inv, gain_to_uV, offset_to_uV):
480509
"""Constructs unwhitened templates from whitened_templates, by
481510
applying an inverse whitening matrix."""
482511

483512
# templates have dimension (num units) x (num samples) x (num channels)
484-
# whitening inverse has dimension (num units) x (num channels)
513+
# whitening inverse has dimension (num channels) x (num channels)
485514
# to undo whitening, we need do matrix multiplication on the channel index
486515
unwhitened_templates = np.einsum("ij,klj->kli", wh_inv, whitened_templates)
487516

488-
return unwhitened_templates
517+
# then scale to physical units
518+
return unwhitened_templates * gain_to_uV + offset_to_uV

0 commit comments

Comments
 (0)