Skip to content

Commit 95de48f

Browse files
committed
refactor
1 parent 26eac1f commit 95de48f

3 files changed

Lines changed: 112 additions & 120 deletions

File tree

src/spikeinterface/preprocessing/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .detect_bad_channels import detect_bad_channels
1414
from .correct_lsb import correct_lsb
1515

16-
from .pipeline import create_preprocessed, PreprocessingPipeline, get_preprocessing_dict_from_json
16+
from .pipeline import apply_pipeline, PreprocessingPipeline
1717

1818
# for snippets
1919
from .align_snippets import AlignSnippets

src/spikeinterface/preprocessing/pipeline.py

Lines changed: 74 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
import re
21
import json
32
import inspect
4-
from copy import deepcopy
53
from spikeinterface.core.core_tools import is_dict_extractor
6-
from spikeinterface.preprocessing.preprocessinglist import pp_function_to_class, preprocesser_dict
4+
from spikeinterface.core import BaseRecording
5+
from spikeinterface.preprocessing.preprocessinglist import preprocessor_dict, _all_preprocesser_dict
6+
7+
pp_names_to_functions = {preprocessor.__name__: preprocessor for preprocessor in preprocessor_dict.values()}
8+
pp_names_to_classes = {pp_function.__name__: pp_class for pp_class, pp_function in _all_preprocesser_dict.items()}
79

810

911
class PreprocessingPipeline:
@@ -24,15 +26,18 @@ class PreprocessingPipeline:
2426
>>> preprocessor_dict = {'bandpass_filter': {'freq_max': 3000}, 'common_reference': {}}
2527
>>> my_pipeline = PreprocessingPipeline(preprocessor_dict)
2628
PreprocessingPipeline: Raw Recording → bandpass_filter → common_reference → Preprocessed Recording
27-
>>> my_pipeline.apply_to(recording)
29+
>>> my_pipeline.apply(recording)
2830
2931
"""
3032

3133
def __init__(self, preprocessor_dict):
3234
for preprocessor in preprocessor_dict:
33-
assert _is_genuine_preprocessor(
34-
preprocessor
35-
), f"'{preprocessor}' is not a preprocessing step in spikeinterface. To see the full list run:\n\t>>> from spikeinterface.preprocessing import pp_function_to_class\n\t>>> print(pp_function_to_class.keys())"
35+
if preprocessor not in pp_names_to_functions.keys():
36+
raise TypeError(
37+
f"'{preprocessor}' is not supported by the `PreprocessingPipeline`. \
38+
To see the list of supported steps, run:\n\t>>> from spikeinterface.preprocessing \
39+
import preprocessor_dict\n\t>>> print(preprocessor_dict.keys())"
40+
)
3641

3742
self.preprocessor_dict = preprocessor_dict
3843

@@ -47,47 +52,40 @@ def _repr_html_(self):
4752

4853
all_kwargs = _get_all_kwargs_and_values(self)
4954

50-
num_titles = len(all_kwargs) + 3
51-
colors = [
52-
[255 - (a * 15) // num_titles, 255 - (a * 82) // num_titles, 255 - (a * 30) // num_titles]
53-
for a in range(0, num_titles + 1)
54-
]
55-
5655
html_text = "<div'>"
57-
58-
html_text += "<strong>PreprocessingPipeline:</strong>"
59-
60-
html_text += f"<div style='border:1px solid #ddd; padding:10px;'><strong>Initial Recording</strong></div>"
61-
62-
html_text += "<div style='margin: auto'>&#x2193;</div>"
56+
html_text += "<strong>PreprocessingPipeline</strong>"
57+
html_text += "<div style='border:1px solid #ccc; padding:10px;'><strong>Initial Recording</strong></div>"
58+
html_text += "<div style='margin: auto; text-indent: 30px;'>&#x2193;</div>"
6359

6460
for a, (preprocessor, kwargs) in enumerate(all_kwargs.items()):
65-
html_text += "<details>"
66-
html_text += (
67-
f"<summary style='border:1px solid #eee; padding:5px;'><strong>{preprocessor}</strong></summary>"
68-
)
61+
html_text += "<details style='border:1px solid #ddd; padding:5px;'>"
62+
html_text += f"<summary><strong>{preprocessor}</strong></summary>"
63+
6964
html_text += "<ul>"
7065
for kwarg, value in kwargs.items():
7166
html_text += f"<li><strong>{kwarg}</strong>: {value}</li>"
7267
html_text += "</ul>"
7368
html_text += "</details>"
74-
html_text += """<div">&#x2193;</div>"""
75-
76-
html_text += f"<div style='border:1px solid #ddd; padding:10px;'><strong>Preprocessed Recording</strong></div>"
7769

