Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 236 additions & 6 deletions element_array_ephys/ephys_no_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,68 @@ def get_processed_root_data_dir() -> str:
else:
return get_ephys_root_data_dir()[0]

def map_channel_to_electrode(probe_type="A1x32-6mm-100-177-H32_21mm", input_indices=None, electrode_to_channel=False):
"""
Maps channel indices from recording controller to specific probe geometry as defined in probe.ElectrodeConfig.Electrode.

Args:
probe_type (str): Name of the probe used in the recording session. See probe.ProbeType() for inserted probes.
electrode_to_channel (bool): If True, maps from electrode indices to channel indices. If False, maps from channel indices to electrode indices. Default is False.

Returns:
electrodes (array-like): Array of electrode indices corresponding to the input channel indices.
If electrode_to_channel is False, the output will be electrode indices. If electrode_to_channel is True, the output will be channel indices.
"""

# get electrode and channel info
# probe_type is part of ElectrodeConfig.Electrode's PK via -> ProbeType.Electrode
num_electrodes = len(probe.ProbeType.Electrode & {"probe_type": probe_type})
electrode_mapping, channel_mapping = (
probe.ElectrodeConfig.Electrode & {"probe_type": probe_type}
).fetch("electrode", "channel_idx")

if len(electrode_mapping) == 0:
raise ValueError(
f"No electrode configuration found for probe_type='{probe_type}'. "
"Ensure an ElectrodeConfig has been inserted for this probe type."
)

# multiple electrode configs may exist for the same probe_type;
# deduplicate unique (electrode, channel_idx) pairs and raise if they conflict
pairs = np.unique(np.column_stack([electrode_mapping, channel_mapping]), axis=0)
if len(pairs) != len(np.unique(pairs[:, 0])):
raise ValueError(
f"Conflicting channel-electrode mappings found for probe_type='{probe_type}'. "
"Multiple electrode configurations exist with inconsistent channel assignments. "
"Cannot determine a unique mapping."
)
electrode_mapping = pairs[:, 0]
channel_mapping = pairs[:, 1]

# create lookup to convert; -1 marks indices not covered by this config
lookup = np.full(num_electrodes, -1, dtype=int)
if electrode_to_channel:
lookup[electrode_mapping] = channel_mapping
else:
lookup[channel_mapping] = electrode_mapping

# correctly map electrode indices
if input_indices is None:
input_indices = np.arange(num_electrodes)

electrode_ids = lookup[input_indices]
return electrode_ids

def get_probe_type(ephys_key):
"""
Gets the probe type for a given ephys session key. EphysSessionProbe needs an entry along with the EphysSession for ephys_key
"""
probe_type = set((EphysSessionProbe * probe.Probe & ephys_key).fetch('probe_type'))
if len(probe_type) != 1:
raise ValueError(
f"Couldn't identify probe type for {ephys_key} - expected one, found {len(probe_type)}"
)
return probe_type.pop()
# ----------------------------- Table declarations ----------------------


Expand Down Expand Up @@ -152,7 +213,6 @@ class EphysRawFile(dj.Manual):
filename_prefix : varchar(64) # filename prefix, if any, excluding the datetime information
"""


@schema
class EphysSession(dj.Manual):
definition = """ # User defined ephys session for downstream analysis.
Expand Down Expand Up @@ -224,7 +284,6 @@ def make(self, key):
]
)


