Skip to content

Commit 05cd8d9

Browse files
Merge pull request #59 from dj-sciops/dev_three-part-make-lfp
feat: 3-part make for LFP
2 parents 1ce2ef1 + c155f47 commit 05cd8d9

1 file changed

Lines changed: 52 additions & 43 deletions

File tree

element_array_ephys/ephys_no_curation.py

Lines changed: 52 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -251,36 +251,18 @@ def key_source(self):
251251
& 'session_type IN ("lfp", "both")'
252252
)
253253

254-
def make(self, key):
255-
"""Compute broadband LFP signals for each electrode.
256-
257-
Args:
258-
key (dict): EphysSession primary key.
254+
TARGET_SAMPLING_RATE = 2500 # Hz
255+
POWERLINE_NOISE_FREQ = 60 # Hz
256+
MAX_DURATION_MINUTES = 30 # Minutes
259257

260-
Raises:
261-
ValueError: If the trace duration is not within the expected range.
262-
OSError: If there is an error when loading the file.
263-
264-
Logic:
265-
- Fetch the probe information for the given ephys session.
266-
- Fetch the electrode configuration for the given probe.
267-
- Fetch the raw data files for the given ephys session.
268-
- Check for missing files or short trace durations in min
269-
- Design notch filter to remove powerline noise that contaminates the LFP
270-
- Downsample the signal with `decimate` and apply an anti-aliasing FIR filter
271-
"""
258+
def make_fetch(self, key):
272259
execution_time = datetime.now(timezone.utc)
273260

274-
# Define constants
275-
TARGET_SAMPLING_RATE = 2500 # Hz
276-
POWERLINE_NOISE_FREQ = 60 # Hz
277-
MAX_DURATION_MINUTES = 30 # Minutes
278-
279261
# Check if the trace duration is within the expected range
280262
duration = (key["end_time"] - key["start_time"]).total_seconds() / 60 # minutes
281263
assert (
282-
duration <= MAX_DURATION_MINUTES
283-
), f"LFP session duration {duration} min > max session duration {MAX_DURATION_MINUTES} min"
264+
duration <= self.MAX_DURATION_MINUTES
265+
), f"LFP session duration {duration} min > max session duration {self.MAX_DURATION_MINUTES} min"
284266

285267
# Fetch the raw data files for the given ephys session
286268
query = (
@@ -305,11 +287,36 @@ def make(self, key):
305287
if probe_info["used_electrodes"]:
306288
electrode_query &= f"electrode IN {tuple(probe_info['used_electrodes'])}"
307289

290+
lfp_indices = np.array(electrode_query.fetch("channel_idx"), dtype=int)
291+
292+
electrode_df = electrode_query.fetch(format="frame").reset_index()
293+
294+
file_paths = query.fetch("file_path", order_by="file_time")
295+
296+
return file_paths, lfp_indices, probe_info, electrode_df, execution_time
297+
298+
def make_compute(self, key, file_paths, lfp_indices, probe_info, electrode_df, execution_time):
299+
"""Compute broadband LFP signals for each electrode.
300+
301+
Args:
302+
key (dict): EphysSession primary key.
303+
304+
Raises:
305+
ValueError: If the trace duration is not within the expected range.
306+
OSError: If there is an error when loading the file.
307+
308+
Logic:
309+
- Fetch the probe information for the given ephys session.
310+
- Fetch the electrode configuration for the given probe.
311+
- Fetch the raw data files for the given ephys session.
312+
- Check for missing files or short trace durations in min
313+
- Design notch filter to remove powerline noise that contaminates the LFP
314+
- Downsample the signal with `decimate` and apply an anti-aliasing FIR filter
315+
"""
308316
header = {}
309317
lfp_concat = []
310-
311318
# Iterate over the raw data files for the given ephys session to load the data
312-
for file_relpath in query.fetch("file_path", order_by="file_time"):
319+
for file_relpath in file_paths:
313320
file = find_full_path(get_ephys_root_data_dir(), file_relpath)
314321
try:
315322
data = intanrhdreader.load_file(file)
@@ -320,11 +327,11 @@ def make(self, key):
320327
header = data.pop("header")
321328
lfp_sampling_rate = header["sample_rate"]
322329
powerline_noise_freq = (
323-
header["notch_filter_frequency"] or POWERLINE_NOISE_FREQ
330+
header["notch_filter_frequency"] or self.POWERLINE_NOISE_FREQ
324331
) # in Hz
325332

326333
# Calculate downsampling factor
327-
true_ratio = lfp_sampling_rate / TARGET_SAMPLING_RATE
334+
true_ratio = lfp_sampling_rate / self.TARGET_SAMPLING_RATE
328335
downsample_factor = int(np.round(true_ratio))
329336

330337
# Check if the ratio is within 1% of an integer (1% tolerance)
@@ -334,7 +341,6 @@ def make(self, key):
334341
)
335342

