1111
1212
1313def find_range (x ,a ,b ,option = 'within' ):
14-
14+
1515 """
1616 Find indices of data within or outside range [a,b]
1717
@@ -53,7 +53,7 @@ def rms(data):
5353 Output:
5454 ------
5555 rms_value - float
56-
56+
5757 """
5858
5959 return np .power (np .mean (np .power (data .astype ('float32' ),2 )),0.5 )
@@ -91,21 +91,21 @@ def write_probe_json(output_file, channels, offset, scaling, mask, surface_chann
9191 """
9292
9393 with open (output_file , 'w' ) as outfile :
94- json .dump (
95- {
96- 'channel' : channels .tolist (),
97- 'offset' : offset .tolist (),
98- 'scaling' : scaling .tolist (),
99- 'mask' : mask .tolist (),
100- 'surface_channel' : surface_channel ,
94+ json .dump (
95+ {
96+ 'channel' : channels .tolist (),
97+ 'offset' : offset .tolist (),
98+ 'scaling' : scaling .tolist (),
99+ 'mask' : mask .tolist (),
100+ 'surface_channel' : surface_channel ,
101101 'air_channel' : air_channel ,
102102 'vertical_pos' : vertical_pos .tolist (),
103103 'horizontal_pos' : horizontal_pos .tolist ()
104104 },
105-
106- outfile ,
107- indent = 4 , separators = (',' , ': ' )
108- )
105+
106+ outfile ,
107+ indent = 4 , separators = (',' , ': ' )
108+ )
109109
110110def read_probe_json (input_file ):
111111
@@ -131,10 +131,10 @@ def read_probe_json(input_file):
131131 Index of channel at interface between saline/agar and air
132132
133133 """
134-
134+
135135 with open (input_file ) as data_file :
136136 data = json .load (data_file )
137-
137+
138138 scaling = np .array (data ['scaling' ])
139139 mask = np .array (data ['mask' ])
140140 offset = np .array (data ['offset' ])
@@ -163,11 +163,11 @@ def write_cluster_group_tsv(IDs, quality, output_directory, filename = 'cluster_
163163 cluster_group.tsv (written to disk)
164164
165165 """
166-
166+
167167 df = pd .DataFrame (data = {'cluster_id' : IDs , 'group' : quality })
168-
168+
169169 print ('Saving data...' )
170-
170+
171171 df .to_csv (os .path .join (output_directory , filename ), sep = '\t ' , index = False )
172172
173173
@@ -197,7 +197,7 @@ def read_cluster_group_tsv(filename):
197197 return cluster_ids , cluster_quality
198198
199199def read_cluster_amplitude_tsv (filename ):
200-
200+
201201 """
202202 Reads a tab-separated cluster_Amplitude.tsv file from disk
203203
@@ -213,7 +213,7 @@ def read_cluster_amplitude_tsv(filename):
213213
214214 """
215215 info = np .genfromtxt (filename , dtype = 'str' )
216- # don't return cluster_ids because those are already read in or
216+ # don't return cluster_ids because those are already read in or
217217 # derived from the spike_clusters.npy file
218218 # cluster_ids = info[1:,0].astype('int')
219219 cluster_amplitude = info [1 :,1 ].astype ('float' )
@@ -243,10 +243,10 @@ def load(folder, filename):
243243 return np .load (os .path .join (folder , filename ))
244244
245245
246- def load_kilosort_data (folder ,
247- sample_rate = None ,
248- convert_to_seconds = True ,
249- use_master_clock = False ,
246+ def load_kilosort_data (folder ,
247+ sample_rate = None ,
248+ convert_to_seconds = True ,
249+ use_master_clock = False ,
250250 include_pcs = False ,
251251 template_zero_padding = 21 ):
252252
@@ -278,7 +278,7 @@ def load_kilosort_data(folder,
278278 Template IDs for N spikes
279279 amplitudes : numpy.ndarray (N x 0)
280280 Amplitudes for N spikes
281- unwhitened_temps : numpy.ndarray (M x samples x channels)
281+ unwhitened_temps : numpy.ndarray (M x samples x channels)
282282 Templates for M units
283283 channel_map : numpy.ndarray
284284 Channels from original data file used for sorting
@@ -303,7 +303,7 @@ def load_kilosort_data(folder,
303303 spike_times = load (folder ,'spike_times_master_clock.npy' )
304304 else :
305305 spike_times = load (folder ,'spike_times.npy' )
306-
306+
307307 spike_clusters = load (folder ,'spike_clusters.npy' )
308308 spike_templates = load (folder , 'spike_templates.npy' )
309309 amplitudes = load (folder ,'amplitudes.npy' )
@@ -312,35 +312,38 @@ def load_kilosort_data(folder,
312312 channel_map = load (folder , 'channel_map.npy' )
313313 channel_pos = load (folder , 'channel_positions.npy' )
314314
315+ # handles channel_map being read as 2-dimensional
316+ channel_map = np .squeeze (channel_map ).astype (int )
317+
315318 if include_pcs :
316319 pc_features = load (folder , 'pc_features.npy' )
317320 pc_feature_ind = load (folder , 'pc_feature_ind.npy' )
318- template_features = load (folder , 'template_features.npy' )
321+ template_features = load (folder , 'template_features.npy' )
322+
319323
320-
321324 templates = templates [:,template_zero_padding :,:] # remove zeros
322325 spike_clusters = np .squeeze (spike_clusters ) # fix dimensions
323326 spike_times = np .squeeze (spike_times )# fix dimensions
324327
325328 if convert_to_seconds and sample_rate is not None :
326- spike_times = spike_times / sample_rate
327-
329+ spike_times = spike_times / sample_rate
330+
328331 unwhitened_temps = np .zeros ((templates .shape ))
329-
332+
330333 for temp_idx in range (templates .shape [0 ]):
331-
334+
332335 unwhitened_temps [temp_idx ,:,:] = np .dot (np .ascontiguousarray (templates [temp_idx ,:,:]),np .ascontiguousarray (unwhitening_mat ))
333-
336+
334337 try :
335338 cluster_ids , cluster_quality = read_cluster_group_tsv (os .path .join (folder , 'cluster_group.tsv' ))
336339 except OSError :
337340 cluster_ids = np .unique (spike_clusters )
338341 cluster_quality = ['unsorted' ] * cluster_ids .size
339-
342+
340343 cluster_amplitude = read_cluster_amplitude_tsv (os .path .join (folder , 'cluster_Amplitude.tsv' ))
341-
342-
343-
344+
345+
346+
344347
345348 if not include_pcs :
346349 return spike_times , spike_clusters , spike_templates , amplitudes , unwhitened_temps , \
@@ -357,12 +360,12 @@ def get_spike_depths(spike_clusters, unit_template_ids, pc_features, pc_feature_
357360 Calculates the distance (in microns) of individual spikes from the probe tip
358361
359362 This implementation is based on Matlab code from github.com/cortex-lab/spikes
360-
361- Needs to be called for a subset of spikes extracted with the majority template
363+
364+ Needs to be called for a subset of spikes extracted with the majority template
362365 This is true for all spikes in data which has not been curated.
363366 Manual merges create clusters that derive from multiple templats, but this
364367 algorthim examines features from a single template -- so we select spikes
365- for each cluster that were extracted with the majority template before calling
368+ for each cluster that were extracted with the majority template before calling
366369 in metrics.py
367370
368371 Input:
@@ -390,7 +393,7 @@ def get_spike_depths(spike_clusters, unit_template_ids, pc_features, pc_feature_
390393 pc_features_copy = np .squeeze (pc_features_copy [:,0 ,:])
391394 pc_features_copy [pc_features_copy < 0 ] = 0
392395 pc_power = pow (pc_features_copy , 2 )
393-
396+
394397 spike_feat_ind = pc_feature_ind [unit_template_ids [spike_clusters ], :]
395398 spike_feat_ycoord = channel_pos [spike_feat_ind , 1 ]
396399 spike_depths = np .sum (spike_feat_ycoord * pc_power , 1 ) / np .sum (pc_power ,1 )
@@ -410,7 +413,7 @@ def get_spike_amplitudes(spike_templates, templates, amplitudes):
410413 -------
411414 spike_templates : numpy.ndarray (N x 0)
412415 Template IDs for N spikes
413- templates : numpy.ndarray (M x samples x channels)
416+ templates : numpy.ndarray (M x samples x channels)
414417 Unwhitened templates for M units
415418 amplitudes : numpy.ndarray (N x 0)
416419 Amplitudes for N spikes
@@ -465,7 +468,7 @@ def get_repo_commit_date_and_hash(repo_location):
465468
466469
467470def printProgressBar (iteration , total , prefix = '' , suffix = '' , decimals = 0 , length = 40 , fill = '▒' ):
468-
471+
469472 """
470473 Call in a loop to create terminal progress bar
471474
@@ -491,18 +494,18 @@ def printProgressBar(iteration, total, prefix = '', suffix = '', decimals = 0, l
491494 Outputs:
492495 --------
493496 None
494-
497+
495498 """
496-
499+
497500 percent = ("{0:." + str (decimals ) + "f}" ).format (100 * (iteration / float (total )))
498501 filledLength = int (length * iteration // total )
499502 bar = fill * filledLength + '░' * (length - filledLength )
500503 sys .stdout .write ('\r %s %s %s%% %s' % (prefix , bar , percent , suffix ))
501504 sys .stdout .flush ()
502505
503- if iteration == total :
506+ if iteration == total :
504507 print ()
505-
508+
506509def catGT_ex_params_from_str (ex_str ):
507510 # starting from the comma delimeted CatGT string, return extraction
508511 # parameters.
@@ -511,13 +514,13 @@ def catGT_ex_params_from_str(ex_str):
511514 # <run name>_g<gate index>_tcat.nidq.<ex_name_str>.txt
512515 # for imec SY channels, the file of of extracted edges will be named:
513516 # <run name>_g<gate index>_tcat.imec<probe index>.txt
514-
515- # CatGT does not allow any spaces wihtin options, but there can be
517+
518+ # CatGT does not allow any spaces wihtin options, but there can be
516519 # spaces between options in the command string, and these are
517- # appended to the comma delimited string parsed here.
520+ # appended to the comma delimited string parsed here.
518521 # Remove spaces before parsing
519522 ex_str = ex_str .replace (' ' ,'' ) #replace any spare spaces with commas
520-
523+
521524 eq_pos = ex_str .find ('=' )
522525 ex_type = ex_str [0 :eq_pos ] # stream type (SY, iSY, XD, iXD, i)
523526 ex_parts = ex_str [eq_pos + 1 :].split (',' )
@@ -567,19 +570,19 @@ def getSortResults(output_dir, clu_version):
567570 templates = np .load (os .path .join (output_dir , 'templates.npy' ))
568571 channel_map = np .load (os .path .join (output_dir , 'channel_map.npy' ))
569572 channel_map = np .squeeze (channel_map )
570-
573+
571574 # read in inverse of whitening matrix
572575 w_inv = np .load ((os .path .join (output_dir , 'whitening_mat_inv.npy' )))
573576 nTemplate = templates .shape [0 ]
574-
577+
575578 # initialize peak_channels array
576579 peak_channels = np .zeros ([nLabel ,],'uint32' )
577-
578-
580+
581+
579582 # After manual splits or merges, some labels will have spikes found with
580583 # different templats.
581584 # for each label in the list unqLabel, get the most common template
582- # For that template (nt x nchan), multiply the the transpose (nchan x nt) by inverse of
585+ # For that template (nt x nchan), multiply the the transpose (nchan x nt) by inverse of
583586 # the whitening matrix (nchan x nchan); get max and min along tthe time axis (1)
584587 # to find the peak channel
585588 for i in np .arange (0 ,nLabel ):
@@ -599,18 +602,18 @@ def getSortResults(output_dir, clu_version):
599602 else :
600603 clu_Name = 'clus_Table_' + repr (clu_version ) + '.npy'
601604 np .save (os .path .join (output_dir , clu_Name ), clus_Table )
602-
605+
603606 return nTemplate , nTot
604607
605608def getFileVersion (input_filePath ):
606-
609+
607610 # arting from the base path name givin in the parameters
608611 # also return name for next file in series = next_file
609612 # If no file exists yet, return curr_file = 'none', new_file = input
610-
613+
611614 next_version = 0 ;
612615 next_file = input_filePath
613-
616+
614617 if os .path .exists (next_file ):
615618 # loop over up to 20 versions with an added _1, _2 ...etc
616619 outPath = pathlib .Path (input_filePath ).parent
@@ -620,9 +623,9 @@ def getFileVersion(input_filePath):
620623 nextName = outName + '_' + repr (version_idx ) + outExt
621624 next_file = os .path .join (outPath , nextName )
622625 if os .path .exists (next_file ) is False :
623- #break out of loop
626+ #break out of loop
624627 next_version = version_idx
625628 break
626-
627629
628- return next_file , next_version
630+
631+ return next_file , next_version
0 commit comments