Skip to content

Commit affca49

Browse files
committed
Changes to Ephys Schema
1) map channel to electrode function 2) Fix bug in trace extraction (currently doesn't account for file boundaries 3) Add impedance measurement tables 4) Add STTC tables
1 parent 997bdbb commit affca49

1 file changed

Lines changed: 204 additions & 6 deletions

File tree

element_array_ephys/ephys_no_curation.py

Lines changed: 204 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory
1010
from scipy import signal
1111
import intanrhdreader
12+
import neo
13+
import quantities as pq
14+
from elephant.spike_train_correlation import spike_time_tiling_coefficient
1215

1316
from . import ephys_report, probe
1417
from .readers import kilosort, openephys, spikeglx
@@ -115,7 +118,36 @@ def get_processed_root_data_dir() -> str:
115118
else:
116119
return get_ephys_root_data_dir()[0]
117120

121+
def map_channel_to_electrode(probe_type="A1x32-6mm-100-177-H32_21mm", input_indices=None, electrode_to_channel=False):
122+
"""
123+
Maps channel indices from recording controller to specific probe geometry as defined in probe.ElectrodeConfig.Electrode.
124+
125+
Args:
126+
probe_type (str): Name of the probe used in the recording session. See probe.ProbeType() for inserted probes.
127+
electrode_to_channel (bool): If True, maps from electrode indices to channel indices. If False, maps from channel indices to electrode indices. Default is False.
128+
129+
Returns:
130+
electrodes (array-like): Array of electrode indices corresponding to the input channel indices.
131+
If electrode_to_channel is False, the output will be electrode indices. If electrode_to_channel is True, the output will be channel indices.
132+
"""
118133

134+
# get electrode and channel info
135+
num_electrodes = len(probe.ProbeType.Electrode & f"probe_type='{probe_type}'")
136+
electrode_mapping, channel_mapping = probe.ElectrodeConfig.Electrode.fetch("electrode", "channel_idx")
137+
138+
# create lookup to convert
139+
lookup = np.empty(num_electrodes, dtype=int)
140+
if electrode_to_channel:
141+
lookup[electrode_mapping] = channel_mapping
142+
else:
143+
lookup[channel_mapping] = electrode_mapping
144+
145+
# correctly map electrode indices
146+
if input_indices is None:
147+
input_indices = np.arange(num_electrodes)
148+
149+
electrode_ids = lookup[input_indices]
150+
return electrode_ids
119151
# ----------------------------- Table declarations ----------------------
120152

121153

@@ -152,7 +184,6 @@ class EphysRawFile(dj.Manual):
152184
filename_prefix : varchar(64) # filename prefix, if any, excluding the datetime information
153185
"""
154186

155-
156187
@schema
157188
class EphysSession(dj.Manual):
158189
definition = """ # User defined ephys session for downstream analysis.
@@ -224,7 +255,6 @@ def make(self, key):
224255
]
225256
)
226257