336343
# Get LFP indices (row index of the LFP matrix to be used)
337-
lfp_indices = np.array(electrode_query.fetch("channel_idx"), dtype=int)
338344
port_indices = np.array(
339345
[
340346
ind
@@ -344,8 +350,6 @@ def make(self, key):
344350
)
345351
lfp_indices = np.sort(port_indices[lfp_indices])
346352

347-
self.insert1({**key, "lfp_sampling_rate": TARGET_SAMPLING_RATE})
348-
349353
# Get LFP channels
350354
channels = np.array(
351355
[
@@ -356,7 +360,6 @@ def make(self, key):
356360
)[lfp_indices]
357361

358362
# Get channel to electrode mapping
359-
electrode_df = electrode_query.fetch(format="frame").reset_index()
360363
channel_to_electrode_map = dict(
361364
zip(electrode_df["channel_idx"], electrode_df["electrode"])
362365
)
@@ -383,34 +386,40 @@ def make(self, key):
383386
w0=powerline_noise_freq, Q=30, fs=lfp_sampling_rate
384387
)
385388

389+
all_lfps = []
386390
for ch_idx, raw_lfp in zip(channels, full_lfp):
387-
388391
# Apply notch filter
389392
lfp = signal.filtfilt(notch_b, notch_a, raw_lfp)
390393

391394
# Downsample the signal with `decimate`
392395
lfp = signal.decimate(lfp, downsample_factor, ftype="fir", zero_phase=True)
396+
all_lfps.append(lfp)
393397

394-
self.Trace.insert1(
395-
{
396-
**key,
397-
"electrode_config_hash": electrode_df["electrode_config_hash"][0],
398-
"probe_type": electrode_df["probe_type"][0],
399-
"electrode": channel_to_electrode_map[ch_idx],
400-
"lfp": lfp,
401-
}
402-
)
398+
return all_lfps, channels, electrode_df, channel_to_electrode_map, execution_time
403399

404-
self.update1(
400+
def make_insert(self, key, all_lfps, channels, electrode_df, channel_to_electrode_map, execution_time):
401+
self.insert1(
405402
{
406403
**key,
404+
"lfp_sampling_rate": self.TARGET_SAMPLING_RATE,
407405
"execution_duration": (
408406
datetime.now(timezone.utc) - execution_time
409407
).total_seconds()
410408
/ 3600,
411409
}
412410
)
413411

412+
for ch_idx, lfp in zip(channels, all_lfps):
413+
self.Trace.insert1(
414+
{
415+
**key,
416+
"electrode_config_hash": electrode_df["electrode_config_hash"][0],
417+
"probe_type": electrode_df["probe_type"][0],
418+
"electrode": channel_to_electrode_map[ch_idx],
419+
"lfp": lfp,
420+
}
421+
)
422+
414423

415424
# ------------ Clustering --------------
416425

0 commit comments

Comments
 (0)