Skip to content

Commit e8a9cf4

Browse files
committed
apply_preprocessing_pipeline and BaseRecording
1 parent 7452278 commit e8a9cf4

2 files changed

Lines changed: 21 additions & 17 deletions

File tree

src/spikeinterface/preprocessing/pipeline.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _apply(self, recording, apply_precomputed_kwargs=False):
9797
9898
Returns
9999
-------
100-
preprocessed_recording : RecordingExtractor
100+
preprocessed_recording : BaseRecording
101101
Preprocessed recording
102102
103103
"""
@@ -118,7 +118,7 @@ def _apply(self, recording, apply_precomputed_kwargs=False):
118118
return recording
119119

120120

121-
def apply_pipeline(
121+
def apply_preprocessing_pipeline(
122122
recording: BaseRecording, pipeline_or_dict: PreprocessingPipeline | dict, apply_precomputed_kwargs=True
123123
):
124124
"""
@@ -127,7 +127,7 @@ def apply_pipeline(
127127
128128
Parameters
129129
----------
130-
recording : RecordingExtractor
130+
recording : BaseRecording
131131
The initial recording
132132
pipeline_or_dict : PreprocessingPipeline | dict
133133
Dictionary containing preprocessing steps and their kwargs, or a pipeline object.
@@ -139,7 +139,7 @@ def apply_pipeline(
139139
140140
Returns
141141
-------
142-
preprocessed_recording : RecordingExtractor
142+
preprocessed_recording : BaseRecording
143143
Preprocessed recording
144144
145145
Examples
@@ -150,7 +150,7 @@ def apply_pipeline(
150150
>>> from spikeinterface.generation import generate_recording
151151
>>> recording = generate_recording()
152152
>>> preprocessor_dict = {'bandpass_filter': {'freq_max': 3000}, 'common_reference': {}}
153-
>>> preprocessed_recording = apply_pipeline(recording, preprocessor_dict)
153+
>>> preprocessed_recording = apply_preprocessing_pipeline(recording, preprocessor_dict)
154154
"""
155155

156156
if isinstance(pipeline_or_dict, PreprocessingPipeline):
@@ -218,7 +218,7 @@ def get_preprocessing_dict_from_analyzer(analyzer_folder, format="auto", backend
218218

219219
def get_preprocessing_dict_from_file(recording_dictionary_path):
220220
"""
221-
Generates a preprocessing dict, passable to `apply_pipeline` function and
221+
Generates a preprocessing dict, passable to `apply_preprocessing_pipeline` function and
222222
`PreprocessPipeline` class, from a recording dictionary.
223223
224224
Only extracts preprocessing steps which can be applied "globally" to any recording.

src/spikeinterface/preprocessing/tests/test_pipeline.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from spikeinterface.generation import generate_recording, generate_ground_truth_recording
22
from spikeinterface.preprocessing import (
3-
apply_pipeline,
3+
apply_preprocessing_pipeline,
44
preprocessor_dict,
55
bandpass_filter,
66
common_reference,
@@ -83,7 +83,7 @@ def test_pipeline_equiv_to_step():
8383
else:
8484
pp_rec_from_class = pp_class(rec)
8585

86-
pp_rec_from_pipeline = apply_pipeline(rec, pp_dict)
86+
pp_rec_from_pipeline = apply_preprocessing_pipeline(rec, pp_dict)
8787

8888
if isinstance(pp_rec_from_pipeline, dict):
8989
check_recordings_equal(pp_rec_from_pipeline[0], pp_rec_from_class[0])
@@ -106,7 +106,7 @@ def test_three_preprocessing_steps():
106106
"whiten": {"seed": 1205},
107107
}
108108

109-
pp_rec_from_pipeline = apply_pipeline(rec, pipeline_dict)
109+
pp_rec_from_pipeline = apply_preprocessing_pipeline(rec, pipeline_dict)
110110
pp_rec_from_functions = whiten(bandpass_filter(common_reference(rec)), seed=1205)
111111

112112
check_recordings_equal(pp_rec_from_pipeline, pp_rec_from_functions)
@@ -115,7 +115,7 @@ def test_three_preprocessing_steps():
115115
rec_groups.set_property(key="group", values=[0, 1])
116116
dict_of_recs = rec_groups.split_by("group")
117117

118-
pp_dict_of_recs_from_pipeline = apply_pipeline(dict_of_recs, pipeline_dict)
118+
pp_dict_of_recs_from_pipeline = apply_preprocessing_pipeline(dict_of_recs, pipeline_dict)
119119
pp_dict_of_recs_from_functions = whiten(bandpass_filter(common_reference(dict_of_recs)), seed=1205)
120120

121121
check_recordings_equal(pp_dict_of_recs_from_pipeline[0], pp_dict_of_recs_from_functions[0])
@@ -131,14 +131,14 @@ def test_kwargs_are_propagated():
131131
rec = generate_recording(durations=[1])
132132
pipeline_dict = {"bandpass_filter": {}}
133133

134-
bp_rec_default = apply_pipeline(rec, pipeline_dict)
134+
bp_rec_default = apply_preprocessing_pipeline(rec, pipeline_dict)
135135

136136
kwargs = bp_rec_default._kwargs
137137
assert kwargs["freq_min"] == 300.0
138138

139139
pipeline_dict_non_default = {"bandpass_filter": {"freq_min": 500.0}}
140140

141-
bp_rec_non_default = apply_pipeline(rec, pipeline_dict_non_default)
141+
bp_rec_non_default = apply_preprocessing_pipeline(rec, pipeline_dict_non_default)
142142
non_default_kwargs = bp_rec_non_default._kwargs
143143

144144
assert non_default_kwargs["freq_min"] == 500.0
@@ -162,8 +162,12 @@ def test_loading_provenance(create_cache_folder):
162162

163163
loaded_pp_dict = get_preprocessing_dict_from_file(cache_folder / "provenance.pkl")
164164

165-
pipeline_rec_applying_precomputed_kwargs = apply_pipeline(rec, loaded_pp_dict, apply_precomputed_kwargs=True)
166-
pipeline_rec_ignoring_precomputed_kwargs = apply_pipeline(rec, loaded_pp_dict, apply_precomputed_kwargs=False)
165+
pipeline_rec_applying_precomputed_kwargs = apply_preprocessing_pipeline(
166+
rec, loaded_pp_dict, apply_precomputed_kwargs=True
167+
)
168+
pipeline_rec_ignoring_precomputed_kwargs = apply_preprocessing_pipeline(
169+
rec, loaded_pp_dict, apply_precomputed_kwargs=False
170+
)
167171

168172
check_recordings_equal(pipeline_rec_applying_precomputed_kwargs, pp_rec)
169173
check_recordings_equal(pipeline_rec_ignoring_precomputed_kwargs, pp_rec)
@@ -182,18 +186,18 @@ def test_loading_from_analyzer(create_cache_folder):
182186
recording, sorting = generate_ground_truth_recording()
183187

184188
preprocessing_dict = {"common_reference": {}, "highpass_filter": {"freq_min": 301.0}}
185-
pp_recording = apply_pipeline(recording, preprocessing_dict)
189+
pp_recording = apply_preprocessing_pipeline(recording, preprocessing_dict)
186190

187191
analyzer_binary_folder = cache_folder / "binary_format"
188192
_ = create_sorting_analyzer(
189193
sorting=sorting, recording=pp_recording, format="binary_folder", folder=analyzer_binary_folder
190194
)
191195
pp_dict_from_binary = get_preprocessing_dict_from_analyzer(analyzer_binary_folder)
192-
pp_recording_from_binary = apply_pipeline(recording, pp_dict_from_binary)
196+
pp_recording_from_binary = apply_preprocessing_pipeline(recording, pp_dict_from_binary)
193197
check_recordings_equal(pp_recording, pp_recording_from_binary)
194198

195199
analyzer_zarr_folder = cache_folder / "zarr_format.zarr"
196200
_ = create_sorting_analyzer(sorting=sorting, recording=pp_recording, format="zarr", folder=analyzer_zarr_folder)
197201
pp_dict_from_zarr = get_preprocessing_dict_from_analyzer(analyzer_zarr_folder)
198-
pp_recording_from_zarr = apply_pipeline(recording, pp_dict_from_zarr)
202+
pp_recording_from_zarr = apply_preprocessing_pipeline(recording, pp_dict_from_zarr)
199203
check_recordings_equal(pp_recording, pp_recording_from_zarr)

0 commit comments

Comments
 (0)