@schema
class LFP(dj.Imported):
definition = """ # Store pre-processed LFP traces per electrode. Only the LFPs collected from a pre-defined recording session.
Expand Down Expand Up @@ -384,6 +443,29 @@ def make_compute(
}

lfps = data.pop("amplifier_data")[lfp_indices]

# account for boundaries
fs = header["sample_rate"]
start_idx = 0
end_idx = lfps.shape[1]

is_first_file = file_relpath == file_paths[0]
is_last_file = file_relpath == file_paths[-1]

if is_first_file or is_last_file:
# parse file_start once; applies to both boundary conditions
file_start = datetime.strptime(
"_".join(file_relpath.split("_")[3:5]).removesuffix(".rhd"),
"%y%m%d_%H%M%S",
)
if is_first_file:
start_idx = int((key['start_time'] - file_start).total_seconds() * fs)
if is_last_file:
end_idx = int((key['end_time'] - file_start).total_seconds() * fs)

# trim lfps to session boundaries (handles single-file and multi-file sessions)
lfps = lfps[:, start_idx:end_idx]

lfp_concat.append(lfps)

full_lfp = np.hstack(lfp_concat)
Expand All @@ -408,10 +490,11 @@ def make_compute(
# Downsample the signal with `decimate`
lfp = signal.decimate(lfp, downsample_factor, ftype="fir", zero_phase=True)
all_lfps.append(lfp)

execution_duration = (
datetime.now(timezone.utc) - execution_time
).total_seconds() / 3600

execution_duration = ((
datetime.now(timezone.utc) - execution_time
).total_seconds()
/ 3600)
return (
all_lfps,
channels,
Expand Down Expand Up @@ -448,7 +531,88 @@ def make_insert(
}
)

@schema
class ImpedanceFile(dj.Manual):
definition = """ # Insert files and organoid_id for impedance measurements
-> ephys.EphysRawFile
organoid_id : varchar(4) # e.g. O17
"""

@schema
class ImpedanceMeasurements(dj.Imported):
definition = """ # Store impedance measurements per channel
-> ImpedanceFile
---
port_id: char(2) # Port ID of the Intan acquisition system
"""

class Channel(dj.Part):
definition = """
-> master
channel_idx: int # channel index
---
channel_id: varchar(64) # channel id
impedance_magnitude: float # in Ohms
impedance_phase: float # in Degrees
"""

def make(self, key):
# fetch file path from ephysrawfile entry
file_path = (EphysRawFile & key).fetch1("file_path")

# import file
file = find_full_path(get_ephys_root_data_dir(), file_path)
try:
data = intanrhdreader.load_file(file)
except OSError:
raise OSError(f"OS error occurred when loading file {file.name}")

# extract amplifier channels
amplifier_channels = data['header'].pop("amplifier_channels")

# Figure out `Port ID` from the existing EphysSessionProbe
if not (EphysSessionProbe & key):
raise ValueError(
f"No EphysSessionProbe found for the {key} - cannot determine the port ID"
)

port_id = set((EphysSessionProbe & key).fetch("port_id"))

# Check if there are multiple port IDs for the same experiment, if so, it needs to be fixed in the EphysSessionProbe table
if len(port_id) > 1:
raise ValueError(
f"Multiple Port IDs found for the {key} - cannot determine the port ID"
)
port_id = port_id.pop()

# get channels for the correct port
port_channels = [channel for channel in amplifier_channels if channel['port_prefix'] == port_id]

# insert into master
self.insert1(
{
**key,
"port_id": port_id,
}
)

# loop through channels and insert impedance data
for channel in port_channels:

channel_idx = channel['custom_order']
channel_id = channel['custom_channel_name']
impedance_magnitude = channel['electrode_impedance_magnitude']
impedance_phase = channel['electrode_impedance_phase']

self.Channel.insert1(
{
**key,
"channel_idx": channel_idx,
"channel_id": channel_id,
"impedance_magnitude": impedance_magnitude,
"impedance_phase": impedance_phase,
}
)
# ------------ Clustering --------------


Expand Down Expand Up @@ -1101,3 +1265,69 @@ def make(self, key):
self.insert1(key)
self.Cluster.insert(metrics_list, ignore_extra_fields=True)
self.Waveform.insert(metrics_list, ignore_extra_fields=True)

"""
Functional Connectivity (STTC)
"""
@schema
class STTC(dj.Computed):
"""
Spike Time Tiling Coefficient (STTC) between unit pairs. Automatically computed within ephys sessions (spike sorting).
Based on the method described in Sharf et al. (2022) Nature Communications.
"""

definition = """
-> ephys.CuratedClustering
unit_a: int # First unit in the pair
unit_b: int # Second unit in the pair
---
sttc: float # STTC value between unit pairs
spike_time_latencies: longblob # Latencies (ms) of spikes from unit A to nearest spike in unit B during (limited to +/- dt)
"""

def make(self, key):
import neo
import quantities as pq
from elephant.spike_train_correlation import spike_time_tiling_coefficient

# define parameters
dt = 20 # ms

# fetch spike times for all units in the clustering
unit_ids, spike_times = (CuratedClustering.Unit & key).fetch('unit', 'spike_times', order_by='unit')

num_units = len(unit_ids)
t_stop = (key['end_time'] - key['start_time']) / timedelta(milliseconds=1) # in ms

# clip spike times to session duration to guard against edge-case spikes beyond end_time
spike_times = np.array([st[st <= (key['end_time'] - key['start_time']).total_seconds()] for st in spike_times], dtype=object)

# loop through unit pairs and calculate STTC
for i in range(num_units - 1):
for j in range(i + 1, num_units):

# get spike times (convert from seconds to miliseconds)
spikes_A = (spike_times[i] * (timedelta(seconds=1) / timedelta(milliseconds=1))).astype(int)
spikes_B = (spike_times[j] * (timedelta(seconds=1) / timedelta(milliseconds=1))).astype(int)

# convert to spike trains (neo)
spiketrain_A = neo.SpikeTrain(spikes_A, units='ms', t_stop=t_stop)
spiketrain_B = neo.SpikeTrain(spikes_B, units='ms', t_stop=t_stop)

# calculate STTC
sttc = spike_time_tiling_coefficient(spiketrain_A, spiketrain_B, dt=dt*pq.ms)

# calculate spike time latencies
diff_matrix = np.abs(np.subtract.outer(spikes_A, spikes_B))
closest_spikes = np.min(diff_matrix, axis=1) # closest spike in B for each spike in A
spike_time_latencies = closest_spikes[closest_spikes <= dt]

self.insert1(
{
**key,
'unit_a': unit_ids[i],
'unit_b': unit_ids[j],
'sttc': sttc,
'spike_time_latencies': spike_time_latencies,
}
)
Loading