1- import re
21import json
32import inspect
4- from copy import deepcopy
53from spikeinterface .core .core_tools import is_dict_extractor
6- from spikeinterface .preprocessing .preprocessinglist import pp_function_to_class , preprocesser_dict
4+ from spikeinterface .core import BaseRecording
5+ from spikeinterface .preprocessing .preprocessinglist import preprocessor_dict , _all_preprocesser_dict
6+
7+ pp_names_to_functions = {preprocessor .__name__ : preprocessor for preprocessor in preprocessor_dict .values ()}
8+ pp_names_to_classes = {pp_function .__name__ : pp_class for pp_class , pp_function in _all_preprocesser_dict .items ()}
79
810
911class PreprocessingPipeline :
@@ -24,15 +26,18 @@ class PreprocessingPipeline:
2426 >>> preprocessor_dict = {'bandpass_filter': {'freq_max': 3000}, 'common_reference': {}}
2527 >>> my_pipeline = PreprocessingPipeline(preprocessor_dict)
2628 PreprocessingPipeline: Raw Recording → bandpass_filter → common_reference → Preprocessed Recording
27- >>> my_pipeline.apply_to (recording)
29+ >>> my_pipeline.apply (recording)
2830
2931 """
3032
3133 def __init__ (self , preprocessor_dict ):
3234 for preprocessor in preprocessor_dict :
33- assert _is_genuine_preprocessor (
34- preprocessor
35- ), f"'{ preprocessor } ' is not a preprocessing step in spikeinterface. To see the full list run:\n \t >>> from spikeinterface.preprocessing import pp_function_to_class\n \t >>> print(pp_function_to_class.keys())"
35+ if preprocessor not in pp_names_to_functions .keys ():
36+ raise TypeError (
37+ f"'{ preprocessor } ' is not supported by the `PreprocessingPipeline`. \
38+ To see the list of supported steps, run:\n \t >>> from spikeinterface.preprocessing \
39+ import preprocessor_dict\n \t >>> print(preprocessor_dict.keys())"
40+ )
3641
3742 self .preprocessor_dict = preprocessor_dict
3843
@@ -47,47 +52,40 @@ def _repr_html_(self):
4752
4853 all_kwargs = _get_all_kwargs_and_values (self )
4954
50- num_titles = len (all_kwargs ) + 3
51- colors = [
52- [255 - (a * 15 ) // num_titles , 255 - (a * 82 ) // num_titles , 255 - (a * 30 ) // num_titles ]
53- for a in range (0 , num_titles + 1 )
54- ]
55-
5655 html_text = "<div'>"
57-
58- html_text += "<strong>PreprocessingPipeline:</strong>"
59-
60- html_text += f"<div style='border:1px solid #ddd; padding:10px;'><strong>Initial Recording</strong></div>"
61-
62- html_text += "<div style='margin: auto'>↓</div>"
56+ html_text += "<strong>PreprocessingPipeline</strong>"
57+ html_text += "<div style='border:1px solid #ccc; padding:10px;'><strong>Initial Recording</strong></div>"
58+ html_text += "<div style='margin: auto; text-indent: 30px;'>↓</div>"
6359
6460 for a , (preprocessor , kwargs ) in enumerate (all_kwargs .items ()):
65- html_text += "<details>"
66- html_text += (
67- f"<summary style='border:1px solid #eee; padding:5px;'><strong>{ preprocessor } </strong></summary>"
68- )
61+ html_text += "<details style='border:1px solid #ddd; padding:5px;'>"
62+ html_text += f"<summary><strong>{ preprocessor } </strong></summary>"
63+
6964 html_text += "<ul>"
7065 for kwarg , value in kwargs .items ():
7166 html_text += f"<li><strong>{ kwarg } </strong>: { value } </li>"
7267 html_text += "</ul>"
7368 html_text += "</details>"
74- html_text += """<div">↓</div>"""
75-
76- html_text += f"<div style='border:1px solid #ddd; padding:10px;'><strong>Preprocessed Recording</strong></div>"
7769
70+ html_text += """<div style='margin: auto; text-indent: 30px;'>↓</div>"""
71+ html_text += "<div style='border:1px solid #ccc; padding:10px;'><strong>Preprocessed Recording</strong></div>"
7872 html_text += "</div>"
7973
8074 return html_text
8175
82- def apply_to (self , recording ):
76+ def apply (self , recording , ignore_precomputed_kwargs = True ):
8377 """
84- Creates a preprocessed recording by applying the PreprocessingPipeline to
78+ Creates a preprocessed recording by applying the ` PreprocessingPipeline` to
8579 `recording`.
8680
8781 Parameters
8882 ----------
8983 recording : RecordingExtractor
9084 The initial recording
85+ ignore_precomputed_kwargs : Bool
86+ Some preprocessing steps (e.g. Whitening) contain arguments which are computed
87+ during preprocessing. If True, we ignore these precomputed steps. If False, we
88+ compute when we apply the preprocessors.
9189
9290 Returns
9391 -------
@@ -96,28 +94,25 @@ def apply_to(self, recording):
9694
9795 """
9896
99- preprocessor_dict = self .preprocessor_dict
97+ for preprocessor_name , kwargs in self .preprocessor_dict . items ():
10098
101- for preprocessor , kwargs in preprocessor_dict . items ():
99+ dont_include_kwargs = [ "recording" , "parent_recording" ]
102100
103- kwargs .pop ("recording" , kwargs )
104- kwargs .pop ("parent_recording" , kwargs )
105-
106- using_class_name = bool (re .search ("Recording" , preprocessor ))
107- if using_class_name is True :
108- pp_output = preprocesser_dict [preprocessor .split ("." )[- 1 ]](recording , ** kwargs )
109- else :
110- pp_output = pp_function_to_class [preprocessor .split ("." )[- 1 ]](recording , ** kwargs )
111-
112- if preprocessor == "motion_correct" :
113- pp_output = pp_output [0 ]
101+ if ignore_precomputed_kwargs :
102+ preprocessor_class = pp_names_to_classes [preprocessor_name ]
103+ precomputable_kwarg_names = preprocessor_class ._precomputable_kwarg_names
104+ dont_include_kwargs += precomputable_kwarg_names
114105
106+ non_rec_kwargs = {key : value for key , value in kwargs .items () if key not in dont_include_kwargs }
107+ pp_output = pp_names_to_functions [preprocessor_name ](recording , ** non_rec_kwargs )
115108 recording = pp_output
116109
117110 return recording
118111
119112
120- def create_preprocessed (recording , preprocessor_dict = None ):
113+ def apply_pipeline (
114+ recording : BaseRecording , pipeline_or_dict : dict | PreprocessingPipeline = {}, ignore_precomputed_kwargs = True
115+ ):
121116 """
122117 Creates a preprocessed recording by applying the preprocessing steps in
123118 `preprocessor_dict` to `recording`.
@@ -126,8 +121,13 @@ def create_preprocessed(recording, preprocessor_dict=None):
126121 ----------
127122 recording : RecordingExtractor
128123 The initial recording
129- preprocessor_dict : dict
130- Dictionary containing preprocessing steps and their kwargs
124+ preprocessor_dict : dict | PreprocessingPipeline = {}
125+ Dictionary containing preprocessing steps and their kwargs, or a pipeline object.
126+ If None, the original recording is returned.
127+ ignore_precomputed_kwargs : Bool
128+ Some preprocessing steps (e.g. Whitening) contain arguments which are computed
129+ during preprocessing. If True, we ignore these precomputed steps. If False, we
130+ compute when we apply the preprocessors.
131131
132132 Returns
133133 -------
@@ -142,13 +142,15 @@ def create_preprocessed(recording, preprocessor_dict=None):
142142 >>> from spikeinterface.generation import generate_recording
143143 >>> recording = generate_recording()
144144 >>> preprocessor_dict = {'bandpass_filter': {'freq_max': 3000}, 'common_reference': {}}
145- >>> preprocessed_recording = create_preprocessed(recording, preprocessor_dict)
146-
147-
145+ >>> preprocessed_recording = apply_pipeline(recording, preprocessor_dict)
148146 """
149147
150- pipeline = PreprocessingPipeline (preprocessor_dict )
151- preprocessed_recording = pipeline .apply_to (recording )
148+ if isinstance (pipeline_or_dict , PreprocessingPipeline ):
149+ pipeline = pipeline_or_dict
150+ else :
151+ pipeline = PreprocessingPipeline (pipeline_or_dict )
152+
153+ preprocessed_recording = pipeline .apply (recording , ignore_precomputed_kwargs )
152154 return preprocessed_recording
153155
154156
@@ -177,49 +179,34 @@ def get_preprocessing_dict_from_json(recording_json_path):
177179 """
178180 recording_json = json .load (open (recording_json_path ))
179181
180- initial_preprocessor_dict = {}
181- _load_pp_from_dict (recording_json , initial_preprocessor_dict )
182-
183- preprocessor_dict = deepcopy (initial_preprocessor_dict )
184- for preprocessor in initial_preprocessor_dict :
185- preprocessor_name = preprocessor .split ("." )[- 1 ]
186-
187- if not _is_genuine_preprocessor (preprocessor_name ):
188- preprocessor_dict .pop (preprocessor , preprocessor_dict )
189- continue
190-
191- # remove recording details
192- preprocessor_dict [preprocessor ].pop ("recording" , preprocessor_dict [preprocessor ])
193- preprocessor_dict [preprocessor ].pop ("parent_recording" , preprocessor_dict [preprocessor ])
194-
195- # rename keys to be the class names
196- preprocessor_dict [preprocessor_name ] = preprocessor_dict [preprocessor ]
197- preprocessor_dict .pop (preprocessor )
182+ pp_from_json = {}
183+ _load_pp_from_dict (recording_json , pp_from_json )
198184
199- preprocessor_dict = dict (reversed (preprocessor_dict .items ()))
185+ pipeline_dict = {}
186+ for preprocessor in reversed (pp_from_json ):
200187
201- return preprocessor_dict
188+ preprocessor_class_name = preprocessor . split ( "." )[ - 1 ]
202189
190+ preprocessor_function = preprocessor_dict .get (preprocessor_class_name )
191+ if preprocessor_function is None :
192+ continue
203193
204- def _is_genuine_preprocessor ( preprocessor ):
205- """
206- Check is string 'preprocessor' is in the list of preprocessors from
207- `pp_function_to_class`.
208- """
194+ pp_kwargs = {
195+ key : value
196+ for key , value in pp_from_json [ preprocessor ]. items ()
197+ if key not in [ "recording" , "parent_recording" ]
198+ }
209199
210- using_class_name = bool (re .search ("Recording" , preprocessor ))
211- if using_class_name :
212- genuine_preprocessor = preprocessor in preprocesser_dict .keys ()
213- else :
214- genuine_preprocessor = preprocessor in pp_function_to_class .keys ()
200+ pipeline_dict [preprocessor_function .__name__ ] = pp_kwargs
215201
216- return genuine_preprocessor
202+ return pipeline_dict
217203
218204
219205def _load_pp_from_dict (prov_dict , kwargs_dict ):
220206 """
221207 Recursive function used to iterate through recording provenance dictionary, and
222- extract preprocessing steps and their kwargs.
208+ extract preprocessing steps and their kwargs. Based on `_load_extractor_from_dict`
209+ from spikeinterface.core.base.
223210 """
224211 new_kwargs = dict ()
225212 transform_dict_to_extractor = lambda x : _load_pp_from_dict (x ) if is_dict_extractor (x ) else x
@@ -238,18 +225,21 @@ def _load_pp_from_dict(prov_dict, kwargs_dict):
238225
239226
240227def _get_all_kwargs_and_values (my_pipeline ):
228+ """
229+ Get all keyword arguments and their values from a pipeline,
230+ including the default values.
231+ """
241232
242233 all_kwargs = {}
243234 for preprocessor in my_pipeline .preprocessor_dict :
244235
245236 preprocessor_name = preprocessor .split ("." )[- 1 ]
246- # preprocessor_name = preprocessor.split(".")[-1]
247- pp_function = pp_function_to_class [preprocessor .split ("." )[- 1 ]]
237+ pp_function = pp_names_to_functions [preprocessor .split ("." )[- 1 ]]
248238 signature = inspect .signature (pp_function )
249239
250240 all_kwargs [preprocessor_name ] = {}
251241
252- for parameter , value in signature .parameters .items ():
242+ for _ , value in signature .parameters .items ():
253243 par_name = str (value ).split ("=" )[0 ].split (":" )[0 ]
254244 if par_name != "recording" :
255245 try :
0 commit comments