99from element_interface .utils import dict_to_uuid , find_full_path , find_root_directory
1010from scipy import signal
1111import intanrhdreader
12- import neo
13- import quantities as pq
14- from elephant .spike_train_correlation import spike_time_tiling_coefficient
1512
1613from . import ephys_report , probe
1714from .readers import kilosort , openephys , spikeglx
@@ -132,8 +129,23 @@ def map_channel_to_electrode(probe_type="A1x32-6mm-100-177-H32_21mm", input_indi
132129 """
133130
134131 # get electrode and channel info
135- num_electrodes = len (probe .ProbeType .Electrode & f"probe_type='{ probe_type } '" )
136- electrode_mapping , channel_mapping = probe .ElectrodeConfig .Electrode .fetch ("electrode" , "channel_idx" )
132+ # probe_type is part of ElectrodeConfig.Electrode's PK via -> ProbeType.Electrode
133+ num_electrodes = len (probe .ProbeType .Electrode & {"probe_type" : probe_type })
134+ electrode_mapping , channel_mapping = (
135+ probe .ElectrodeConfig .Electrode & {"probe_type" : probe_type }
136+ ).fetch ("electrode" , "channel_idx" )
137+
138+ # multiple electrode configs may exist for the same probe_type;
139+ # deduplicate unique (electrode, channel_idx) pairs and raise if they conflict
140+ pairs = np .unique (np .column_stack ([electrode_mapping , channel_mapping ]), axis = 0 )
141+ if len (pairs ) != len (np .unique (pairs [:, 0 ])):
142+ raise ValueError (
143+ f"Conflicting channel-electrode mappings found for probe_type='{ probe_type } '. "
144+ "Multiple electrode configurations exist with inconsistent channel assignments. "
145+ "Cannot determine a unique mapping."
146+ )
147+ electrode_mapping = pairs [:, 0 ]
148+ channel_mapping = pairs [:, 1 ]
137149
138150 # create lookup to convert
139151 lookup = np .empty (num_electrodes , dtype = int )
@@ -428,24 +440,25 @@ def make_compute(
428440
429441 # account for boundaries
430442 fs = header ["sample_rate" ]
431- if file_relpath == file_paths [0 ]:
432- file_start = datetime .strptime (
433- "_" .join (file_relpath .split ("_" )[3 :5 ]).removesuffix (".rhd" ),
434- "%y%m%d_%H%M%S" ,
435- )
436- start_idx = int ((key ['start_time' ] - file_start ).total_seconds () * fs )
443+ start_idx = 0
444+ end_idx = lfps .shape [1 ]
437445
438- # trim lfps to start boundary
439- lfps = lfps [:, start_idx :]
440- elif file_relpath == file_paths [- 1 ]:
446+ is_first_file = file_relpath == file_paths [0 ]
447+ is_last_file = file_relpath == file_paths [- 1 ]
448+
449+ if is_first_file or is_last_file :
450+ # parse file_start once; applies to both boundary conditions
441451 file_start = datetime .strptime (
442452 "_" .join (file_relpath .split ("_" )[3 :5 ]).removesuffix (".rhd" ),
443453 "%y%m%d_%H%M%S" ,
444454 )
445- end_idx = int ((key ['end_time' ] - file_start ).total_seconds () * fs )
455+ if is_first_file :
456+ start_idx = int ((key ['start_time' ] - file_start ).total_seconds () * fs )
457+ if is_last_file :
458+ end_idx = int ((key ['end_time' ] - file_start ).total_seconds () * fs )
446459
447- # trim lfps to end boundary
448- lfps = lfps [:, :end_idx ]
460+ # trim lfps to session boundaries (handles single-file and multi-file sessions)
461+ lfps = lfps [:, start_idx :end_idx ]
449462
450463 lfp_concat .append (lfps )
451464
@@ -552,14 +565,13 @@ def make(self, key):
552565 amplifier_channels = data ['header' ].pop ("amplifier_channels" )
553566
554567 # Figure out `Port ID` from the existing EphysSessionProbe
555- port_id = set ((EphysSessionProbe & key ).fetch ("port_id" ))
556-
557- # Figure out `Port ID` from the existing EphysSession
558568 if not (EphysSessionProbe & key ):
559569 raise ValueError (
560570 f"No EphysSessionProbe found for the { key } - cannot determine the port ID"
561571 )
562572
573+ port_id = set ((EphysSessionProbe & key ).fetch ("port_id" ))
574+
563575 # Check if there are multiple port IDs for the same experiment, if so, it needs to be fixed in the EphysSessionProbe table
564576 if len (port_id ) > 1 :
565577 raise ValueError (
@@ -1268,17 +1280,20 @@ class STTC(dj.Computed):
12681280 """
12691281
12701282 def make (self , key ):
1283+ import neo
1284+ import quantities as pq
1285+ from elephant .spike_train_correlation import spike_time_tiling_coefficient
12711286
12721287 # define parameters
1273- dt = 20 # ms
1288+ dt = 20 # ms
12741289
12751290 # fetch spike times for all units in the clustering
12761291 unit_ids , spike_times = (CuratedClustering .Unit & key ).fetch ('unit' , 'spike_times' , order_by = 'unit' )
12771292
12781293 num_units = len (unit_ids )
12791294 t_stop = (key ['end_time' ] - key ['start_time' ]) / timedelta (milliseconds = 1 ) # in ms
12801295
1281- # REMOVE LATER
1296+ # clip spike times to session duration to guard against edge-case spikes beyond end_time
12821297 spike_times = np .array ([st [st <= (key ['end_time' ] - key ['start_time' ]).total_seconds ()] for st in spike_times ], dtype = object )
12831298
12841299 # loop through unit pairs and calculate STTC
@@ -1309,4 +1324,4 @@ def make(self, key):
13091324 'sttc' : sttc ,
13101325 'spike_time_latencies' : spike_time_latencies ,
13111326 }
1312- )
1327+ )
0 commit comments