Skip to content
This repository was archived by the owner on Jan 21, 2026. It is now read-only.

Commit 154a9ef

Browse files
authored
Merge pull request #6 from ttngu207/master
handling OpenEphys data format - allow for using pregenerated chanMap.mat file
2 parents 5be9ddd + 14e9b03 commit 154a9ef

7 files changed

Lines changed: 321 additions & 283 deletions

File tree

ecephys_spike_sorting/common/utils.py

Lines changed: 66 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
def find_range(x,a,b,option='within'):
14-
14+
1515
"""
1616
Find indices of data within or outside range [a,b]
1717
@@ -53,7 +53,7 @@ def rms(data):
5353
Output:
5454
------
5555
rms_value - float
56-
56+
5757
"""
5858

5959
return np.power(np.mean(np.power(data.astype('float32'),2)),0.5)
@@ -91,21 +91,21 @@ def write_probe_json(output_file, channels, offset, scaling, mask, surface_chann
9191
"""
9292

9393
with open(output_file, 'w') as outfile:
94-
json.dump(
95-
{
96-
'channel' : channels.tolist(),
97-
'offset' : offset.tolist(),
98-
'scaling' : scaling.tolist(),
99-
'mask' : mask.tolist(),
100-
'surface_channel' : surface_channel,
94+
json.dump(
95+
{
96+
'channel' : channels.tolist(),
97+
'offset' : offset.tolist(),
98+
'scaling' : scaling.tolist(),
99+
'mask' : mask.tolist(),
100+
'surface_channel' : surface_channel,
101101
'air_channel' : air_channel,
102102
'vertical_pos' : vertical_pos.tolist(),
103103
'horizontal_pos' : horizontal_pos.tolist()
104104
},
105-
106-
outfile,
107-
indent = 4, separators = (',', ': ')
108-
)
105+
106+
outfile,
107+
indent = 4, separators = (',', ': ')
108+
)
109109

110110
def read_probe_json(input_file):
111111

@@ -131,10 +131,10 @@ def read_probe_json(input_file):
131131
Index of channel at interface between saline/agar and air
132132
133133
"""
134-
134+
135135
with open(input_file) as data_file:
136136
data = json.load(data_file)
137-
137+
138138
scaling = np.array(data['scaling'])
139139
mask = np.array(data['mask'])
140140
offset = np.array(data['offset'])
@@ -163,11 +163,11 @@ def write_cluster_group_tsv(IDs, quality, output_directory, filename = 'cluster_
163163
cluster_group.tsv (written to disk)
164164
165165
"""
166-
166+
167167
df = pd.DataFrame(data={'cluster_id' : IDs, 'group': quality})
168-
168+
169169
print('Saving data...')
170-
170+
171171
df.to_csv(os.path.join(output_directory, filename), sep='\t', index=False)
172172

173173

@@ -197,7 +197,7 @@ def read_cluster_group_tsv(filename):
197197
return cluster_ids, cluster_quality
198198

199199
def read_cluster_amplitude_tsv(filename):
200-
200+
201201
"""
202202
Reads a tab-separated cluster_Amplitude.tsv file from disk
203203
@@ -213,7 +213,7 @@ def read_cluster_amplitude_tsv(filename):
213213
214214
"""
215215
info = np.genfromtxt(filename, dtype='str')
216-
# don't return cluster_ids because those are already read in or
216+
# don't return cluster_ids because those are already read in or
217217
# derived from the spike_clusters.npy file
218218
# cluster_ids = info[1:,0].astype('int')
219219
cluster_amplitude = info[1:,1].astype('float')
@@ -243,10 +243,10 @@ def load(folder, filename):
243243
return np.load(os.path.join(folder, filename))
244244

245245

246-
def load_kilosort_data(folder,
247-
sample_rate = None,
248-
convert_to_seconds = True,
249-
use_master_clock = False,
246+
def load_kilosort_data(folder,
247+
sample_rate = None,
248+
convert_to_seconds = True,
249+
use_master_clock = False,
250250
include_pcs = False,
251251
template_zero_padding= 21):
252252

@@ -278,7 +278,7 @@ def load_kilosort_data(folder,
278278
Template IDs for N spikes
279279
amplitudes : numpy.ndarray (N x 0)
280280
Amplitudes for N spikes
281-
unwhitened_temps : numpy.ndarray (M x samples x channels)
281+
unwhitened_temps : numpy.ndarray (M x samples x channels)
282282
Templates for M units
283283
channel_map : numpy.ndarray
284284
Channels from original data file used for sorting
@@ -303,7 +303,7 @@ def load_kilosort_data(folder,
303303
spike_times = load(folder,'spike_times_master_clock.npy')
304304
else:
305305
spike_times = load(folder,'spike_times.npy')
306-
306+
307307
spike_clusters = load(folder,'spike_clusters.npy')
308308
spike_templates = load(folder, 'spike_templates.npy')
309309
amplitudes = load(folder,'amplitudes.npy')
@@ -312,35 +312,38 @@ def load_kilosort_data(folder,
312312
channel_map = load(folder, 'channel_map.npy')
313313
channel_pos = load(folder, 'channel_positions.npy')
314314

315+
# handles channel_map being read as 2-dimensional
316+
channel_map = np.squeeze(channel_map).astype(int)
317+
315318
if include_pcs:
316319
pc_features = load(folder, 'pc_features.npy')
317320
pc_feature_ind = load(folder, 'pc_feature_ind.npy')
318-
template_features = load(folder, 'template_features.npy')
321+
template_features = load(folder, 'template_features.npy')
322+
319323

320-
321324
templates = templates[:,template_zero_padding:,:] # remove zeros
322325
spike_clusters = np.squeeze(spike_clusters) # fix dimensions
323326
spike_times = np.squeeze(spike_times)# fix dimensions
324327

325328
if convert_to_seconds and sample_rate is not None:
326-
spike_times = spike_times / sample_rate
327-
329+
spike_times = spike_times / sample_rate
330+
328331
unwhitened_temps = np.zeros((templates.shape))
329-
332+
330333
for temp_idx in range(templates.shape[0]):
331-
334+
332335
unwhitened_temps[temp_idx,:,:] = np.dot(np.ascontiguousarray(templates[temp_idx,:,:]),np.ascontiguousarray(unwhitening_mat))
333-
336+
334337
try:
335338
cluster_ids, cluster_quality = read_cluster_group_tsv(os.path.join(folder, 'cluster_group.tsv'))
336339
except OSError:
337340
cluster_ids = np.unique(spike_clusters)
338341
cluster_quality = ['unsorted'] * cluster_ids.size
339-
342+
340343
cluster_amplitude = read_cluster_amplitude_tsv(os.path.join(folder, 'cluster_Amplitude.tsv'))
341-
342-
343-
344+
345+
346+
344347

345348
if not include_pcs:
346349
return spike_times, spike_clusters, spike_templates, amplitudes, unwhitened_temps, \
@@ -357,12 +360,12 @@ def get_spike_depths(spike_clusters, unit_template_ids, pc_features, pc_feature_
357360
Calculates the distance (in microns) of individual spikes from the probe tip
358361
359362
This implementation is based on Matlab code from github.com/cortex-lab/spikes
360-
361-
Needs to be called for a subset of spikes extracted with the majority template
363+
364+
Needs to be called for a subset of spikes extracted with the majority template
362365
This is true for all spikes in data which has not been curated.
363366
Manual merges create clusters that derive from multiple templats, but this
364367
algorthim examines features from a single template -- so we select spikes
365-
for each cluster that were extracted with the majority template before calling
368+
for each cluster that were extracted with the majority template before calling
366369
in metrics.py
367370
368371
Input:
@@ -390,7 +393,7 @@ def get_spike_depths(spike_clusters, unit_template_ids, pc_features, pc_feature_
390393
pc_features_copy = np.squeeze(pc_features_copy[:,0,:])
391394
pc_features_copy[pc_features_copy < 0] = 0
392395
pc_power = pow(pc_features_copy, 2)
393-
396+
394397
spike_feat_ind = pc_feature_ind[unit_template_ids[spike_clusters], :]
395398
spike_feat_ycoord = channel_pos[spike_feat_ind, 1]
396399
spike_depths = np.sum(spike_feat_ycoord * pc_power, 1) / np.sum(pc_power,1)
@@ -410,7 +413,7 @@ def get_spike_amplitudes(spike_templates, templates, amplitudes):
410413
-------
411414
spike_templates : numpy.ndarray (N x 0)
412415
Template IDs for N spikes
413-
templates : numpy.ndarray (M x samples x channels)
416+
templates : numpy.ndarray (M x samples x channels)
414417
Unwhitened templates for M units
415418
amplitudes : numpy.ndarray (N x 0)
416419
Amplitudes for N spikes
@@ -465,7 +468,7 @@ def get_repo_commit_date_and_hash(repo_location):
465468

466469

467470
def printProgressBar(iteration, total, prefix = '', suffix = '', decimals = 0, length = 40, fill = '▒'):
468-
471+
469472
"""
470473
Call in a loop to create terminal progress bar
471474
@@ -491,18 +494,18 @@ def printProgressBar(iteration, total, prefix = '', suffix = '', decimals = 0, l
491494
Outputs:
492495
--------
493496
None
494-
497+
495498
"""
496-
499+
497500
percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
498501
filledLength = int(length * iteration // total)
499502
bar = fill * filledLength + '░' * (length - filledLength)
500503
sys.stdout.write('\r%s %s %s%% %s' % (prefix, bar, percent, suffix))
501504
sys.stdout.flush()
502505

503-
if iteration == total:
506+
if iteration == total:
504507
print()
505-
508+
506509
def catGT_ex_params_from_str(ex_str):
507510
# starting from the comma delimeted CatGT string, return extraction
508511
# parameters.
@@ -511,13 +514,13 @@ def catGT_ex_params_from_str(ex_str):
511514
# <run name>_g<gate index>_tcat.nidq.<ex_name_str>.txt
512515
# for imec SY channels, the file of of extracted edges will be named:
513516
# <run name>_g<gate index>_tcat.imec<probe index>.txt
514-
515-
# CatGT does not allow any spaces wihtin options, but there can be
517+
518+
# CatGT does not allow any spaces wihtin options, but there can be
516519
# spaces between options in the command string, and these are
517-
# appended to the comma delimited string parsed here.
520+
# appended to the comma delimited string parsed here.
518521
# Remove spaces before parsing
519522
ex_str = ex_str.replace(' ','') #replace any spare spaces with commas
520-
523+
521524
eq_pos = ex_str.find('=')
522525
ex_type = ex_str[0:eq_pos] # stream type (SY, iSY, XD, iXD, i)
523526
ex_parts = ex_str[eq_pos+1:].split(',')
@@ -567,19 +570,19 @@ def getSortResults(output_dir, clu_version):
567570
templates = np.load(os.path.join(output_dir, 'templates.npy'))
568571
channel_map = np.load(os.path.join(output_dir, 'channel_map.npy'))
569572
channel_map = np.squeeze(channel_map)
570-
573+
571574
# read in inverse of whitening matrix
572575
w_inv = np.load((os.path.join(output_dir, 'whitening_mat_inv.npy')))
573576
nTemplate = templates.shape[0]
574-
577+
575578
# initialize peak_channels array
576579
peak_channels = np.zeros([nLabel,],'uint32')
577-
578-
580+
581+
579582
# After manual splits or merges, some labels will have spikes found with
580583
# different templats.
581584
# for each label in the list unqLabel, get the most common template
582-
# For that template (nt x nchan), multiply the the transpose (nchan x nt) by inverse of
585+
# For that template (nt x nchan), multiply the the transpose (nchan x nt) by inverse of
583586
# the whitening matrix (nchan x nchan); get max and min along tthe time axis (1)
584587
# to find the peak channel
585588
for i in np.arange(0,nLabel):
@@ -599,18 +602,18 @@ def getSortResults(output_dir, clu_version):
599602
else:
600603
clu_Name = 'clus_Table_' + repr(clu_version) + '.npy'
601604
np.save(os.path.join(output_dir, clu_Name), clus_Table)
602-
605+
603606
return nTemplate, nTot
604607

605608
def getFileVersion(input_filePath):
606-
609+
607610
# arting from the base path name givin in the parameters
608611
# also return name for next file in series = next_file
609612
# If no file exists yet, return curr_file = 'none', new_file = input
610-
613+
611614
next_version = 0;
612615
next_file = input_filePath
613-
616+
614617
if os.path.exists(next_file):
615618
# loop over up to 20 versions with an added _1, _2 ...etc
616619
outPath = pathlib.Path(input_filePath).parent
@@ -620,9 +623,9 @@ def getFileVersion(input_filePath):
620623
nextName = outName + '_' + repr(version_idx) + outExt
621624
next_file = os.path.join(outPath, nextName)
622625
if os.path.exists(next_file) is False:
623-
#break out of loop
626+
#break out of loop
624627
next_version = version_idx
625628
break
626-
627629

628-
return next_file, next_version
630+
631+
return next_file, next_version

0 commit comments

Comments
 (0)