70+
html_text += """<div style='margin: auto; text-indent: 30px;'>&#x2193;</div>"""
71+
html_text += "<div style='border:1px solid #ccc; padding:10px;'><strong>Preprocessed Recording</strong></div>"
7872
html_text += "</div>"
7973

8074
return html_text
8175

82-
def apply_to(self, recording):
76+
def apply(self, recording, ignore_precomputed_kwargs=True):
8377
"""
84-
Creates a preprocessed recording by applying the PreprocessingPipeline to
78+
Creates a preprocessed recording by applying the `PreprocessingPipeline` to
8579
`recording`.
8680
8781
Parameters
8882
----------
8983
recording : RecordingExtractor
9084
The initial recording
85+
ignore_precomputed_kwargs : Bool
86+
Some preprocessing steps (e.g. Whitening) contain arguments which are computed
87+
during preprocessing. If True, we ignore these precomputed steps. If False, we
88+
compute when we apply the preprocessors.
9189
9290
Returns
9391
-------
@@ -96,28 +94,25 @@ def apply_to(self, recording):
9694
9795
"""
9896

99-
preprocessor_dict = self.preprocessor_dict
97+
for preprocessor_name, kwargs in self.preprocessor_dict.items():
10098

101-
for preprocessor, kwargs in preprocessor_dict.items():
99+
dont_include_kwargs = ["recording", "parent_recording"]
102100

103-
kwargs.pop("recording", kwargs)
104-
kwargs.pop("parent_recording", kwargs)
105-
106-
using_class_name = bool(re.search("Recording", preprocessor))
107-
if using_class_name is True:
108-
pp_output = preprocesser_dict[preprocessor.split(".")[-1]](recording, **kwargs)
109-
else:
110-
pp_output = pp_function_to_class[preprocessor.split(".")[-1]](recording, **kwargs)
111-
112-
if preprocessor == "motion_correct":
113-
pp_output = pp_output[0]
101+
if ignore_precomputed_kwargs:
102+
preprocessor_class = pp_names_to_classes[preprocessor_name]
103+
precomputable_kwarg_names = preprocessor_class._precomputable_kwarg_names
104+
dont_include_kwargs += precomputable_kwarg_names
114105

106+
non_rec_kwargs = {key: value for key, value in kwargs.items() if key not in dont_include_kwargs}
107+
pp_output = pp_names_to_functions[preprocessor_name](recording, **non_rec_kwargs)
115108
recording = pp_output
116109

117110
return recording
118111

119112

120-
def create_preprocessed(recording, preprocessor_dict=None):
113+
def apply_pipeline(
114+
recording: BaseRecording, pipeline_or_dict: dict | PreprocessingPipeline = {}, ignore_precomputed_kwargs=True
115+
):
121116
"""
122117
Creates a preprocessed recording by applying the preprocessing steps in
123118
`preprocessor_dict` to `recording`.
@@ -126,8 +121,13 @@ def create_preprocessed(recording, preprocessor_dict=None):
126121
----------
127122
recording : RecordingExtractor
128123
The initial recording
129-
preprocessor_dict : dict
130-
Dictionary containing preprocessing steps and their kwargs
124+
preprocessor_dict : dict | PreprocessingPipeline = {}
125+
Dictionary containing preprocessing steps and their kwargs, or a pipeline object.
126+
If None, the original recording is returned.
127+
ignore_precomputed_kwargs : Bool
128+
Some preprocessing steps (e.g. Whitening) contain arguments which are computed
129+
during preprocessing. If True, we ignore these precomputed steps. If False, we
130+
compute when we apply the preprocessors.
131131
132132
Returns
133133
-------
@@ -142,13 +142,15 @@ def create_preprocessed(recording, preprocessor_dict=None):
142142
>>> from spikeinterface.generation import generate_recording
143143
>>> recording = generate_recording()
144144
>>> preprocessor_dict = {'bandpass_filter': {'freq_max': 3000}, 'common_reference': {}}
145-
>>> preprocessed_recording = create_preprocessed(recording, preprocessor_dict)
146-
147-
145+
>>> preprocessed_recording = apply_pipeline(recording, preprocessor_dict)
148146
"""
149147