227-
228258
@schema
229259
class LFP(dj.Imported):
230260
definition = """ # Store pre-processed LFP traces per electrode. Only the LFPs collected from a pre-defined recording session.
@@ -384,6 +414,28 @@ def make_compute(
384414
}
385415

386416
lfps = data.pop("amplifier_data")[lfp_indices]
417+
418+
# account for boundaries
419+
fs = header["sample_rate"]
420+
if file_relpath == file_paths[0]:
421+
file_start = datetime.strptime(
422+
"_".join(file_relpath.split("_")[3:5]).removesuffix(".rhd"),
423+
"%y%m%d_%H%M%S",
424+
)
425+
start_idx = int((key['start_time'] - file_start).total_seconds() * fs)
426+
427+
# trim lfps to start boundary
428+
lfps = lfps[:, start_idx:]
429+
elif file_relpath == file_paths[-1]:
430+
file_start = datetime.strptime(
431+
"_".join(file_relpath.split("_")[3:5]).removesuffix(".rhd"),
432+
"%y%m%d_%H%M%S",
433+
)
434+
end_idx = int((key['end_time'] - file_start).total_seconds() * fs)
435+
436+
# trim lfps to end boundary
437+
lfps = lfps[:, :end_idx]
438+
387439
lfp_concat.append(lfps)
388440

389441
full_lfp = np.hstack(lfp_concat)
@@ -408,10 +460,11 @@ def make_compute(
408460
# Downsample the signal with `decimate`
409461
lfp = signal.decimate(lfp, downsample_factor, ftype="fir", zero_phase=True)
410462
all_lfps.append(lfp)
411-
412-
execution_duration = (
413-
datetime.now(timezone.utc) - execution_time
414-
).total_seconds() / 3600
463+
464+
execution_duration = ((
465+
datetime.now(timezone.utc) - execution_time
466+
).total_seconds()
467+
/ 3600)
415468
return (
416469
all_lfps,
417470
channels,
@@ -448,7 +501,89 @@ def make_insert(
448501
}
449502
)
450503

504+
@schema
505+
class ImpedanceFile(dj.Manual):
506+
definition = """ # Insert files and organoid_id for impedance measurements
507+
-> ephys.EphysRawFile
508+
organoid_id : varchar(4) # e.g. O17
509+
"""
510+
511+
@schema
512+
class ImpedanceMeasurements(dj.Imported):
513+
definition = """ # Store impedance measurements per channel
514+
-> ImpedanceFile
515+
---
516+
port_id: char(2) # Port ID of the Intan acquisition system
517+
"""
518+
519+
class Channel(dj.Part):
520+
definition = """
521+
-> master
522+
channel_idx: int # channel index
523+
---
524+
channel_id: varchar(64) # channel id
525+
impedance_magnitude: float # in Ohms
526+
impedance_phase: float # in Degrees
527+
"""
451528

529+
def make(self, key):
530+
# fetch file path from ephysrawfile entry
531+
file_path = (EphysRawFile & key).fetch1("file_path")
532+
533+
# import file
534+
file = find_full_path(get_ephys_root_data_dir(), file_path)
535+
try:
536+
data = intanrhdreader.load_file(file)
537+
except OSError:
538+
raise OSError(f"OS error occurred when loading file {file.name}")
539+
540+
# extract amplifier channels
541+
amplifier_channels = data['header'].pop("amplifier_channels")
542+
543+
# Figure out `Port ID` from the existing EphysSessionProbe
544+
port_id = set((EphysSessionProbe & key).fetch("port_id"))
545+
546+
# Figure out `Port ID` from the existing EphysSession
547+
if not (EphysSessionProbe & key):
548+
raise ValueError(
549+
f"No EphysSessionProbe found for the {key} - cannot determine the port ID"
550+
)
551+
552+
# Check if there are multiple port IDs for the same experiment, if so, it needs to be fixed in the EphysSessionProbe table
553+
if len(port_id) > 1:
554+
raise ValueError(
555+
f"Multiple Port IDs found for the {key} - cannot determine the port ID"
556+
)
557+
port_id = port_id.pop()
558+
559+
# get channels for the correct port
560+
port_channels = [channel for channel in amplifier_channels if channel['port_prefix'] == port_id]
561+
562+
# insert into master
563+
self.insert1(
564+
{
565+
**key,
566+
"port_id": port_id,
567+
}
568+
)
569+
570+
# loop through channels and insert impedance data
571+
for channel in port_channels:
572+
573+
channel_idx = channel['custom_order']
574+
channel_id = channel['custom_channel_name']
575+
impedance_magnitude = channel['electrode_impedance_magnitude']
576+
impedance_phase = channel['electrode_impedance_phase']
577+
578+
self.Channel.insert1(
579+
{
580+
**key,
581+
"channel_idx": channel_idx,
582+
"channel_id": channel_id,
583+
"impedance_magnitude": impedance_magnitude,
584+
"impedance_phase": impedance_phase,
585+
}
586+
)
452587
# ------------ Clustering --------------
453588

454589

@@ -1101,3 +1236,66 @@ def make(self, key):
11011236
self.insert1(key)
11021237
self.Cluster.insert(metrics_list, ignore_extra_fields=True)
11031238
self.Waveform.insert(metrics_list, ignore_extra_fields=True)
1239+
1240+
"""
1241+
Functional Connectivity (STTC)
1242+
"""
1243+
@schema
1244+
class STTC(dj.Computed):
1245+
"""
1246+
Spike Time Tiling Coefficient (STTC) between unit pairs. Automatically computed within ephys sessions (spike sorting).
1247+
Based on the method described in Sharf et al. (2022) Nature Communications.
1248+
"""
1249+
1250+
definition = """
1251+
-> ephys.CuratedClustering
1252+
unit_a: int # First unit in the pair
1253+
unit_b: int # Second unit in the pair
1254+
---
1255+
sttc: float # STTC value between unit pairs
1256+
spike_time_latencies: longblob # Latencies (ms) of spikes from unit A to nearest spike in unit B during (limited to +/- dt)
1257+
"""
1258+
1259+
def make(self, key):
1260+
1261+
# define parameters
1262+
dt = 20 # ms
1263+
1264+
# fetch spike times for all units in the clustering
1265+
unit_ids, spike_times = (CuratedClustering.Unit & key).fetch('unit', 'spike_times', order_by='unit')
1266+
1267+
num_units = len(unit_ids)
1268+
t_stop = (key['end_time'] - key['start_time']) / timedelta(milliseconds=1) # in ms
1269+
1270+
# REMOVE LATER
1271+
spike_times = np.array([st[st <= (key['end_time'] - key['start_time']).total_seconds()] for st in spike_times], dtype=object)
1272+
1273+
# loop through unit pairs and calculate STTC
1274+
for i in range(num_units - 1):
1275+
for j in range(i + 1, num_units):
1276+
1277+
# get spike times (convert from seconds to miliseconds)
1278+
spikes_A = (spike_times[i] * (timedelta(seconds=1) / timedelta(milliseconds=1))).astype(int)
1279+
spikes_B = (spike_times[j] * (timedelta(seconds=1) / timedelta(milliseconds=1))).astype(int)
1280+
1281+
# convert to spike trains (neo)
1282+
spiketrain_A = neo.SpikeTrain(spikes_A, units='ms', t_stop=t_stop)
1283+
spiketrain_B = neo.SpikeTrain(spikes_B, units='ms', t_stop=t_stop)
1284+
1285+
# calculate STTC
1286+
sttc = spike_time_tiling_coefficient(spiketrain_A, spiketrain_B, dt=dt*pq.ms)
1287+
1288+
# calculate spike time latencies
1289+
diff_matrix = np.abs(np.subtract.outer(spikes_A, spikes_B))
1290+
closest_spikes = np.min(diff_matrix, axis=1) # closest spike in B for each spike in A
1291+
spike_time_latencies = closest_spikes[closest_spikes <= dt]
1292+
1293+
self.insert1(
1294+
{
1295+
**key,
1296+
'unit_a': unit_ids[i],
1297+
'unit_b': unit_ids[j],
1298+
'sttc': sttc,
1299+
'spike_time_latencies': spike_time_latencies,
1300+
}
1301+
)

0 commit comments

Comments
 (0)