Skip to content

Commit 9db0cf9

Browse files
authored
Merge pull request #3721 from yger/total_memory
Handle automatic chunks duration for SC2
2 parents 0bffc1e + a93016a commit 9db0cf9

5 files changed

Lines changed: 157 additions & 20 deletions

File tree

src/spikeinterface/core/recording_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def get_random_recording_slices(
512512
chunk_duration : str | float | None, default "500ms"
513513
The duration of each chunk in 's' or 'ms'
514514
chunk_size : int | None
515-
Size of a chunk in number of frames. This is ued only if chunk_duration is None.
515+
Size of a chunk in number of frames. This is used only if chunk_duration is None.
516516
This is kept for backward compatibility, you should prefer 'chunk_duration=500ms' instead.
517517
concatenated : bool, default: True
518518
If True chunk are concatenated along time axis

src/spikeinterface/sorters/internal/spyking_circus2.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
cache_preprocessing,
1414
get_prototype_and_waveforms_from_recording,
1515
get_shuffled_recording_slices,
16+
_set_optimal_chunk_size,
1617
)
1718
from spikeinterface.core.basesorting import minimum_spike_dtype
1819
from spikeinterface.core.sparsity import compute_sparsity
@@ -39,6 +40,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
3940
"apply_preprocessing": True,
4041
"templates_from_svd": True,
4142
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
43+
"chunk_preprocessing": {"memory_limit": None},
4244
"multi_units_only": False,
4345
"job_kwargs": {"n_jobs": 0.75},
4446
"seed": 42,
@@ -66,6 +68,9 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
6668
"matched_filtering": "Boolean to specify whether circus 2 should detect peaks via matched filtering (slightly slower)",
6769
"cache_preprocessing": "How to cache the preprocessed recording. Mode can be memory, file, zarr, with extra arguments. In case of memory (default), \
6870
memory_limit will control how much RAM can be used. In case of folder or zarr, delete_cache controls if cache is cleaned after sorting",
71+
"chunk_preprocessing": "How much RAM (approximately) should be devoted to load all data chunks (given n_jobs).\
72+
memory_limit will control how much RAM can be used as a fraction of available memory. Otherwise, use total_memory to fix a hard limit, with\
73+
a string syntax (e.g. '1G', '500M')",
6974
"multi_units_only": "Boolean to get only multi units activity (i.e. one template per electrode)",
7075
"job_kwargs": "A dictionary to specify how many jobs and which parameters they should used",
7176
"seed": "An int to control how chunks are shuffled while detecting peaks",
@@ -100,8 +105,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
100105

101106
job_kwargs = fix_job_kwargs(params["job_kwargs"])
102107
job_kwargs.update({"progress_bar": verbose})
103-
104108
recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False)
109+
if params["chunk_preprocessing"].get("memory_limit", None) is not None:
110+
job_kwargs = _set_optimal_chunk_size(recording, job_kwargs, **params["chunk_preprocessing"])
105111

106112
sampling_frequency = recording.get_sampling_frequency()
107113
num_channels = recording.get_num_channels()
@@ -401,7 +407,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
401407
# np.save(fitting_folder / "amplitudes", guessed_amplitudes)
402408

403409
if sorting.get_non_empty_unit_ids().size > 0:
404-
sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params, **job_kwargs)
410+
final_analyzer = final_cleaning_circus(
411+
recording_w, sorting, templates, **merging_params, **job_kwargs
412+
)
413+
final_analyzer.save_as(format="binary_folder", folder=sorter_output_folder / "final_analyzer")
414+
415+
sorting = final_analyzer.sorting
405416

406417
if verbose:
407418
print(f"Kept {len(sorting.unit_ids)} units after final merging")
@@ -460,4 +471,5 @@ def final_cleaning_circus(
460471
sparsity_overlap=sparsity_overlap,
461472
**job_kwargs,
462473
)
463-
return final_sa.sorting
474+
475+
return final_sa

src/spikeinterface/sortingcomponents/clustering/circus.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,12 @@
1414

