Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 52 additions & 43 deletions element_array_ephys/ephys_no_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,36 +251,18 @@ def key_source(self):
& 'session_type IN ("lfp", "both")'
)

def make(self, key):
"""Compute broadband LFP signals for each electrode.

Args:
key (dict): EphysSession primary key.
TARGET_SAMPLING_RATE = 2500 # Hz
POWERLINE_NOISE_FREQ = 60 # Hz
MAX_DURATION_MINUTES = 30 # Minutes

Raises:
ValueError: If the trace duration is not within the expected range.
OSError: If there is an error when loading the file.

Logic:
- Fetch the probe information for the given ephys session.
- Fetch the electrode configuration for the given probe.
- Fetch the raw data files for the given ephys session.
- Check for missing files or short trace durations in min
- Design notch filter to remove powerline noise that contaminates the LFP
- Downsample the signal with `decimate` and apply an anti-aliasing FIR filter
"""
def make_fetch(self, key):
execution_time = datetime.now(timezone.utc)

# Define constants
TARGET_SAMPLING_RATE = 2500 # Hz
POWERLINE_NOISE_FREQ = 60 # Hz
MAX_DURATION_MINUTES = 30 # Minutes

# Check if the trace duration is within the expected range
duration = (key["end_time"] - key["start_time"]).total_seconds() / 60 # minutes
assert (
duration <= MAX_DURATION_MINUTES
), f"LFP session duration {duration} min > max session duration {MAX_DURATION_MINUTES} min"
duration <= self.MAX_DURATION_MINUTES
), f"LFP session duration {duration} min > max session duration {self.MAX_DURATION_MINUTES} min"

# Fetch the raw data files for the given ephys session
query = (
Expand All @@ -305,11 +287,36 @@ def make(self, key):
if probe_info["used_electrodes"]:
electrode_query &= f"electrode IN {tuple(probe_info['used_electrodes'])}"

lfp_indices = np.array(electrode_query.fetch("channel_idx"), dtype=int)

electrode_df = electrode_query.fetch(format="frame").reset_index()

file_paths = query.fetch("file_path", order_by="file_time")

return file_paths, lfp_indices, probe_info, electrode_df, execution_time

def make_compute(self, key, file_paths, lfp_indices, probe_info, electrode_df, execution_time):
"""Compute broadband LFP signals for each electrode.

Args:
key (dict): EphysSession primary key.

Raises:
ValueError: If the trace duration is not within the expected range.
OSError: If there is an error when loading the file.

Logic:
- Fetch the probe information for the given ephys session.
- Fetch the electrode configuration for the given probe.
- Fetch the raw data files for the given ephys session.
- Check for missing files or short trace durations in min
- Design notch filter to remove powerline noise that contaminates the LFP
- Downsample the signal with `decimate` and apply an anti-aliasing FIR filter
"""
header = {}
lfp_concat = []

# Iterate over the raw data files for the given ephys session to load the data
for file_relpath in query.fetch("file_path", order_by="file_time"):
for file_relpath in file_paths:
file = find_full_path(get_ephys_root_data_dir(), file_relpath)
try:
data = intanrhdreader.load_file(file)
Expand All @@ -320,11 +327,11 @@ def make(self, key):
header = data.pop("header")
lfp_sampling_rate = header["sample_rate"]
powerline_noise_freq = (
header["notch_filter_frequency"] or POWERLINE_NOISE_FREQ
header["notch_filter_frequency"] or self.POWERLINE_NOISE_FREQ
) # in Hz

# Calculate downsampling factor
true_ratio = lfp_sampling_rate / TARGET_SAMPLING_RATE
true_ratio = lfp_sampling_rate / self.TARGET_SAMPLING_RATE
downsample_factor = int(np.round(true_ratio))

# Check if the ratio is within 1% of an integer (1% tolerance)
Expand All @@ -334,7 +341,6 @@ def make(self, key):
)

# Get LFP indices (row index of the LFP matrix to be used)
lfp_indices = np.array(electrode_query.fetch("channel_idx"), dtype=int)
port_indices = np.array(
[
ind
Expand All @@ -344,8 +350,6 @@ def make(self, key):
)
lfp_indices = np.sort(port_indices[lfp_indices])

self.insert1({**key, "lfp_sampling_rate": TARGET_SAMPLING_RATE})

# Get LFP channels
channels = np.array(
[
Expand All @@ -356,7 +360,6 @@ def make(self, key):
)[lfp_indices]

# Get channel to electrode mapping
electrode_df = electrode_query.fetch(format="frame").reset_index()
channel_to_electrode_map = dict(
zip(electrode_df["channel_idx"], electrode_df["electrode"])
)
Expand All @@ -383,34 +386,40 @@ def make(self, key):
w0=powerline_noise_freq, Q=30, fs=lfp_sampling_rate
)

all_lfps = []
for ch_idx, raw_lfp in zip(channels, full_lfp):

# Apply notch filter
lfp = signal.filtfilt(notch_b, notch_a, raw_lfp)

# Downsample the signal with `decimate`
lfp = signal.decimate(lfp, downsample_factor, ftype="fir", zero_phase=True)
all_lfps.append(lfp)

self.Trace.insert1(
{
**key,
"electrode_config_hash": electrode_df["electrode_config_hash"][0],
"probe_type": electrode_df["probe_type"][0],
"electrode": channel_to_electrode_map[ch_idx],
"lfp": lfp,
}
)
return all_lfps, channels, electrode_df, channel_to_electrode_map, execution_time

self.update1(
def make_insert(self, key, all_lfps, channels, electrode_df, channel_to_electrode_map, execution_time):
self.insert1(
{
**key,
"lfp_sampling_rate": self.TARGET_SAMPLING_RATE,
"execution_duration": (
datetime.now(timezone.utc) - execution_time
).total_seconds()
/ 3600,
}
)

for ch_idx, lfp in zip(channels, all_lfps):
self.Trace.insert1(
{
**key,
"electrode_config_hash": electrode_df["electrode_config_hash"][0],
"probe_type": electrode_df["probe_type"][0],
"electrode": channel_to_electrode_map[ch_idx],
"lfp": lfp,
}
)


# ------------ Clustering --------------

Expand Down
Loading