|
12 | 12 | from spikeinterface.core.sparsity import ChannelSparsity |
13 | 13 | from spikeinterface.core.template import Templates |
14 | 14 | from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer |
15 | | -from spikeinterface.core.job_tools import split_job_kwargs |
| 15 | +from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs |
16 | 16 | from spikeinterface.core.sortinganalyzer import create_sorting_analyzer |
17 | 17 | from spikeinterface.core.sparsity import ChannelSparsity |
18 | 18 | from spikeinterface.core.analyzer_extension_core import ComputeTemplates |
@@ -249,19 +249,144 @@ def check_probe_for_drift_correction(recording, dist_x_max=60): |
249 | 249 | return True |
250 | 250 |
|
251 | 251 |
|
252 | | -def cache_preprocessing(recording, mode="memory", memory_limit=0.5, delete_cache=True, **extra_kwargs): |
253 | | - save_kwargs, job_kwargs = split_job_kwargs(extra_kwargs) |
| 252 | +def _set_optimal_chunk_size(recording, job_kwargs, memory_limit=0.5, total_memory=None): |
| 253 | + """ |
| 254 | + Set the optimal chunk size for a job given the memory_limit and the number of jobs |
254 | 255 |
|
255 | | - if mode == "memory": |
| 256 | + Parameters |
| 257 | + ---------- |
| 258 | +
|
| 259 | + recording: Recording |
| 260 | + The recording object |
| 261 | + job_kwargs: dict |
| 262 | + The job kwargs |
| 263 | + memory_limit: float |
| 264 | + The memory limit in fraction of available memory |
| 265 | + total_memory: str, Default None |
| 266 | + The total memory to use for the job in bytes |
| 267 | +
|
| 268 | + Returns |
| 269 | + ------- |
| 270 | +
|
| 271 | + job_kwargs: dict |
| 272 | + The updated job kwargs |
| 273 | + """ |
| 274 | + job_kwargs = fix_job_kwargs(job_kwargs) |
| 275 | + n_jobs = job_kwargs["n_jobs"] |
| 276 | + if total_memory is None: |
256 | 277 | if HAVE_PSUTIL: |
257 | 278 | assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" |
258 | 279 | memory_usage = memory_limit * psutil.virtual_memory().available |
259 | | - if recording.get_total_memory_size() < memory_usage: |
260 | | - recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) |
| 280 | + num_channels = recording.get_num_channels() |
| 281 | + dtype_size_bytes = recording.get_dtype().itemsize |
| 282 | + chunk_size = memory_usage / ((num_channels * dtype_size_bytes) * n_jobs) |
| 283 | + chunk_duration = chunk_size / recording.get_sampling_frequency() |
| 284 | + job_kwargs.update(dict(chunk_duration=f"{chunk_duration}s")) |
| 285 | + job_kwargs = fix_job_kwargs(job_kwargs) |
| 286 | + else: |
| 287 | + import warnings |
| 288 | + |
| 289 | + warnings.warn("psutil is required to use only a fraction of available memory") |
| 290 | + else: |
| 291 | + from spikeinterface.core.job_tools import convert_string_to_bytes |
| 292 | + |
| 293 | + total_memory = convert_string_to_bytes(total_memory) |
| 294 | + num_channels = recording.get_num_channels() |
| 295 | + dtype_size_bytes = recording.get_dtype().itemsize |
| 296 | + chunk_size = (num_channels * dtype_size_bytes) * n_jobs / total_memory |
| 297 | + chunk_duration = chunk_size / recording.get_sampling_frequency() |
| 298 | + job_kwargs.update(dict(chunk_duration=f"{chunk_duration}s")) |
| 299 | + job_kwargs = fix_job_kwargs(job_kwargs) |
| 300 | + return job_kwargs |
| 301 | + |
| 302 | + |
| 303 | +def _get_optimal_n_jobs(job_kwargs, ram_requested, memory_limit=0.25): |
| 304 | + """ |
| 305 | + Set the optimal chunk size for a job given the memory_limit and the number of jobs |
| 306 | +
|
| 307 | + Parameters |
| 308 | + ---------- |
| 309 | +
|
| 310 | + recording: Recording |
| 311 | + The recording object |
| 312 | + ram_requested: int |
| 313 | + The amount of RAM (in bytes) requested for the job |
| 314 | + memory_limit: float |
| 315 | + The memory limit in fraction of available memory |
| 316 | +
|
| 317 | + Returns |
| 318 | + ------- |
| 319 | +
|
| 320 | + job_kwargs: dict |
| 321 | + The updated job kwargs |
| 322 | + """ |
| 323 | + job_kwargs = fix_job_kwargs(job_kwargs) |
| 324 | + n_jobs = job_kwargs["n_jobs"] |
| 325 | + if HAVE_PSUTIL: |
| 326 | + assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" |
| 327 | + memory_usage = memory_limit * psutil.virtual_memory().available |
| 328 | + n_jobs = max(1, int(min(n_jobs, memory_usage // ram_requested))) |
| 329 | + job_kwargs.update(dict(n_jobs=n_jobs)) |
| 330 | + else: |
| 331 | + import warnings |
| 332 | + |
| 333 | + warnings.warn("psutil is required to use only a fraction of available memory") |
| 334 | + return job_kwargs |
| 335 | + |
| 336 | + |
| 337 | +def cache_preprocessing( |
| 338 | + recording, mode="memory", memory_limit=0.5, total_memory=None, delete_cache=True, **extra_kwargs |
| 339 | +): |
| 340 | + """ |
| 341 | + Cache the preprocessing of a recording object |
| 342 | +
|
| 343 | + Parameters |
| 344 | + ---------- |
| 345 | +
|
| 346 | + recording: Recording |
| 347 | + The recording object |
| 348 | + mode: str |
| 349 | + The mode to cache the preprocessing, can be 'memory', 'folder', 'zarr' or 'no-cache' |
| 350 | + memory_limit: float |
| 351 | + The memory limit in fraction of available memory |
| 352 | + total_memory: str, Default None |
| 353 | + The total memory to use for the job in bytes |
| 354 | + delete_cache: bool |
| 355 | + If True, delete the cache after the job |
| 356 | + **extra_kwargs: dict |
| 357 | + The extra kwargs for the job |
| 358 | +
|
| 359 | + Returns |
| 360 | + ------- |
| 361 | +
|
| 362 | + recording: Recording |
| 363 | + The cached recording object |
| 364 | + """ |
| 365 | + |
| 366 | + save_kwargs, job_kwargs = split_job_kwargs(extra_kwargs) |
| 367 | + |
| 368 | + if mode == "memory": |
| 369 | + if total_memory is None: |
| 370 | + if HAVE_PSUTIL: |
| 371 | + assert 0 < memory_limit < 1, "memory_limit should be in ]0, 1[" |
| 372 | + memory_usage = memory_limit * psutil.virtual_memory().available |
| 373 | + if recording.get_total_memory_size() < memory_usage: |
| 374 | + recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) |
| 375 | + else: |
| 376 | + import warnings |
| 377 | + |
| 378 | + warnings.warn("Recording too large to be preloaded in RAM...") |
261 | 379 | else: |
262 | | - print("Recording too large to be preloaded in RAM...") |
| 380 | + import warnings |
| 381 | + |
| 382 | + warnings.warn("psutil is required to preload in memory given only a fraction of available memory") |
263 | 383 | else: |
264 | | - print("psutil is required to preload in memory") |
| 384 | + if recording.get_total_memory_size() < total_memory: |
| 385 | + recording = recording.save_to_memory(format="memory", shared=True, **job_kwargs) |
| 386 | + else: |
| 387 | + import warnings |
| 388 | + |
| 389 | + warnings.warn("Recording too large to be preloaded in RAM...") |
265 | 390 | elif mode == "folder": |
266 | 391 | recording = recording.save_to_folder(**extra_kwargs) |
267 | 392 | elif mode == "zarr": |
|
0 commit comments