1515
import random, string
1616
from spikeinterface.core import get_global_tmp_folder
17-
from spikeinterface.core.basesorting import minimum_spike_dtype
18-
from spikeinterface.core.waveform_tools import estimate_templates
1917
from .clustering_tools import remove_duplicates_via_matching
2018
from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances
2119
from spikeinterface.sortingcomponents.peak_selection import select_peaks
22-
from spikeinterface.core.template import Templates
2320
from spikeinterface.core.sparsity import compute_sparsity
24-
from spikeinterface.sortingcomponents.tools import remove_empty_templates
21+
from spikeinterface.sortingcomponents.tools import remove_empty_templates, _get_optimal_n_jobs
2522
from spikeinterface.sortingcomponents.clustering.peak_svd import extract_peaks_svd
26-
27-
2823
from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel
2924

3025

@@ -62,6 +57,7 @@ class CircusClustering:
6257
"noise_levels": None,
6358
"tmp_folder": None,
6459
"verbose": True,
60+
"memory_limit": 0.25,
6561
"debug": False,
6662
}
6763

@@ -162,13 +158,17 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
162158
if not params["templates_from_svd"]:
163159
from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_recording
164160

161+
job_kwargs_local = job_kwargs.copy()
162+
unit_ids = np.unique(peak_labels)
163+
ram_requested = recording.get_num_channels() * (nbefore + nafter) * len(unit_ids) * 4
164+
job_kwargs_local = _get_optimal_n_jobs(job_kwargs_local, ram_requested, params["memory_limit"])
165165
templates = get_templates_from_peaks_and_recording(
166166
recording,
167167
peaks,
168168
peak_labels,
169169
ms_before,
170170
ms_after,
171-
**job_kwargs,
171+
**job_kwargs_local,
172172
)
173173
else:
174174
from spikeinterface.sortingcomponents.clustering.tools import get_templates_from_peaks_and_svd

src/spikeinterface/sortingcomponents/tools.py

Lines changed: 133 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from spikeinterface.core.sparsity import ChannelSparsity
1313
from spikeinterface.core.template import Templates
1414
from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer
15-
from spikeinterface.core.job_tools import split_job_kwargs
15+
from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs
1616
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
1717
from spikeinterface.core.sparsity import ChannelSparsity
1818
from spikeinterface.core.analyzer_extension_core import ComputeTemplates
@@ -249,19 +249,144 @@ def check_probe_for_drift_correction(recording, dist_x_max=60):
249249
return True
250250

251251

