@@ -91,7 +91,10 @@ def install_maxwell_plugin(self, force_download=False):
9191 auto_install_maxwell_hdf5_compression_plugin (force_download = False )
9292
9393
94- _maxwell_event_dtype = np .dtype ([("id" , "int8" ), ("frame" , "uint32" ), ("time" , "float64" ), ("state" , "uint32" ), ("message" , "object" )])
94+ _maxwell_event_dtype = np .dtype (
95+ [("id" , "int8" ), ("frame" , "uint32" ), ("time" , "float64" ), ("state" , "uint32" ), ("message" , "object" )]
96+ )
97+
9598
9699class MaxwellEventExtractor (BaseEvent ):
97100 """
@@ -105,40 +108,43 @@ def __init__(self, file_path):
105108 h5_file = h5py .File (self .file_path , mode = "r" )
106109 version = int (h5_file ["version" ][0 ].decode ())
107110 fs = 20000
108-
111+
109112 if version < 20190530 :
110113 raise NotImplementedError (f"Version { self .version } not supported" )
111-
114+
112115 # get ttl events
113116 bits = h5_file ["bits" ]
114-
115- channel_ids = np .zeros ((0 ),dtype = np .int8 )
116- if len (bits )> 0 :
117+
118+ channel_ids = np .zeros ((0 ), dtype = np .int8 )
119+ if len (bits ) > 0 :
117120 bit_state = bits ["bits" ]
118121 channel_ids = np .int8 (np .unique (bit_state [bit_state != 0 ]))
119122 if - 1 in channel_ids or 1 in channel_ids :
120123 raise ValueError ("TTL bits cannot be -1 or 1." )
121-
124+
122125 # access data_store from h5_file
123126 data_store_keys = [x for x in h5_file ["data_store" ].keys ()]
124- data_store_keys_id = [("events" in h5_file ["data_store" ][x ].keys ()) and ("groups" in h5_file ["data_store" ][x ].keys ()) for x in data_store_keys ]
127+ data_store_keys_id = [
128+ ("events" in h5_file ["data_store" ][x ].keys ()) and ("groups" in h5_file ["data_store" ][x ].keys ())
129+ for x in data_store_keys
130+ ]
125131 data_store = data_store_keys [data_store_keys_id .index (True )]
126132
127133 # get stim events
128134 event_raw = h5_file ["data_store" ][data_store ]["events" ]
129135 channel_ids_stim = np .int8 (np .unique ([x [1 ] for x in event_raw ]))
130136 if - 1 in channel_ids_stim or 0 in channel_ids_stim :
131137 raise ValueError ("Stimulation bits cannot be -1 or 0." )
132- if len (channel_ids )> 0 :
138+ if len (channel_ids ) > 0 :
133139 if set (channel_ids ) & set (channel_ids_stim ):
134140 raise ValueError ("TTL and stimulation bits overlap." )
135141 channel_ids = np .concatenate ((channel_ids , channel_ids_stim ), dtype = np .int8 )
136142
137143 # set spike events channel == -1
138144 spike_raw = h5_file ["data_store" ][data_store ]["spikes" ]
139- if len (spike_raw )> 0 :
145+ if len (spike_raw ) > 0 :
140146 channel_ids = np .concatenate ((channel_ids , [- 1 ]), dtype = np .int8 )
141-
147+
142148 BaseEvent .__init__ (self , channel_ids , structured_dtype = _maxwell_event_dtype )
143149 event_segment = MaxwellEventSegment (h5_file , version , fs )
144150 self .add_event_segment (event_segment )
@@ -151,30 +157,33 @@ def __init__(self, h5_file, version, fs):
151157 self .version = version
152158 self .bits = self .h5_file ["bits" ]
153159 self .fs = fs
154-
160+
155161 def get_events (self , channel_id , start_time , end_time ):
156162 bits = self .bits
157-
163+
158164 # get ttl events
159- channel_ids = np .zeros ((0 ),dtype = np .int8 )
160- bit_channel = np .zeros ((0 ),dtype = np .int8 )
161- bit_frameno = np .zeros ((0 ),dtype = np .uint32 )
162- bit_state = np .zeros ((0 ),dtype = np .uint32 )
163- bit_message = np .zeros ((0 ),dtype = object )
164- if len (bits )> 0 :
165+ channel_ids = np .zeros ((0 ), dtype = np .int8 )
166+ bit_channel = np .zeros ((0 ), dtype = np .int8 )
167+ bit_frameno = np .zeros ((0 ), dtype = np .uint32 )
168+ bit_state = np .zeros ((0 ), dtype = np .uint32 )
169+ bit_message = np .zeros ((0 ), dtype = object )
170+ if len (bits ) > 0 :
165171 good_idx = np .where (bits ["bits" ] != 0 )[0 ]
166- channel_ids = np .concatenate ((channel_ids ,np .int8 (np .unique (bits ["bits" ][good_idx ]))))
172+ channel_ids = np .concatenate ((channel_ids , np .int8 (np .unique (bits ["bits" ][good_idx ]))))
167173 if 1 in channel_ids :
168174 raise ValueError ("TTL bits cannot be 1." )
169- bit_channel = np .concatenate ((bit_channel ,np .uint8 (bits ["bits" ][good_idx ])))
170- bit_frameno = np .concatenate ((bit_frameno ,np .uint32 (bits ["frameno" ][good_idx ])))
171- bit_state = np .concatenate ((bit_state ,np .uint32 (bits ["bits" ][good_idx ])))
172- bit_message = np .concatenate ((bit_message ,[ b' {}\n ' ] * len (bit_state )),dtype = object )
173-
175+ bit_channel = np .concatenate ((bit_channel , np .uint8 (bits ["bits" ][good_idx ])))
176+ bit_frameno = np .concatenate ((bit_frameno , np .uint32 (bits ["frameno" ][good_idx ])))
177+ bit_state = np .concatenate ((bit_state , np .uint32 (bits ["bits" ][good_idx ])))
178+ bit_message = np .concatenate ((bit_message , [ b" {}\n " ] * len (bit_state )), dtype = object )
179+
174180 # access data_store from h5_file
175181 h5_file = self .h5_file
176182 data_store_keys = [x for x in h5_file ["data_store" ].keys ()]
177- data_store_keys_id = [("events" in h5_file ["data_store" ][x ].keys ()) and ("groups" in h5_file ["data_store" ][x ].keys ()) for x in data_store_keys ]
183+ data_store_keys_id = [
184+ ("events" in h5_file ["data_store" ][x ].keys ()) and ("groups" in h5_file ["data_store" ][x ].keys ())
185+ for x in data_store_keys
186+ ]
178187 data_store = data_store_keys [data_store_keys_id .index (True )]
179188
180189 # get stim events
@@ -185,15 +194,15 @@ def get_events(self, channel_id, start_time, end_time):
185194 bit_frameno_stim = stim_arr ["frameno" ]
186195 bit_state_stim = stim_arr ["eventid" ]
187196 bit_message_stim = stim_arr ["eventmessage" ]
188-
197+
189198 # get spike events
190199 spike_raw = h5_file ["data_store" ][data_store ]["spikes" ]
191- if len (spike_raw )> 0 :
200+ if len (spike_raw ) > 0 :
192201 channel_ids_spike = np .int8 ([- 1 ])
193202 spike_arr = np .array (spike_raw )
194- bit_channel_spike = - np .ones (len (spike_arr ),dtype = np .int8 )
203+ bit_channel_spike = - np .ones (len (spike_arr ), dtype = np .int8 )
195204 bit_frameno_spike = spike_arr ["frameno" ]
196- bit_state_spike = spike_arr ["channel" ]
205+ bit_state_spike = spike_arr ["channel" ]
197206 bit_message_spike = spike_arr ["amplitude" ]
198207
199208 # final array in order: spikes, stims, ttl
0 commit comments