150-
pipeline = PreprocessingPipeline(preprocessor_dict)
151-
preprocessed_recording = pipeline.apply_to(recording)
148+
if isinstance(pipeline_or_dict, PreprocessingPipeline):
149+
pipeline = pipeline_or_dict
150+
else:
151+
pipeline = PreprocessingPipeline(pipeline_or_dict)
152+
153+
preprocessed_recording = pipeline.apply(recording, ignore_precomputed_kwargs)
152154
return preprocessed_recording
153155

154156

@@ -177,49 +179,34 @@ def get_preprocessing_dict_from_json(recording_json_path):
177179
"""
178180
recording_json = json.load(open(recording_json_path))
179181

180-
initial_preprocessor_dict = {}
181-
_load_pp_from_dict(recording_json, initial_preprocessor_dict)
182-
183-
preprocessor_dict = deepcopy(initial_preprocessor_dict)
184-
for preprocessor in initial_preprocessor_dict:
185-
preprocessor_name = preprocessor.split(".")[-1]
186-
187-
if not _is_genuine_preprocessor(preprocessor_name):
188-
preprocessor_dict.pop(preprocessor, preprocessor_dict)
189-
continue
190-
191-
# remove recording details
192-
preprocessor_dict[preprocessor].pop("recording", preprocessor_dict[preprocessor])
193-
preprocessor_dict[preprocessor].pop("parent_recording", preprocessor_dict[preprocessor])
194-
195-
# rename keys to be the class names
196-
preprocessor_dict[preprocessor_name] = preprocessor_dict[preprocessor]
197-
preprocessor_dict.pop(preprocessor)
182+
pp_from_json = {}
183+
_load_pp_from_dict(recording_json, pp_from_json)
198184

199-
preprocessor_dict = dict(reversed(preprocessor_dict.items()))
185+
pipeline_dict = {}
186+
for preprocessor in reversed(pp_from_json):
200187

201-
return preprocessor_dict
188+
preprocessor_class_name = preprocessor.split(".")[-1]
202189

190+
preprocessor_function = preprocessor_dict.get(preprocessor_class_name)
191+
if preprocessor_function is None:
192+
continue
203193

204-
def _is_genuine_preprocessor(preprocessor):
205-
"""
206-
Check is string 'preprocessor' is in the list of preprocessors from
207-
`pp_function_to_class`.
208-
"""
194+
pp_kwargs = {
195+
key: value
196+
for key, value in pp_from_json[preprocessor].items()
197+
if key not in ["recording", "parent_recording"]
198+
}
209199

210-
using_class_name = bool(re.search("Recording", preprocessor))
211-
if using_class_name:
212-
genuine_preprocessor = preprocessor in preprocesser_dict.keys()
213-
else:
214-
genuine_preprocessor = preprocessor in pp_function_to_class.keys()
200+
pipeline_dict[preprocessor_function.__name__] = pp_kwargs
215201

216-
return genuine_preprocessor
202+
return pipeline_dict
217203

218204

219205
def _load_pp_from_dict(prov_dict, kwargs_dict):
220206
"""
221207
Recursive function used to iterate through recording provenance dictionary, and
222-
extract preprocessing steps and their kwargs.
208+
extract preprocessing steps and their kwargs. Based on `_load_extractor_from_dict`
209+
from spikeinterface.core.base.
223210
"""
224211
new_kwargs = dict()
225212
transform_dict_to_extractor = lambda x: _load_pp_from_dict(x) if is_dict_extractor(x) else x
@@ -238,18 +225,21 @@ def _load_pp_from_dict(prov_dict, kwargs_dict):
238225

239226

240227
def _get_all_kwargs_and_values(my_pipeline):
228+
"""
229+
Get all keyword arguments and their values from a pipeline,
230+
including the default values.
231+
"""
241232

242233
all_kwargs = {}
243234
for preprocessor in my_pipeline.preprocessor_dict:
244235

245236
preprocessor_name = preprocessor.split(".")[-1]
246-
# preprocessor_name = preprocessor.split(".")[-1]
247-
pp_function = pp_function_to_class[preprocessor.split(".")[-1]]
237+
pp_function = pp_names_to_functions[preprocessor.split(".")[-1]]
248238
signature = inspect.signature(pp_function)
249239

250240
all_kwargs[preprocessor_name] = {}
251241

252-
for parameter, value in signature.parameters.items():
242+
for _, value in signature.parameters.items():
253243
par_name = str(value).split("=")[0].split(":")[0]
254244
if par_name != "recording":
255245
try:

src/spikeinterface/preprocessing/preprocessinglist.py

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -45,46 +45,48 @@
4545
from .astype import AstypeRecording, astype
4646
from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed
4747

48-
from .motion import correct_motion
49-
50-
pp_function_to_class = {
48+
_all_preprocesser_dict = {
5149
# filter stuff
52-
"filter": FilterRecording,
53-
"bandpass_filter": BandpassFilterRecording,
54-
"notch_filter": NotchFilterRecording,
55-
"highpass_filter": HighpassFilterRecording,
56-
"gaussian_filter": GaussianFilterRecording,
50+
FilterRecording: filter,
51+
BandpassFilterRecording: bandpass_filter,
52+
HighpassFilterRecording: highpass_filter,
53+
NotchFilterRecording: notch_filter,
54+
GaussianFilterRecording: gaussian_filter,
5755
# gain offset stuff
58-
"normalize_by_quantile": NormalizeByQuantileRecording,
59-
"scale": ScaleRecording,
60-
"zscore": ZScoreRecording,
61-
"center": CenterRecording,
56+
NormalizeByQuantileRecording: normalize_by_quantile,
57+
ScaleRecording: scale,
58+
CenterRecording: center,
59+
ZScoreRecording: zscore,
6260
# decorrelation stuff
63-
"whiten": WhitenRecording,
61+
WhitenRecording: whiten,
6462
# re-reference
65-
"common_reference": CommonReferenceRecording,
66-
"phase_shift": PhaseShiftRecording,
63+
CommonReferenceRecording: common_reference,
64+
PhaseShiftRecording: phase_shift,
6765
# misc
68-
"rectify": RectifyRecording,
69-
"clip": ClipRecording,
70-
"blank_staturation": BlankSaturationRecording,
71-
"silence_periods": SilencedPeriodsRecording,
72-
"remove_artifacts": RemoveArtifactsRecording,
73-
"zero_channel_pad": ZeroChannelPaddedRecording,
74-
"deepinterpolate": DeepInterpolatedRecording,
75-
"resample": ResampleRecording,
76-
"decimate": DecimateRecording,
77-
"highpass_spatial_filter": HighpassSpatialFilterRecording,
78-
"interpolate_bad_channels": InterpolateBadChannelsRecording,
79-
"depth_order": DepthOrderRecording,
80-
"average_across_direction": AverageAcrossDirectionRecording,
81-
"directional_derivative": DirectionalDerivativeRecording,
82-
"astype": AstypeRecording,
83-
"unsigned_to_signed": UnsignedToSignedRecording,
84-
"correct_motion": correct_motion,
66+
RectifyRecording: rectify,
67+
ClipRecording: clip,
68+
BlankSaturationRecording: blank_saturation,
69+
SilencedPeriodsRecording: silence_periods,
70+
RemoveArtifactsRecording: remove_artifacts,
71+
ZeroChannelPaddedRecording: zero_channel_pad,
72+
DeepInterpolatedRecording: deepinterpolate,
73+
ResampleRecording: resample,
74+
DecimateRecording: decimate,
75+
HighpassSpatialFilterRecording: highpass_spatial_filter,
76+
InterpolateBadChannelsRecording: interpolate_bad_channels,
77+
DepthOrderRecording: depth_order,
78+
AverageAcrossDirectionRecording: average_across_direction,
79+
DirectionalDerivativeRecording: directional_derivative,
80+
AstypeRecording: astype,
81+
UnsignedToSignedRecording: unsigned_to_signed,
8582
}
83+
# we control import in the preprocessing init by setting an __all__
8684

85+
# pp_function.__name__ gives the name of the function that users should use
86+
__all__ = [pp_function.__name__ for pp_function in _all_preprocesser_dict.values()]
87+
__all__.extend(
88+
[scale_to_uV.__name__, compute_whitening_matrix.__name__, train_deepinterpolation.__name__, causal_filter.__name__]
89+
)
8790

88-
preprocessers_full_list = pp_function_to_class.values()
89-
90-
preprocesser_dict = {pp_class.__name__: pp_class for pp_class in preprocessers_full_list}
91+
preprocessor_dict = {pp_class.__name__: pp_function for pp_class, pp_function in _all_preprocesser_dict.items()}
92+
__all__.append("preprocessor_dict")

0 commit comments

Comments
 (0)