252-
def cache_preprocessing(recording, mode="memory", memory_limit=0.5, delete_cache=True, **extra_kwargs):
253-
save_kwargs, job_kwargs = split_job_kwargs(extra_kwargs)
252+
def _set_optimal_chunk_size(recording, job_kwargs, memory_limit=0.5, total_memory=None):
253+
"""
254+
Set the optimal chunk size for a job given the memory_limit and the number of jobs
254255
255-
if mode == "memory":
256+
Parameters
257+
----------
258+
259+
recording: Recording
260+
The recording object
261+
job_kwargs: dict
262+
The job kwargs
263+
memory_limit: float
264+
The memory limit in fraction of available memory
265+
total_memory: str, Default None
266+
The total memory to use for the job in bytes
267+
268+
Returns
269+
-------
270+
271+
job_kwargs: dict
272+
The updated job kwargs
273+
"""
274+
job_kwargs = fix_job_kwargs(job_kwargs)
275+
n_jobs = job_kwargs["n_jobs"]
276+
if total_memory is None:
256277
if HAVE_PSUTIL:
257278
assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1["
258279
memory_usage = memory_limit * psutil.virtual_memory().available
259-
if recording.get_total_memory_size() < memory_usage:
260-
recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs)
280+
num_channels = recording.get_num_channels()
281+
dtype_size_bytes = recording.get_dtype().itemsize
282+
chunk_size = memory_usage / ((num_channels * dtype_size_bytes) * n_jobs)
283+
chunk_duration = chunk_size / recording.get_sampling_frequency()
284+
job_kwargs.update(dict(chunk_duration=f"{chunk_duration}s"))
285+
job_kwargs = fix_job_kwargs(job_kwargs)
286+
else:
287+
import warnings
288+
289+
warnings.warn("psutil is required to use only a fraction of available memory")
290+
else:
291+
from spikeinterface.core.job_tools import convert_string_to_bytes
292+
293+
total_memory = convert_string_to_bytes(total_memory)
294+
num_channels = recording.get_num_channels()
295+
dtype_size_bytes = recording.get_dtype().itemsize
296+
chunk_size = (num_channels * dtype_size_bytes) * n_jobs / total_memory
297+
chunk_duration = chunk_size / recording.get_sampling_frequency()
298+
job_kwargs.update(dict(chunk_duration=f"{chunk_duration}s"))
299+
job_kwargs = fix_job_kwargs(job_kwargs)
300+
return job_kwargs
301+
302+
303+
def _get_optimal_n_jobs(job_kwargs, ram_requested, memory_limit=0.25):
304+
"""
305+
Set the optimal chunk size for a job given the memory_limit and the number of jobs
306+
307+
Parameters
308+
----------
309+
310+
recording: Recording
311+
The recording object
312+
ram_requested: int
313+
The amount of RAM (in bytes) requested for the job
314+
memory_limit: float
315+
The memory limit in fraction of available memory
316+
317+
Returns
318+
-------
319+
320+
job_kwargs: dict
321+
The updated job kwargs
322+
"""
323+
job_kwargs = fix_job_kwargs(job_kwargs)
324+
n_jobs = job_kwargs["n_jobs"]
325+
if HAVE_PSUTIL:
326+
assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1["
327+
memory_usage = memory_limit * psutil.virtual_memory().available
328+
n_jobs = max(1, int(min(n_jobs, memory_usage // ram_requested)))
329+
job_kwargs.update(dict(n_jobs=n_jobs))
330+
else:
331+
import warnings
332+
333+
warnings.warn("psutil is required to use only a fraction of available memory")
334+
return job_kwargs
335+
336+
337+
def cache_preprocessing(
338+
recording, mode="memory", memory_limit=0.5, total_memory=None, delete_cache=True, **extra_kwargs
339+
):
340+
"""
341+
Cache the preprocessing of a recording object
342+
343+
Parameters
344+
----------
345+
346+
recording: Recording
347+
The recording object
348+
mode: str
349+
The mode to cache the preprocessing, can be 'memory', 'folder', 'zarr' or 'no-cache'
350+
memory_limit: float
351+
The memory limit in fraction of available memory
352+
total_memory: str, Default None
353+
The total memory to use for the job in bytes
354+
delete_cache: bool
355+
If True, delete the cache after the job
356+
**extra_kwargs: dict
357+
The extra kwargs for the job
358+
359+
Returns
360+
-------
361+
362+
recording: Recording
363+
The cached recording object
364+
"""
365+
366+
save_kwargs, job_kwargs = split_job_kwargs(extra_kwargs)
367+
368+
if mode == "memory":
369+
if total_memory is None:
370+
if HAVE_PSUTIL:
371+
assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1["
372+
memory_usage = memory_limit * psutil.virtual_memory().available
373+
if recording.get_total_memory_size() < memory_usage:
374+
recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs)
375+
else:
376+
import warnings
377+
378+
warnings.warn("Recording too large to be preloaded in RAM...")
261379
else:
262-
print("Recording too large to be preloaded in RAM...")
380+
import warnings
381+
382+
warnings.warn("psutil is required to preload in memory given only a fraction of available memory")
263383
else:
264-
print("psutil is required to preload in memory")
384+
if recording.get_total_memory_size() < total_memory:
385+
recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs)
386+
else:
387+
import warnings
388+
389+
warnings.warn("Recording too large to be preloaded in RAM...")
265390
elif mode == "folder":
266391
recording = recording.save_to_folder(**extra_kwargs)
267392
elif mode == "zarr":

src/spikeinterface/widgets/crosscorrelograms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
120120

121121
if i < len(self.axes) - 1:
122122
self.axes[i, j].set_xticks([], [])
123-
plt.tight_layout()
123+
self.figure.tight_layout()
124124

125125
for i, unit_id in enumerate(unit_ids):
126126
self.axes[0, i].set_title(str(unit_id))

0 commit comments

Comments
 (0)