diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index f8f34fe3..ea9778a9 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -251,36 +251,18 @@ def key_source(self): & 'session_type IN ("lfp", "both")' ) - def make(self, key): - """Compute broadband LFP signals for each electrode. - - Args: - key (dict): EphysSession primary key. + TARGET_SAMPLING_RATE = 2500 # Hz + POWERLINE_NOISE_FREQ = 60 # Hz + MAX_DURATION_MINUTES = 30 # Minutes - Raises: - ValueError: If the trace duration is not within the expected range. - OSError: If there is an error when loading the file. - - Logic: - - Fetch the probe information for the given ephys session. - - Fetch the electrode configuration for the given probe. - - Fetch the raw data files for the given ephys session. - - Check for missing files or short trace durations in min - - Design notch filter to remove powerline noise that contaminates the LFP - - Downsample the signal with `decimate` and apply an anti-aliasing FIR filter - """ + def make_fetch(self, key): execution_time = datetime.now(timezone.utc) - # Define constants - TARGET_SAMPLING_RATE = 2500 # Hz - POWERLINE_NOISE_FREQ = 60 # Hz - MAX_DURATION_MINUTES = 30 # Minutes - # Check if the trace duration is within the expected range duration = (key["end_time"] - key["start_time"]).total_seconds() / 60 # minutes assert ( - duration <= MAX_DURATION_MINUTES - ), f"LFP session duration {duration} min > max session duration {MAX_DURATION_MINUTES} min" + duration <= self.MAX_DURATION_MINUTES + ), f"LFP session duration {duration} min > max session duration {self.MAX_DURATION_MINUTES} min" # Fetch the raw data files for the given ephys session query = ( @@ -305,11 +287,36 @@ def make(self, key): if probe_info["used_electrodes"]: electrode_query &= f"electrode IN {tuple(probe_info['used_electrodes'])}" + lfp_indices = np.array(electrode_query.fetch("channel_idx"), dtype=int) + + electrode_df = electrode_query.fetch(format="frame").reset_index() + + file_paths = query.fetch("file_path", order_by="file_time") + + return file_paths, lfp_indices, probe_info, electrode_df, execution_time + + def make_compute(self, key, file_paths, lfp_indices, probe_info, electrode_df, execution_time): + """Compute broadband LFP signals for each electrode. + + Args: + key (dict): EphysSession primary key. + + Raises: + ValueError: If the trace duration is not within the expected range. + OSError: If there is an error when loading the file. + + Logic: + - Fetch the probe information for the given ephys session. + - Fetch the electrode configuration for the given probe. + - Fetch the raw data files for the given ephys session. + - Check for missing files or short trace durations in min + - Design notch filter to remove powerline noise that contaminates the LFP + - Downsample the signal with `decimate` and apply an anti-aliasing FIR filter + """ header = {} lfp_concat = [] - # Iterate over the raw data files for the given ephys session to load the data - for file_relpath in query.fetch("file_path", order_by="file_time"): + for file_relpath in file_paths: file = find_full_path(get_ephys_root_data_dir(), file_relpath) try: data = intanrhdreader.load_file(file) @@ -320,11 +327,11 @@ def make(self, key): header = data.pop("header") lfp_sampling_rate = header["sample_rate"] powerline_noise_freq = ( - header["notch_filter_frequency"] or POWERLINE_NOISE_FREQ + header["notch_filter_frequency"] or self.POWERLINE_NOISE_FREQ ) # in Hz # Calculate downsampling factor - true_ratio = lfp_sampling_rate / TARGET_SAMPLING_RATE + true_ratio = lfp_sampling_rate / self.TARGET_SAMPLING_RATE downsample_factor = int(np.round(true_ratio)) # Check if the ratio is within 1% of an integer (1% tolerance) @@ -334,7 +341,6 @@ def make(self, key): ) # Get LFP indices (row index of the LFP matrix to be used) - lfp_indices = np.array(electrode_query.fetch("channel_idx"), dtype=int) port_indices = np.array( [ ind @@ -344,8 +350,6 @@ def make(self, key): ) lfp_indices = np.sort(port_indices[lfp_indices]) - self.insert1({**key, "lfp_sampling_rate": TARGET_SAMPLING_RATE}) - # Get LFP channels channels = np.array( [ @@ -356,7 +360,6 @@ def make(self, key): )[lfp_indices] # Get channel to electrode mapping - electrode_df = electrode_query.fetch(format="frame").reset_index() channel_to_electrode_map = dict( zip(electrode_df["channel_idx"], electrode_df["electrode"]) ) @@ -383,27 +386,22 @@ def make(self, key): w0=powerline_noise_freq, Q=30, fs=lfp_sampling_rate ) + all_lfps = [] for ch_idx, raw_lfp in zip(channels, full_lfp): - # Apply notch filter lfp = signal.filtfilt(notch_b, notch_a, raw_lfp) # Downsample the signal with `decimate` lfp = signal.decimate(lfp, downsample_factor, ftype="fir", zero_phase=True) + all_lfps.append(lfp) - self.Trace.insert1( - { - **key, - "electrode_config_hash": electrode_df["electrode_config_hash"][0], - "probe_type": electrode_df["probe_type"][0], - "electrode": channel_to_electrode_map[ch_idx], - "lfp": lfp, - } - ) + return all_lfps, channels, electrode_df, channel_to_electrode_map, execution_time - self.update1( + def make_insert(self, key, all_lfps, channels, electrode_df, channel_to_electrode_map, execution_time): + self.insert1( { **key, + "lfp_sampling_rate": self.TARGET_SAMPLING_RATE, "execution_duration": ( datetime.now(timezone.utc) - execution_time ).total_seconds() @@ -411,6 +409,17 @@ def make(self, key): } ) + for ch_idx, lfp in zip(channels, all_lfps): + self.Trace.insert1( + { + **key, + "electrode_config_hash": electrode_df["electrode_config_hash"][0], + "probe_type": electrode_df["probe_type"][0], + "electrode": channel_to_electrode_map[ch_idx], + "lfp": lfp, + } + ) + # ------------ Clustering --------------