diff --git a/src/spikeinterface/exporters/__init__.py b/src/spikeinterface/exporters/__init__.py index 50fcc304d1..1bec3e5cab 100644 --- a/src/spikeinterface/exporters/__init__.py +++ b/src/spikeinterface/exporters/__init__.py @@ -1,2 +1,3 @@ from .to_phy import export_to_phy from .report import export_report +from .to_ibl import export_to_ibl diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py new file mode 100644 index 0000000000..2c4aec4cd2 --- /dev/null +++ b/src/spikeinterface/exporters/to_ibl.py @@ -0,0 +1,309 @@ +from __future__ import annotations + +import os +import shutil +from pathlib import Path +from typing import Optional + +import numpy as np +import numpy.typing as npt +from tqdm.auto import tqdm + +from spikeinterface.core import ChannelSparsity, SortingAnalyzer +from spikeinterface.core.template_tools import get_template_extremum_channel +from spikeinterface.exporters import ( + export_to_phy, +) +from spikeinterface.exporters.to_ibl_utils import ( + WindowGenerator, + fscale, + hp, + rms, + save_object_npy, +) + + +def export_to_ibl( + analyzer: SortingAnalyzer, + output_folder: str | Path, + rms_win_length_sec=3, + welch_win_length_samples=1024, + total_secs=100, + only_ibl_specific_steps=False, + compute_pc_features: bool = False, # shouldn't need these? + compute_amplitudes: bool = True, + sparsity: Optional[ChannelSparsity] = None, + copy_binary: bool = True, + remove_if_exists: bool = False, + template_mode: str = "median", + dtype: Optional[npt.DTypeLike] = None, + verbose: bool = True, + use_relative_path: bool = False, + **job_kwargs, +): + """ + Exports a sorting analyzer to the IBL gui format (similar to the Phy format with some extras). + + Parameters + ---------- + analyzer: SortingAnalyzer + The sorting analyzer object. + output_folder: str | Path + The output folder where the phy template-gui files are saved + rms_win_length_sec: float, default: 3 + The window length in seconds for the RMS calculation. + welch_win_length_samples: int, default: 1024 + The window length in samples for the Welch method. + total_secs: int, default: 100 + The total number of seconds to use for the spectral density calculation. + only_ibl_specific_steps: bool, default: False + If True, only the IBL specific steps are run (i.e. skips calling `export_to_phy`) + compute_pc_features: bool, default: False + If True, pc features are computed + compute_amplitudes: bool, default: True + If True, waveforms amplitudes are computed + sparsity: ChannelSparsity or None, default: None + The sparsity object (currently only respected for phy part of the export) + copy_binary: bool, default: True + If True, the recording is copied and saved in the phy "output_folder" + remove_if_exists: bool, default: False + If True and "output_folder" exists, it is removed and overwritten + template_mode: str, default: "median" + Parameter "mode" to be given to WaveformExtractor.get_template() + dtype: dtype or None, default: None + Dtype to save binary data + verbose: bool, default: True + If True, output is verbose + use_relative_path : bool, default: False + If True and `copy_binary=True` saves the binary file `dat_path` in the `params.py` relative to `output_folder` (ie `dat_path=r"recording.dat"`). If `copy_binary=False`, then uses a path relative to the `output_folder` + If False, uses an absolute path in the `params.py` (ie `dat_path=r"path/to/the/recording.dat"`) + {} + + """ + + try: + from scipy.signal import welch + except ImportError as e: + raise ImportError("Please install scipy to use the export_to_ibl function.") from e + + # Output folder checks + if isinstance(output_folder, str): + output_folder = Path(output_folder) + output_folder = Path(output_folder).absolute() + if output_folder.is_dir(): + if remove_if_exists: + shutil.rmtree(output_folder) + else: + raise FileExistsError(f"{output_folder} already exists") + else: + pass + # don't make the output dir yet, b/c export_to_phy will do that for us. + + if verbose: + print("Exporting recording to IBL format...") + + # Compute any missing extensions + available_extension_names = analyzer.get_saved_extension_names() + required_exts = [ + "templates", + "template_similarity", + "spike_locations", + "noise_levels", + "quality_metrics", + ] + required_qms = ["amplitude_median", "isi_violations_ratio", "amplitude_cutoff"] + for ext in required_exts: + if ext not in available_extension_names: + if ext == "quality_metrics": + kwargs = {"skip_pc_metrics": not compute_pc_features} + else: + kwargs = {} + analyzer.compute(ext, verbose=verbose, **kwargs) + elif ext == "quality_metrics": + qm = analyzer.get_extension("quality_metrics").get_data() + for rqm in required_qms: + if rqm not in qm: + analyzer.compute( + "quality_metrics", + metric_names=[rqm], + verbose=verbose, + ) + + # # Start by just exporting to phy + if not only_ibl_specific_steps: + if verbose: + print("Doing phy-like export...") + export_to_phy( + analyzer, + output_folder, + compute_amplitudes=compute_amplitudes, + compute_pc_features=compute_pc_features, + sparsity=sparsity, + copy_binary=copy_binary, + remove_if_exists=remove_if_exists, + template_mode=template_mode, + dtype=dtype, + verbose=verbose, + use_relative_path=use_relative_path, + **job_kwargs, + ) + + if verbose: + print("Running IBL-specific steps...") + + # Now we need to add the extra IBL specific files + (channel_inds,) = np.isin(analyzer.recording.channel_ids, analyzer.channel_ids).nonzero() + + ### Run spectral density and rms ### + fs_ap = analyzer.recording.sampling_frequency + rms_win_length_samples_ap = 2 ** np.ceil(np.log2(fs_ap * rms_win_length_sec)) + total_samples_ap = int(np.min([fs_ap * total_secs, analyzer.recording.get_num_samples()])) + + # the window generator will generates window indices + wingen = WindowGenerator(ns=total_samples_ap, nswin=rms_win_length_samples_ap, overlap=0) + win = { + "TRMS": np.zeros((wingen.nwin, analyzer.recording.get_num_channels())), + "nsamples": np.zeros((wingen.nwin,)), + "fscale": fscale(welch_win_length_samples, 1 / fs_ap, one_sided=True), + "tscale": wingen.tscale(fs=fs_ap), + } + win["spectral_density"] = np.zeros((len(win["fscale"]), analyzer.recording.get_num_channels())) + + # @Josh: this could be dramatically sped up if we employ SpikeInterface parallelization + with tqdm(total=wingen.nwin) as pbar: + for first, last in wingen.firstlast: + D = analyzer.recording.get_traces(start_frame=first, end_frame=last).T + # remove low frequency noise below 1 Hz + D = hp(D, 1 / fs_ap, [0, 1]) + iw = wingen.iw + win["TRMS"][iw, :] = rms(D) + win["nsamples"][iw] = D.shape[1] + + # the last window may be smaller than what is needed for welch + if last - first < welch_win_length_samples: + continue + + # compute a smoothed spectrum using welch method + _, w = welch( + D, + fs=fs_ap, + window="hann", + nperseg=welch_win_length_samples, + detrend="constant", + return_onesided=True, + scaling="density", + axis=-1, + ) + win["spectral_density"] += w.T + # print at least every 20 windows + if (iw % min(20, max(int(np.floor(wingen.nwin / 75)), 1))) == 0: + pbar.update(iw) + + win["TRMS"] = win["TRMS"][:, channel_inds] + win["spectral_density"] = win["spectral_density"][:, channel_inds] + + alf_object_time = "ephysTimeRmsAP" + alf_object_freq = "ephysSpectralDensityAP" + + tdict = { + "rms": win["TRMS"].astype(np.single), + "timestamps": win["tscale"].astype(np.single), + } + save_object_npy(output_folder, object=alf_object_time, dico=tdict, namespace="iblqc") + + fdict = { + "power": win["spectral_density"].astype(np.single), + "freqs": win["fscale"].astype(np.single), + } + save_object_npy(output_folder, object=alf_object_freq, dico=fdict, namespace="iblqc") + + ### Save spike info ### + + spike_locations = analyzer.load_extension("spike_locations").get_data() + spike_depths = spike_locations["y"] + + # convert clusters and squeeze + clusters = np.load(output_folder / "spike_clusters.npy") + np.save(output_folder / "spike_clusters.npy", np.squeeze(clusters.astype("uint32"))) + + # convert times and squeeze + times = np.load(output_folder / "spike_times.npy") + np.save(output_folder / "spike_times.npy", np.squeeze(times / 30000.0).astype("float64")) + + # convert amplitudes and squeeze + amps = np.load(output_folder / "amplitudes.npy") + np.save(output_folder / "amplitudes.npy", np.squeeze(-amps / 1e6).astype("float64")) + + # save depths and channel inds + np.save(output_folder / "spike_depths.npy", spike_depths) + np.save(output_folder / "channel_inds.npy", np.arange(len(channel_inds), dtype="int")) + + # # save templates + cluster_channels = [] + cluster_peakToTrough = [] + cluster_waveforms = [] + templates = analyzer.get_extension("templates").get_data() + extremum_channel_indices = get_template_extremum_channel(analyzer, outputs="index") + + for unit_idx, unit_id in enumerate(analyzer.unit_ids): + waveform = templates[unit_idx, :, :] + extremum_channel_index = extremum_channel_indices[unit_id] + peak_waveform = waveform[:, extremum_channel_index] + peakToTrough = (np.argmax(peak_waveform) - np.argmin(peak_waveform)) / analyzer.sampling_frequency + # cluster_channels.append(int(channel_locs[extremum_channel_index, 1] / 10)) # ??? fails for odd nums of units + cluster_channels.append( + extremum_channel_index + ) # see: https://github.com/SpikeInterface/spikeinterface/issues/2843#issuecomment-2148164870 + cluster_peakToTrough.append(peakToTrough) + cluster_waveforms.append(waveform) + + np.save(output_folder / "cluster_peakToTrough.npy", np.array(cluster_peakToTrough)) + np.save(output_folder / "cluster_waveforms.npy", np.stack(cluster_waveforms)) + np.save(output_folder / "cluster_channels.npy", np.array(cluster_channels)) + + # rename files + _FILE_RENAMES = [ # file_in, file_out + ("channel_positions.npy", "channels.localCoordinates.npy"), + ("channel_inds.npy", "channels.rawInd.npy"), + ("cluster_peakToTrough.npy", "clusters.peakToTrough.npy"), + ("cluster_channels.npy", "clusters.channels.npy"), + ("cluster_waveforms.npy", "clusters.waveforms.npy"), + ("spike_clusters.npy", "spikes.clusters.npy"), + ("amplitudes.npy", "spikes.amps.npy"), + ("spike_depths.npy", "spikes.depths.npy"), + ("spike_times.npy", "spikes.times.npy"), + ] + + for names in _FILE_RENAMES: + old_name = output_folder / names[0] + new_name = output_folder / names[1] + # shutil.copyfile(old_name, new_name) + os.rename(old_name, new_name) + + # save quality metrics + qm = analyzer.load_extension("quality_metrics") + qm_data = qm.get_data() + qm_data.index.name = "cluster_id" + qm_data["cluster_id.1"] = qm_data.index.values + good_ibl = ( # rough estimate of ibl standards + (qm_data["amplitude_median"] > 50) + & (qm_data["isi_violations_ratio"] < 0.2) + & (qm_data["amplitude_cutoff"] < 0.1) + ) + qm_data["label"] = good_ibl.astype("int") + qm_data.to_csv(output_folder / "clusters.metrics.csv") + + +# if __name__ == "__main__": + +# print("Running test script...") +# rec = load_extractor("/n/groups/datta/Jonah/20231003_vlPAG_npx/raw_data/J04501/20240405_J04501/2024-04-05_18-46-54/preprocess") +# we = load_waveforms("/n/groups/datta/Jonah/20231003_vlPAG_npx/raw_data/J04501/20240405_J04501/2024-04-05_18-46-54/kilosort4_clitest_preCompTemplates/waveforms_folder") +# output_folder = "/n/groups/datta/Jonah/20231003_vlPAG_npx/raw_data/J04501/20240405_J04501/2024-04-05_18-46-54/ibl_exported" + +# # rec = load_extractor("/n/groups/datta/Jonah/20231003_vlPAG_npx/raw_data/J04501/20240403_J04501/2024-04-03_16-13-26/preprocess") +# # we = load_waveforms("/n/groups/datta/Jonah/20231003_vlPAG_npx/raw_data/J04501/20240403_J04501/2024-04-03_16-13-26/kilosort4_clitest_preCompTemplates/waveforms_folder") +# # output_folder = "/n/groups/datta/Jonah/20231003_vlPAG_npx/raw_data/J04501/20240403_J04501/2024-04-03_16-13-26/ibl_exported" + +# export_to_ibl(rec, we, output_folder, compute_pc_features=False, copy_binary=False) +# print("Done!") diff --git a/src/spikeinterface/exporters/to_ibl_utils.py b/src/spikeinterface/exporters/to_ibl_utils.py new file mode 100644 index 0000000000..a263222b4f --- /dev/null +++ b/src/spikeinterface/exporters/to_ibl_utils.py @@ -0,0 +1,435 @@ +""" +Low-level functions to work in frequency domain for n-dim arrays + +Copied from https://github.com/int-brain-lab/ibl-neuropixel/ on 2/1/2024 + +""" + +import numpy as np +from pathlib import Path +import re + + +def _dromedary(string) -> str: + """ + Convert a string to camel case. Acronyms/initialisms are preserved. + + Parameters + ---------- + string : str + To be converted to camel case + + Returns + ------- + str + The string in camel case + + Examples + -------- + >>> _dromedary('Hello world') == 'helloWorld' + >>> _dromedary('motion_energy') == 'motionEnergy' + >>> _dromedary('passive_RFM') == 'passive RFM' + >>> _dromedary('FooBarBaz') == 'fooBarBaz' + + See Also + -------- + readableALF + """ + + def _capitalize(x): + return x if x.isupper() else x.capitalize() + + if not string: # short circuit on None and '' + return string + first, *other = re.split(r"[_\s]", string) + if len(other) == 0: + # Already camel/Pascal case, ensure first letter lower case + return first[0].lower() + first[1:] + # Convert to camel case, preserving all-uppercase elements + first = first if first.isupper() else first.casefold() + return "".join([first, *map(_capitalize, other)]) + + +def to_alf(object, attribute, extension, namespace=None, timescale=None, extra=None): + """ + Given a set of ALF file parts, return a valid ALF file name. Essential periods and + underscores are added by the function. + + Parameters + ---------- + object : str + The ALF object name + attribute : str + The ALF object attribute name + extension : str + The file extension + namespace : str + An optional namespace + timescale : str, tuple + An optional timescale + extra : str, tuple + One or more optional extra ALF attributes + + Returns + ------- + str + A file name string built from the ALF parts + + Examples + -------- + >>> to_alf('spikes', 'times', 'ssv') + 'spikes.times.ssv' + >>> to_alf('spikes', 'times', 'ssv', namespace='ibl') + '_ibl_spikes.times.ssv' + >>> to_alf('spikes', 'times', 'ssv', namespace='ibl', timescale='ephysClock') + '_ibl_spikes.times_ephysClock.ssv' + >>> to_alf('spikes', 'times', 'ssv', namespace='ibl', timescale=('ephys clock', 'minutes')) + '_ibl_spikes.times_ephysClock_minutes.ssv' + >>> to_alf('spikes', 'times', 'npy', namespace='ibl', timescale='ephysClock', extra='raw') + '_ibl_spikes.times_ephysClock.raw.npy' + >>> to_alf('wheel', 'timestamps', 'npy', 'ibl', 'bpod', ('raw', 'v12')) + '_ibl_wheel.timestamps_bpod.raw.v12.npy' + """ + # Validate inputs + if not extension: + raise TypeError("An extension must be provided") + elif extension.startswith("."): + extension = extension[1:] + if any(pt is not None and "." in pt for pt in (object, attribute, namespace, extension, timescale)): + raise ValueError("ALF parts must not contain a period (`.`)") + if "_" in (namespace or ""): + raise ValueError("Namespace must not contain extra underscores") + if object[0] == "_": + raise ValueError("Objects must not contain underscores; use namespace arg instead") + # Ensure parts are camel case (converts whitespace and snake case) + if timescale: + timescale = filter(None, [timescale] if isinstance(timescale, str) else timescale) + timescale = "_".join(map(_dromedary, timescale)) + # Convert attribute to camel case, leaving '_times', etc. in tact + times_re = re.search("_(times|timestamps|intervals)$", attribute) + idx = times_re.start() if times_re else len(attribute) + attribute = _dromedary(attribute[:idx]) + attribute[idx:] + object = _dromedary(object) + + # Optional extras may be provided as string or tuple of strings + if not extra: + extra = () + elif isinstance(extra, str): + extra = extra.split(".") + + # Construct ALF file + parts = ( + ("_%s_" % namespace if namespace else "") + object, + attribute + ("_%s" % timescale if timescale else ""), + *extra, + extension, + ) + return ".".join(parts) + + +def save_object_npy(alfpath, dico, object, parts=None, namespace=None, timescale=None) -> list: + """ + Saves a dictionary in `ALF format`_ using object as object name and dictionary keys as + attribute names. Dimensions have to be consistent. + + Simplified ALF example: _namespace_object.attribute.part1.part2.extension + + Parameters + ---------- + alfpath : str, pathlib.Path + Path of the folder to save data to + dico : dict + Dictionary to save to npy; keys correspond to ALF attributes + object : str + Name of the object to save + parts : str, list, None + Extra parts to the ALF name + namespace : str, None + The optional namespace of the object + timescale : str, None + The optional timescale of the object + + Returns + ------- + list + List of written files + + Examples + -------- + >>> spikes = {'times': np.arange(50), 'depths': np.random.random(50)} + >>> files = save_object_npy('/path/to/my/alffolder/', spikes, 'spikes') + + .. _ALF format: + https://int-brain-lab.github.io/ONE/alf_intro.html + """ + alfpath = Path(alfpath) + status = check_dimensions(dico) + if status != 0: + raise ValueError( + "Dimensions are not consistent to save all arrays in ALF format: " + + str([(k, v.shape) for k, v in dico.items()]) + ) + out_files = [] + for k, v in dico.items(): + out_file = alfpath / to_alf(object, k, "npy", extra=parts, namespace=namespace, timescale=timescale) + np.save(out_file, v) + out_files.append(out_file) + return out_files + + +def check_dimensions(dico): + """ + Test for consistency of dimensions as per ALF specs in a dictionary. + + Alf broadcasting rules: only accepts consistent dimensions for a given axis + a dimension is consistent with another if it's empty, 1, or equal to the other arrays + dims [a, 1], [1, b] and [a, b] are all consistent, [c, 1] is not + + Parameters + ---------- + dico : ALFBunch, dict + Dictionary containing data + + Returns + ------- + int + Status 0 for consistent dimensions, 1 for inconsistent dimensions + """ + # supported = (np.ndarray, pd.DataFrame) # idt any dataframes in this specific use case for SI + supported = (np.ndarray,) # Data types that have a shape attribute + shapes = [dico[lab].shape for lab in dico if isinstance(dico[lab], supported) and not lab.startswith("timestamps")] + first_shapes = [sh[0] for sh in shapes] + # Continuous timeseries are permitted to be a (2, 2) + timeseries = [k for k, v in dico.items() if k.startswith("timestamps") and isinstance(v, np.ndarray)] + if any(timeseries): + for key in timeseries: + if dico[key].ndim == 1 or (dico[key].ndim == 2 and dico[key].shape[1] == 1): + # Should be vector with same length as other attributes + first_shapes.append(dico[key].shape[0]) + elif dico[key].ndim > 1 and dico[key].shape != (2, 2): + return 1 # ts not a (2, 2) arr or a vector + + ok = len(first_shapes) == 0 or set(first_shapes).issubset({max(first_shapes), 1}) + return int(ok is False) + + +def rms(x, axis=-1): + """ + Root mean square of array along axis + + :param x: array on which to compute RMS + :param axis: (optional, -1) + :return: numpy array + """ + return np.sqrt(np.mean(x**2, axis=axis)) + + +def _fcn_extrap(x, f, bounds): + """ + Extrapolates a flat value before and after bounds + x: array to be filtered + f: function to be applied between bounds (cf. fcn_cosine below) + bounds: 2 elements list or np.array + """ + y = f(x) + y[x < bounds[0]] = f(bounds[0]) + y[x > bounds[1]] = f(bounds[1]) + return y + + +def fcn_cosine(bounds, gpu=False): + """ + Returns a soft thresholding function with a cosine taper: + values <= bounds[0]: values + values < bounds[0] < bounds[1] : cosine taper + values < bounds[1]: bounds[1] + :param bounds: + :param gpu: bool + :return: lambda function + """ + if gpu: + import cupy as gp + else: + gp = np + + def _cos(x): + return (1 - gp.cos((x - bounds[0]) / (bounds[1] - bounds[0]) * gp.pi)) / 2 + + func = lambda x: _fcn_extrap(x, _cos, bounds) # noqa + return func + + +def fscale(ns, si=1, one_sided=False): + """ + numpy.fft.fftfreq returns Nyquist as a negative frequency so we propose this instead + + :param ns: number of samples + :param si: sampling interval in seconds + :param one_sided: if True, returns only positive frequencies + :return: fscale: numpy vector containing frequencies in Hertz + """ + fsc = np.arange(0, np.floor(ns / 2) + 1) / ns / si # sample the frequency scale + if one_sided: + return fsc + else: + return np.concatenate((fsc, -fsc[slice(-2 + (ns % 2), 0, -1)]), axis=0) + + +def bp(ts, si, b, axis=None): + """ + Band-pass filter in frequency domain + + :param ts: time serie + :param si: sampling interval in seconds + :param b: cutout frequencies: 4 elements vector or list + :param axis: axis along which to perform reduction (last axis by default) + :return: filtered time serie + """ + return _freq_filter(ts, si, b, axis=axis, typ="bp") + + +def lp(ts, si, b, axis=None): + """ + Low-pass filter in frequency domain + + :param ts: time serie + :param si: sampling interval in seconds + :param b: cutout frequencies: 2 elements vector or list + :param axis: axis along which to perform reduction (last axis by default) + :return: filtered time serie + """ + return _freq_filter(ts, si, b, axis=axis, typ="lp") + + +def hp(ts, si, b, axis=None): + """ + High-pass filter in frequency domain + + :param ts: time serie + :param si: sampling interval in seconds + :param b: cutout frequencies: 2 elements vector or list + :param axis: axis along which to perform reduction (last axis by default) + :return: filtered time serie + """ + return _freq_filter(ts, si, b, axis=axis, typ="hp") + + +def _freq_filter(ts, si, b, axis=None, typ="lp"): + """ + Wrapper for hp/lp/bp filters + """ + if axis is None: + axis = ts.ndim - 1 + ns = ts.shape[axis] + f = fscale(ns, si=si, one_sided=True) + if typ == "bp": + filc = _freq_vector(f, b[0:2], typ="hp") * _freq_vector(f, b[2:4], typ="lp") + else: + filc = _freq_vector(f, b, typ=typ) + if axis < (ts.ndim - 1): + filc = filc[:, np.newaxis] + return np.real(np.fft.ifft(np.fft.fft(ts, axis=axis) * fexpand(filc, ns, axis=0), axis=axis)) + + +def _freq_vector(f, b, typ="lp"): + """ + Returns a frequency modulated vector for filtering + + :param f: frequency vector, uniform and monotonic + :param b: 2 bounds array + :return: amplitude modulated frequency vector + """ + filc = fcn_cosine(b)(f) + if typ.lower() in ["hp", "highpass"]: + return filc + elif typ.lower() in ["lp", "lowpass"]: + return 1 - filc + + +def fexpand(x, ns=1, axis=None): + """ + Reconstructs full spectrum from positive frequencies + Works on the last dimension (contiguous in c-stored array) + + :param x: numpy.ndarray + :param axis: axis along which to perform reduction (last axis by default) + :return: numpy.ndarray + """ + if axis is None: + axis = x.ndim - 1 + # dec = int(ns % 2) * 2 - 1 + # xcomp = np.conj(np.flip(x[..., 1:x.shape[-1] + dec], axis=axis)) + ilast = int((ns + (ns % 2)) / 2) + xcomp = np.conj(np.flip(np.take(x, np.arange(1, ilast), axis=axis), axis=axis)) + return np.concatenate((x, xcomp), axis=axis) + + +class WindowGenerator(object): + """ + `wg = WindowGenerator(ns, nswin, overlap)` + + Provide sliding windows indices generator for signal processing applications. + For straightforward spectrogram / periodogram implementation, prefer scipy methods ! + + Example of implementations in test_dsp.py. + """ + + def __init__(self, ns, nswin, overlap): + """ + :param ns: number of sample of the signal along the direction to be windowed + :param nswin: number of samples of the window + :return: dsp.WindowGenerator object: + """ + self.ns = int(ns) + self.nswin = int(nswin) + self.overlap = int(overlap) + self.nwin = int(np.ceil(float(ns - nswin) / float(nswin - overlap))) + 1 + self.iw = None + + @property + def firstlast(self): + """ + Generator that yields first and last index of windows + + :return: tuple of [first_index, last_index] of the window + """ + self.iw = 0 + first = 0 + while True: + last = first + self.nswin + last = min(last, self.ns) + yield (first, last) + if last == self.ns: + break + first += self.nswin - self.overlap + self.iw += 1 + + @property + def slice(self): + """ + Generator that yields slices of windows + + :return: a slice of the window + """ + for first, last in self.firstlast: + yield slice(first, last) + + def slice_array(self, sig, axis=-1): + """ + Provided an array or sliceable object, generator that yields + slices corresponding to windows. Especially useful when working on memmpaps + + :param sig: array + :param axis: (optional, -1) dimension along which to provide the slice + :return: array slice Generator + """ + for first, last in self.firstlast: + yield np.take(sig, np.arange(first, last), axis=axis) + + def tscale(self, fs): + """ + Returns the time scale associated with Window slicing (middle of window) + :param fs: sampling frequency (Hz) + :return: time axis scale + """ + return np.array([(first + (last - first - 1) / 2) / fs for first, last in self.firstlast]) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 06041da231..4b3b914733 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -194,7 +194,7 @@ def export_to_phy( templates[unit_ind, :, :][:, : len(chan_inds)] = template templates_ind[unit_ind, : len(chan_inds)] = chan_inds - if not sorting_analyzer.has_extension("template_similarity"): + if sorting_analyzer.get_extension("template_similarity") is None: sorting_analyzer.compute("template_similarity") template_similarity = sorting_analyzer.get_extension("template_similarity").get_data() @@ -215,14 +215,14 @@ def export_to_phy( np.save(str(output_folder / "channel_groups.npy"), channel_groups) if compute_amplitudes: - if not sorting_analyzer.has_extension("spike_amplitudes"): + if sorting_analyzer.get_extension("spike_amplitudes") is None: sorting_analyzer.compute("spike_amplitudes", **job_kwargs) amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() amplitudes = amplitudes[:, np.newaxis] np.save(str(output_folder / "amplitudes.npy"), amplitudes) if compute_pc_features: - if not sorting_analyzer.has_extension("principal_components"): + if sorting_analyzer.get_extension("principal_components") is None: sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) pca_extension = sorting_analyzer.get_extension("principal_components") @@ -250,7 +250,7 @@ def export_to_phy( channel_group = pd.DataFrame({"cluster_id": [i for i in range(len(unit_ids))], "channel_group": unit_groups}) channel_group.to_csv(output_folder / "cluster_channel_group.tsv", sep="\t", index=False) - if sorting_analyzer.has_extension("quality_metrics") and add_quality_metrics: + if sorting_analyzer.get_extension("quality_metrics") is not None and add_quality_metrics: qm_data = sorting_analyzer.get_extension("quality_metrics").get_data() for column_name in qm_data.columns: # already computed by phy @@ -259,7 +259,7 @@ def export_to_phy( {"cluster_id": [i for i in range(len(unit_ids))], column_name: qm_data[column_name].values} ) metric.to_csv(output_folder / f"cluster_{column_name}.tsv", sep="\t", index=False) - if sorting_analyzer.has_extension("template_metrics") and add_template_metrics: + if sorting_analyzer.get_extension("template_metrics") is not None and add_template_metrics: tm_data = sorting_analyzer.get_extension("template_metrics").get_data() for column_name in tm_data.columns: metric = pd.DataFrame(