99from element_interface .utils import dict_to_uuid , find_full_path , find_root_directory
1010from scipy import signal
1111import intanrhdreader
12+ import neo
13+ import quantities as pq
14+ from elephant .spike_train_correlation import spike_time_tiling_coefficient
1215
1316from . import ephys_report , probe
1417from .readers import kilosort , openephys , spikeglx
@@ -115,7 +118,36 @@ def get_processed_root_data_dir() -> str:
115118 else :
116119 return get_ephys_root_data_dir ()[0 ]
117120
121+ def map_channel_to_electrode (probe_type = "A1x32-6mm-100-177-H32_21mm" , input_indices = None , electrode_to_channel = False ):
122+ """
123+ Maps channel indices from recording controller to specific probe geometry as defined in probe.ElectrodeConfig.Electrode.
124+
125+ Args:
126+ probe_type (str): Name of the probe used in the recording session. See probe.ProbeType() for inserted probes.
127+ electrode_to_channel (bool): If True, maps from electrode indices to channel indices. If False, maps from channel indices to electrode indices. Default is False.
128+
129+ Returns:
130+ electrodes (array-like): Array of electrode indices corresponding to the input channel indices.
131+ If electrode_to_channel is False, the output will be electrode indices. If electrode_to_channel is True, the output will be channel indices.
132+ """
118133
134+ # get electrode and channel info
135+ num_electrodes = len (probe .ProbeType .Electrode & f"probe_type='{ probe_type } '" )
136+ electrode_mapping , channel_mapping = probe .ElectrodeConfig .Electrode .fetch ("electrode" , "channel_idx" )
137+
138+ # create lookup to convert
139+ lookup = np .empty (num_electrodes , dtype = int )
140+ if electrode_to_channel :
141+ lookup [electrode_mapping ] = channel_mapping
142+ else :
143+ lookup [channel_mapping ] = electrode_mapping
144+
145+ # correctly map electrode indices
146+ if input_indices is None :
147+ input_indices = np .arange (num_electrodes )
148+
149+ electrode_ids = lookup [input_indices ]
150+ return electrode_ids
119151# ----------------------------- Table declarations ----------------------
120152
121153
@@ -152,7 +184,6 @@ class EphysRawFile(dj.Manual):
152184 filename_prefix : varchar(64) # filename prefix, if any, excluding the datetime information
153185 """
154186
155-
156187@schema
157188class EphysSession (dj .Manual ):
158189 definition = """ # User defined ephys session for downstream analysis.
@@ -224,7 +255,6 @@ def make(self, key):
224255 ]
225256 )
226257
227-
228258@schema
229259class LFP (dj .Imported ):
230260 definition = """ # Store pre-processed LFP traces per electrode. Only the LFPs collected from a pre-defined recording session.
@@ -384,6 +414,28 @@ def make_compute(
384414 }
385415
386416 lfps = data .pop ("amplifier_data" )[lfp_indices ]
417+
418+ # account for boundaries
419+ fs = header ["sample_rate" ]
420+ if file_relpath == file_paths [0 ]:
421+ file_start = datetime .strptime (
422+ "_" .join (file_relpath .split ("_" )[3 :5 ]).removesuffix (".rhd" ),
423+ "%y%m%d_%H%M%S" ,
424+ )
425+ start_idx = int ((key ['start_time' ] - file_start ).total_seconds () * fs )
426+
427+ # trim lfps to start boundary
428+ lfps = lfps [:, start_idx :]
429+ elif file_relpath == file_paths [- 1 ]:
430+ file_start = datetime .strptime (
431+ "_" .join (file_relpath .split ("_" )[3 :5 ]).removesuffix (".rhd" ),
432+ "%y%m%d_%H%M%S" ,
433+ )
434+ end_idx = int ((key ['end_time' ] - file_start ).total_seconds () * fs )
435+
436+ # trim lfps to end boundary
437+ lfps = lfps [:, :end_idx ]
438+
387439 lfp_concat .append (lfps )
388440
389441 full_lfp = np .hstack (lfp_concat )
@@ -408,10 +460,11 @@ def make_compute(
408460 # Downsample the signal with `decimate`
409461 lfp = signal .decimate (lfp , downsample_factor , ftype = "fir" , zero_phase = True )
410462 all_lfps .append (lfp )
411-
412- execution_duration = (
413- datetime .now (timezone .utc ) - execution_time
414- ).total_seconds () / 3600
463+
464+ execution_duration = ((
465+ datetime .now (timezone .utc ) - execution_time
466+ ).total_seconds ()
467+ / 3600 )
415468 return (
416469 all_lfps ,
417470 channels ,
@@ -448,7 +501,89 @@ def make_insert(
448501 }
449502 )
450503
504+ @schema
505+ class ImpedanceFile (dj .Manual ):
506+ definition = """ # Insert files and organoid_id for impedance measurements
507+ -> ephys.EphysRawFile
508+ organoid_id : varchar(4) # e.g. O17
509+ """
510+
511+ @schema
512+ class ImpedanceMeasurements (dj .Imported ):
513+ definition = """ # Store impedance measurements per channel
514+ -> ImpedanceFile
515+ ---
516+ port_id: char(2) # Port ID of the Intan acquisition system
517+ """
518+
519+ class Channel (dj .Part ):
520+ definition = """
521+ -> master
522+ channel_idx: int # channel index
523+ ---
524+ channel_id: varchar(64) # channel id
525+ impedance_magnitude: float # in Ohms
526+ impedance_phase: float # in Degrees
527+ """
451528
529+ def make (self , key ):
530+ # fetch file path from ephysrawfile entry
531+ file_path = (EphysRawFile & key ).fetch1 ("file_path" )
532+
533+ # import file
534+ file = find_full_path (get_ephys_root_data_dir (), file_path )
535+ try :
536+ data = intanrhdreader .load_file (file )
537+ except OSError :
538+ raise OSError (f"OS error occurred when loading file { file .name } " )
539+
540+ # extract amplifier channels
541+ amplifier_channels = data ['header' ].pop ("amplifier_channels" )
542+
543+ # Figure out `Port ID` from the existing EphysSessionProbe
544+ port_id = set ((EphysSessionProbe & key ).fetch ("port_id" ))
545+
546+ # Figure out `Port ID` from the existing EphysSession
547+ if not (EphysSessionProbe & key ):
548+ raise ValueError (
549+ f"No EphysSessionProbe found for the { key } - cannot determine the port ID"
550+ )
551+
552+ # Check if there are multiple port IDs for the same experiment, if so, it needs to be fixed in the EphysSessionProbe table
553+ if len (port_id ) > 1 :
554+ raise ValueError (
555+ f"Multiple Port IDs found for the { key } - cannot determine the port ID"
556+ )
557+ port_id = port_id .pop ()
558+
559+ # get channels for the correct port
560+ port_channels = [channel for channel in amplifier_channels if channel ['port_prefix' ] == port_id ]
561+
562+ # insert into master
563+ self .insert1 (
564+ {
565+ ** key ,
566+ "port_id" : port_id ,
567+ }
568+ )
569+
570+ # loop through channels and insert impedance data
571+ for channel in port_channels :
572+
573+ channel_idx = channel ['custom_order' ]
574+ channel_id = channel ['custom_channel_name' ]
575+ impedance_magnitude = channel ['electrode_impedance_magnitude' ]
576+ impedance_phase = channel ['electrode_impedance_phase' ]
577+
578+ self .Channel .insert1 (
579+ {
580+ ** key ,
581+ "channel_idx" : channel_idx ,
582+ "channel_id" : channel_id ,
583+ "impedance_magnitude" : impedance_magnitude ,
584+ "impedance_phase" : impedance_phase ,
585+ }
586+ )
452587# ------------ Clustering --------------
453588
454589
@@ -1101,3 +1236,66 @@ def make(self, key):
11011236 self .insert1 (key )
11021237 self .Cluster .insert (metrics_list , ignore_extra_fields = True )
11031238 self .Waveform .insert (metrics_list , ignore_extra_fields = True )
1239+
1240+ """
1241+ Functional Connectivity (STTC)
1242+ """
1243+ @schema
1244+ class STTC (dj .Computed ):
1245+ """
1246+ Spike Time Tiling Coefficient (STTC) between unit pairs. Automatically computed within ephys sessions (spike sorting).
1247+ Based on the method described in Sharf et al. (2022) Nature Communications.
1248+ """
1249+
1250+ definition = """
1251+ -> ephys.CuratedClustering
1252+ unit_a: int # First unit in the pair
1253+ unit_b: int # Second unit in the pair
1254+ ---
1255+ sttc: float # STTC value between unit pairs
1256+ spike_time_latencies: longblob # Latencies (ms) of spikes from unit A to nearest spike in unit B during (limited to +/- dt)
1257+ """
1258+
1259+ def make (self , key ):
1260+
1261+ # define parameters
1262+ dt = 20 # ms
1263+
1264+ # fetch spike times for all units in the clustering
1265+ unit_ids , spike_times = (CuratedClustering .Unit & key ).fetch ('unit' , 'spike_times' , order_by = 'unit' )
1266+
1267+ num_units = len (unit_ids )
1268+ t_stop = (key ['end_time' ] - key ['start_time' ]) / timedelta (milliseconds = 1 ) # in ms
1269+
1270+ # REMOVE LATER
1271+ spike_times = np .array ([st [st <= (key ['end_time' ] - key ['start_time' ]).total_seconds ()] for st in spike_times ], dtype = object )
1272+
1273+ # loop through unit pairs and calculate STTC
1274+ for i in range (num_units - 1 ):
1275+ for j in range (i + 1 , num_units ):
1276+
1277+ # get spike times (convert from seconds to miliseconds)
1278+ spikes_A = (spike_times [i ] * (timedelta (seconds = 1 ) / timedelta (milliseconds = 1 ))).astype (int )
1279+ spikes_B = (spike_times [j ] * (timedelta (seconds = 1 ) / timedelta (milliseconds = 1 ))).astype (int )
1280+
1281+ # convert to spike trains (neo)
1282+ spiketrain_A = neo .SpikeTrain (spikes_A , units = 'ms' , t_stop = t_stop )
1283+ spiketrain_B = neo .SpikeTrain (spikes_B , units = 'ms' , t_stop = t_stop )
1284+
1285+ # calculate STTC
1286+ sttc = spike_time_tiling_coefficient (spiketrain_A , spiketrain_B , dt = dt * pq .ms )
1287+
1288+ # calculate spike time latencies
1289+ diff_matrix = np .abs (np .subtract .outer (spikes_A , spikes_B ))
1290+ closest_spikes = np .min (diff_matrix , axis = 1 ) # closest spike in B for each spike in A
1291+ spike_time_latencies = closest_spikes [closest_spikes <= dt ]
1292+
1293+ self .insert1 (
1294+ {
1295+ ** key ,
1296+ 'unit_a' : unit_ids [i ],
1297+ 'unit_b' : unit_ids [j ],
1298+ 'sttc' : sttc ,
1299+ 'spike_time_latencies' : spike_time_latencies ,
1300+ }
1301+ )
0 commit comments