Skip to content

Commit 5a485e1

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 0d35e2f commit 5a485e1

2 files changed

Lines changed: 46 additions & 37 deletions

File tree

src/spikeinterface/extractors/neoextractors/maxwell.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,10 @@ def install_maxwell_plugin(self, force_download=False):
9191
auto_install_maxwell_hdf5_compression_plugin(force_download=False)
9292

9393

94-
_maxwell_event_dtype = np.dtype([("id", "int8"), ("frame", "uint32"), ("time", "float64"), ("state", "uint32"), ("message", "object")])
94+
_maxwell_event_dtype = np.dtype(
95+
[("id", "int8"), ("frame", "uint32"), ("time", "float64"), ("state", "uint32"), ("message", "object")]
96+
)
97+
9598

9699
class MaxwellEventExtractor(BaseEvent):
97100
"""
@@ -105,40 +108,43 @@ def __init__(self, file_path):
105108
h5_file = h5py.File(self.file_path, mode="r")
106109
version = int(h5_file["version"][0].decode())
107110
fs = 20000
108-
111+
109112
if version < 20190530:
110113
raise NotImplementedError(f"Version {self.version} not supported")
111-
114+
112115
# get ttl events
113116
bits = h5_file["bits"]
114-
115-
channel_ids = np.zeros((0),dtype=np.int8)
116-
if len(bits)>0:
117+
118+
channel_ids = np.zeros((0), dtype=np.int8)
119+
if len(bits) > 0:
117120
bit_state = bits["bits"]
118121
channel_ids = np.int8(np.unique(bit_state[bit_state != 0]))
119122
if -1 in channel_ids or 1 in channel_ids:
120123
raise ValueError("TTL bits cannot be -1 or 1.")
121-
124+
122125
# access data_store from h5_file
123126
data_store_keys = [x for x in h5_file["data_store"].keys()]
124-
data_store_keys_id = [("events" in h5_file["data_store"][x].keys()) and ("groups" in h5_file["data_store"][x].keys()) for x in data_store_keys]
127+
data_store_keys_id = [
128+
("events" in h5_file["data_store"][x].keys()) and ("groups" in h5_file["data_store"][x].keys())
129+
for x in data_store_keys
130+
]
125131
data_store = data_store_keys[data_store_keys_id.index(True)]
126132

127133
# get stim events
128134
event_raw = h5_file["data_store"][data_store]["events"]
129135
channel_ids_stim = np.int8(np.unique([x[1] for x in event_raw]))
130136
if -1 in channel_ids_stim or 0 in channel_ids_stim:
131137
raise ValueError("Stimulation bits cannot be -1 or 0.")
132-
if len(channel_ids)>0:
138+
if len(channel_ids) > 0:
133139
if set(channel_ids) & set(channel_ids_stim):
134140
raise ValueError("TTL and stimulation bits overlap.")
135141
channel_ids = np.concatenate((channel_ids, channel_ids_stim), dtype=np.int8)
136142

137143
# set spike events channel == -1
138144
spike_raw = h5_file["data_store"][data_store]["spikes"]
139-
if len(spike_raw)>0:
145+
if len(spike_raw) > 0:
140146
channel_ids = np.concatenate((channel_ids, [-1]), dtype=np.int8)
141-
147+
142148
BaseEvent.__init__(self, channel_ids, structured_dtype=_maxwell_event_dtype)
143149
event_segment = MaxwellEventSegment(h5_file, version, fs)
144150
self.add_event_segment(event_segment)
@@ -151,30 +157,33 @@ def __init__(self, h5_file, version, fs):
151157
self.version = version
152158
self.bits = self.h5_file["bits"]
153159
self.fs = fs
154-
160+
155161
def get_events(self, channel_id, start_time, end_time):
156162
bits = self.bits
157-
163+
158164
# get ttl events
159-
channel_ids = np.zeros((0),dtype=np.int8)
160-
bit_channel = np.zeros((0),dtype=np.int8)
161-
bit_frameno = np.zeros((0),dtype=np.uint32)
162-
bit_state = np.zeros((0),dtype=np.uint32)
163-
bit_message = np.zeros((0),dtype=object)
164-
if len(bits)>0:
165+
channel_ids = np.zeros((0), dtype=np.int8)
166+
bit_channel = np.zeros((0), dtype=np.int8)
167+
bit_frameno = np.zeros((0), dtype=np.uint32)
168+
bit_state = np.zeros((0), dtype=np.uint32)
169+
bit_message = np.zeros((0), dtype=object)
170+
if len(bits) > 0:
165171
good_idx = np.where(bits["bits"] != 0)[0]
166-
channel_ids = np.concatenate((channel_ids,np.int8(np.unique(bits["bits"][good_idx]))))
172+
channel_ids = np.concatenate((channel_ids, np.int8(np.unique(bits["bits"][good_idx]))))
167173
if 1 in channel_ids:
168174
raise ValueError("TTL bits cannot be 1.")
169-
bit_channel = np.concatenate((bit_channel,np.uint8(bits["bits"][good_idx])))
170-
bit_frameno = np.concatenate((bit_frameno,np.uint32(bits["frameno"][good_idx])))
171-
bit_state = np.concatenate((bit_state,np.uint32(bits["bits"][good_idx])))
172-
bit_message = np.concatenate((bit_message,[b'{}\n']*len(bit_state)),dtype=object)
173-
175+
bit_channel = np.concatenate((bit_channel, np.uint8(bits["bits"][good_idx])))
176+
bit_frameno = np.concatenate((bit_frameno, np.uint32(bits["frameno"][good_idx])))
177+
bit_state = np.concatenate((bit_state, np.uint32(bits["bits"][good_idx])))
178+
bit_message = np.concatenate((bit_message, [b"{}\n"] * len(bit_state)), dtype=object)
179+
174180
# access data_store from h5_file
175181
h5_file = self.h5_file
176182
data_store_keys = [x for x in h5_file["data_store"].keys()]
177-
data_store_keys_id = [("events" in h5_file["data_store"][x].keys()) and ("groups" in h5_file["data_store"][x].keys()) for x in data_store_keys]
183+
data_store_keys_id = [
184+
("events" in h5_file["data_store"][x].keys()) and ("groups" in h5_file["data_store"][x].keys())
185+
for x in data_store_keys
186+
]
178187
data_store = data_store_keys[data_store_keys_id.index(True)]
179188

180189
# get stim events
@@ -185,15 +194,15 @@ def get_events(self, channel_id, start_time, end_time):
185194
bit_frameno_stim = stim_arr["frameno"]
186195
bit_state_stim = stim_arr["eventid"]
187196
bit_message_stim = stim_arr["eventmessage"]
188-
197+
189198
# get spike events
190199
spike_raw = h5_file["data_store"][data_store]["spikes"]
191-
if len(spike_raw)>0:
200+
if len(spike_raw) > 0:
192201
channel_ids_spike = np.int8([-1])
193202
spike_arr = np.array(spike_raw)
194-
bit_channel_spike = -np.ones(len(spike_arr),dtype=np.int8)
203+
bit_channel_spike = -np.ones(len(spike_arr), dtype=np.int8)
195204
bit_frameno_spike = spike_arr["frameno"]
196-
bit_state_spike= spike_arr["channel"]
205+
bit_state_spike = spike_arr["channel"]
197206
bit_message_spike = spike_arr["amplitude"]
198207

199208
# final array in order: spikes, stims, ttl

src/spikeinterface/extractors/neoextractors/neobaseextractor.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -232,22 +232,22 @@ def __init__(
232232
mask_id = np.argwhere(mask).flatten().tolist()
233233

234234
# remove all duplicate channel-to-electrode assignments
235-
dupid = np.where(pd.DataFrame(signal_channels[mask]['name']).duplicated(keep='first'))
235+
dupid = np.where(pd.DataFrame(signal_channels[mask]["name"]).duplicated(keep="first"))
236236
for i in dupid[0]:
237237
mask[mask_id.pop(i)] = False
238-
238+
239239
# remove all duplicate channel assigments corresponding to different electrodes (channel is a mix of mulitple electrode signals)
240-
signal_channels_chan,_ = map(list, zip(*(x.split(' ') for x in signal_channels[mask]['name'])))
240+
signal_channels_chan, _ = map(list, zip(*(x.split(" ") for x in signal_channels[mask]["name"])))
241241
dupid = np.where(pd.DataFrame(signal_channels_chan).duplicated(keep=False))
242242
for i in dupid[0]:
243243
mask[mask_id.pop(i)] = False
244-
244+
245245
# remove subsequent duplicated electrodes (single electrode saved to multiple channels)
246-
_,signal_channels_elec = map(list, zip(*(x.split(' ') for x in signal_channels[mask]['name'])))
247-
dupid = np.where(pd.DataFrame(signal_channels_elec).duplicated(keep='first'))
246+
_, signal_channels_elec = map(list, zip(*(x.split(" ") for x in signal_channels[mask]["name"])))
247+
dupid = np.where(pd.DataFrame(signal_channels_elec).duplicated(keep="first"))
248248
for i in dupid[0]:
249249
mask[mask_id.pop(i)] = False
250-
250+
251251
signal_channels = signal_channels[mask]
252252

253253
if use_names_as_ids:

0 commit comments

Comments
 (0)