Skip to content

Commit d8e9460

Browse files
samuelgarciaJoeZiminskipre-commit-ci[bot]alejoe91
authored
Refactor artifacts detection + add saturation detection (#4297)
Co-authored-by: JoeZiminski <joseph.j.ziminski@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alessio Buccino <alejoe9187@gmail.com>
1 parent 7eeb0b1 commit d8e9460

File tree

12 files changed

+1016
-288
lines changed

12 files changed

+1016
-288
lines changed

conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,21 @@ def create_cache_folder(tmp_path_factory):
88
return cache_folder
99

1010

11+
@pytest.fixture(scope="module")
12+
def debug_plots(request):
13+
"""Return True if debug plots should be shown."""
14+
return request.config.getoption("--debug-plots")
15+
16+
17+
def pytest_addoption(parser):
18+
parser.addoption(
19+
"--debug-plots",
20+
action="store_true",
21+
default=False,
22+
help="Enable debug plots during tests",
23+
)
24+
25+
1126
def pytest_collection_modifyitems(config, items):
1227
"""
1328
This function marks (in the pytest sense) the tests according to their name and file_path location

doc/api.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ spikeinterface.preprocessing
213213
.. autofunction:: detect_bad_channels
214214
.. autofunction:: detect_and_interpolate_bad_channels
215215
.. autofunction:: detect_and_remove_bad_channels
216+
.. autofunction:: detect_artifact_periods
217+
.. autofunction:: detect_artifact_periods_by_envelope
218+
.. autofunction:: detect_saturation_periods
216219
.. autofunction:: directional_derivative
217220
.. autofunction:: filter
218221
.. autofunction:: gaussian_filter

src/spikeinterface/core/node_pipeline.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -495,16 +495,17 @@ def find_parents_of_type(list_of_parents, parent_type):
495495
return parents
496496

497497

498-
def check_graph(nodes):
498+
def check_graph(nodes, check_for_peak_source=True):
499499
"""
500500
Check that node list is orderd in a good (parents are before children)
501501
"""
502502

503-
node0 = nodes[0]
504-
if not isinstance(node0, PeakSource):
505-
raise ValueError(
506-
"Peak pipeline graph must have as first element a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever"
507-
)
503+
if check_for_peak_source:
504+
node0 = nodes[0]
505+
if not isinstance(node0, PeakSource):
506+
raise ValueError(
507+
"Peak pipeline graph must have as first element a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever"
508+
)
508509

509510
for i, node in enumerate(nodes):
510511
assert isinstance(node, PipelineNode), f"Node {node} is not an instance of PipelineNode"
@@ -532,6 +533,7 @@ def run_node_pipeline(
532533
verbose=False,
533534
skip_after_n_peaks=None,
534535
recording_slices=None,
536+
check_for_peak_source=True,
535537
):
536538
"""
537539
Machinery to compute in parallel operations on peaks and traces.
@@ -587,6 +589,8 @@ def run_node_pipeline(
587589
Optionaly give a list of slices to run the pipeline only on some chunks of the recording.
588590
It must be a list of (segment_index, frame_start, frame_stop).
589591
If None (default), the function iterates over the entire duration of the recording.
592+
check_for_peak_source : bool, default True
593+
Whether to check that the first node is a PeakSource (PeakDetector or PeakRetriever or
590594
591595
Returns
592596
-------
@@ -595,7 +599,7 @@ def run_node_pipeline(
595599
If squeeze_output=True and only one output then directly np.array.
596600
"""
597601

598-
check_graph(nodes)
602+
check_graph(nodes, check_for_peak_source=check_for_peak_source)
599603

600604
job_kwargs = fix_job_kwargs(job_kwargs)
601605
assert all(isinstance(node, PipelineNode) for node in nodes)

src/spikeinterface/preprocessing/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
PreprocessingPipeline,
2121
)
2222

23+
from .detect_artifacts import detect_artifact_periods, detect_artifact_periods_by_envelope, detect_saturation_periods
24+
2325
# for snippets
2426
from .align_snippets import AlignSnippets
2527
from warnings import warn

0 commit comments

Comments
 (0)