Skip to content
This repository was archived by the owner on Jan 21, 2026. It is now read-only.

Commit a03fb2c

Browse files
authored
Merge pull request #4 from ttngu207/master
shift the configuration of packages paths to `createInputJson`
2 parents dc40b11 + 7bae350 commit a03fb2c

2 files changed

Lines changed: 58 additions & 75 deletions

File tree

ecephys_spike_sorting/scripts/create_input_json.py

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os, io, json, sys
2+
from dotenv import load_dotenv
23

34
if sys.platform == 'linux':
45
import pwd
@@ -26,8 +27,8 @@ def createInputJson(
2627
catGTPath=None,
2728
tPrime_path=None,
2829
cWaves_path=None,
29-
kilosort_output_tmp=None,
30-
npx_directory=None,
30+
kilosort_output_tmp=None,
31+
npx_directory=None,
3132
continuous_file = None,
3233
spikeGLX_data=True,
3334
input_meta_path=None,
@@ -54,12 +55,12 @@ def createInputJson(
5455
tPrime_3A = False,
5556
toStream_path_3A = None,
5657
fromStream_list_3A = None,
57-
ks_remDup = 0,
58+
ks_remDup = 0,
5859
ks_finalSplits = 1,
5960
ks_labelGood = 1,
6061
ks_saveRez = 1,
6162
ks_copy_fproc = 0,
62-
ks_minfr_goodchannels = 0.1,
63+
ks_minfr_goodchannels = 0.1,
6364
ks_whiteningRadius_um = 163,
6465
ks_Th = '[10,4]',
6566
ks_CSBseed = 1,
@@ -81,22 +82,35 @@ def createInputJson(
8182

8283
# KS 3.0 does not yet output pcs.
8384
include_pcs = KS2ver != '3.0'
84-
85+
8586
#npy_matlab_repository = r'C:\Users\labadmin\Documents\jic\npy-matlab-master'
8687
#catGTPath = r'C:\Users\labadmin\Documents\jic\CatGT-win'
8788
#tPrime_path=r'C:\Users\labadmin\Documents\jic\TPrime-win'
8889
#cWaves_path=r'C:\Users\labadmin\Documents\jic\C_Waves-win'
8990
# these are passed through arguements now
9091

91-
92+
9293
# for config files and kilosort working space
93-
# kilosort_output_tmp = r'D:\kilosort_datatemp'
94+
# kilosort_output_tmp = r'D:\kilosort_datatemp'
9495
# this is passed through arguements now
95-
96+
97+
dot_env_path = "config/sglx_process_probe.json"
98+
if os.path.exists(dot_env_path):
99+
load_dotenv(dot_env_path)
100+
101+
ecephys_directory = ecephys_directory or os.getenv('ecephys_directory')
102+
kilosort_repository = kilosort_repository or os.getenv('kilosort_repository')
103+
KS2ver = KS2ver or os.getenv('KS2ver')
104+
npy_matlab_repository = npy_matlab_repository or os.getenv('npy_matlab_repository')
105+
catGTPath = catGTPath or os.getenv('catGTPath')
106+
tPrime_path = tPrime_path or os.getenv('tPrime_path')
107+
cWaves_path = cWaves_path or os.getenv('cWaves_path')
108+
kilosort_output_tmp = kilosort_output_tmp or os.getenv('kilosort_output_tmp')
109+
96110
# derived directory names
97-
111+
98112
modules_directory = os.path.join(ecephys_directory,'modules')
99-
113+
100114
if kilosort_output_directory is None \
101115
and extracted_data_directory is None \
102116
and npx_directory is None:
@@ -105,49 +119,49 @@ def createInputJson(
105119

106120
#default ephys params. For spikeGLX, these get replaced by values read from metadata
107121
sample_rate = 30000
108-
num_channels = 385
122+
num_channels = 385
109123
reference_channels = [191]
110124
uVPerBit = 2.34375
111125
acq_system = 'PXI'
112-
113-
126+
127+
114128
if spikeGLX_data:
115129
# location of the raw data is the continuous file passed from script
116130
# metadata file should be located in same directory
117-
#
131+
#
118132
# kilosort output will be put in the same directory as the input raw data,
119133
# set in kilosort_output_directory passed from script
120134
# kilososrt postprocessing (duplicate removal) and identification of noise
121135
# clusters will act on phy output in the kilosort output directory
122136
#
123-
#
137+
#
124138
if input_meta_path is not None:
125-
probe_type, sample_rate, num_channels, uVPerBit = SpikeGLX_utils.EphysParams(input_meta_path)
139+
probe_type, sample_rate, num_channels, uVPerBit = SpikeGLX_utils.EphysParams(input_meta_path)
126140
print('SpikeGLX params read from meta')
127141
print('probe type: {:s}, sample_rate: {:.5f}, num_channels: {:d}, uVPerBit: {:.4f}'.format\
128142
(probe_type, sample_rate, num_channels, uVPerBit))
129143
#print('kilosort output directory: ' + kilosort_output_directory )
130144

131-
145+
132146
else:
133147
print('currently only supporting spikeGLX data')
134-
135148

136-
149+
150+
137151

138152
# geometry params by probe type. expand the dictoionaries to add types
139153
# vertical probe pitch vs probe type
140-
vpitch = {'3A': 20, 'NP1': 20, 'NP21': 15, 'NP24': 15, 'NP1100': 6}
141-
hpitch = {'3A': 32, 'NP1': 32, 'NP21': 32, 'NP24': 32, 'NP1100': 6}
142-
nColumn = {'3A': 2, 'NP1': 2, 'NP21': 2, 'NP24': 2, 'NP1100': 8}
143-
144-
154+
vpitch = {'3A': 20, 'NP1': 20, 'NP21': 15, 'NP24': 15, 'NP1100': 6}
155+
hpitch = {'3A': 32, 'NP1': 32, 'NP21': 32, 'NP24': 32, 'NP1100': 6}
156+
nColumn = {'3A': 2, 'NP1': 2, 'NP21': 2, 'NP24': 2, 'NP1100': 8}
157+
158+
145159
# CatGT needs the inner and outer redii for local common average referencing
146160
# specified in sites
147161
catGT_loccar_min_sites = int(round(catGT_loccar_min_um/vpitch.get(probe_type)))
148162
catGT_loccar_max_sites = int(round(catGT_loccar_max_um/vpitch.get(probe_type)))
149163
# print('loccar min: ' + repr(catGT_loccar_min_sites))
150-
164+
151165
# whiteningRange is the number of sites used for whitening in KIlosort
152166
# preprocessing. Calculate the number of sites within the user-specified
153167
# whitening radius for this probe geometery
@@ -156,24 +170,24 @@ def createInputJson(
156170
ks_whiteningRange = int(round(2*nrows*nColumn.get(probe_type)))
157171
if ks_whiteningRange > 384:
158172
ks_whiteningRange = 384
159-
173+
160174
# nNeighbors is the number of sites kilosort includes in a template.
161175
# Calculate the number of sites within that radisu.
162176
nrows = np.sqrt((np.square(ks_templateRadius_um) - np.square(hpitch.get(probe_type))))/vpitch.get(probe_type)
163177
ks_nNeighbors = int(round(2*nrows*nColumn.get(probe_type)))
164178
if ks_nNeighbors > 64:
165179
ks_nNeighbors = 64 #max allowed in CUDA
166180
# print('ks_nNeighbors: ' + repr(ks_nNeighbors))
167-
181+
168182
c_waves_radius_sites = int(round(c_Waves_snr_um/vpitch.get(probe_type)))
169183

170184
# Create string designating temporary output file for KS2 (gets inserted into KS2 config.m file)
171185
fproc = os.path.join(kilosort_output_tmp,'temp_wh.dat') # full path for temp whitened data file
172186
fproc_forward_slash = fproc.replace('\\','/')
173187
fproc_str = "'" + fproc_forward_slash + "'"
174-
175188

176-
189+
190+
177191
dictionary = \
178192
{
179193

@@ -193,7 +207,7 @@ def createInputJson(
193207
"waveform_metrics" : {
194208
"waveform_metrics_file" : os.path.join(kilosort_output_directory, 'waveform_metrics.csv')
195209
},
196-
210+
197211
"cluster_metrics" : {
198212
"cluster_metrics_file" : os.path.join(kilosort_output_directory, 'metrics.csv')
199213
},
@@ -210,7 +224,7 @@ def createInputJson(
210224
"lfp_band_file" : os.path.join(extracted_data_directory, 'continuous', 'Neuropix-' + acq_system + '-100.1', 'continuous.dat'),
211225
"reorder_lfp_channels" : True,
212226
"cluster_group_file_name" : 'cluster_group.tsv'
213-
},
227+
},
214228

215229
"extract_from_npx_params" : {
216230
"npx_directory": npx_directory,
@@ -235,7 +249,7 @@ def createInputJson(
235249
"time_interval" : 5,
236250
"skip_s_per_pass" : 10,
237251
"start_time" : 10
238-
},
252+
},
239253

240254
"median_subtraction_params" : {
241255
"median_subtraction_executable": "C:\\Users\\svc_neuropix\\Documents\\GitHub\\spikebandmediansubtraction\\Builds\\VisualStudio2013\\Release\\SpikeBandMedianSubtraction.exe",
@@ -289,19 +303,19 @@ def createInputJson(
289303
},
290304

291305
"mean_waveform_params" : {
292-
306+
293307
"mean_waveforms_file" : os.path.join(kilosort_output_directory, 'mean_waveforms.npy'),
294308
"samples_per_spike" : 82,
295309
"pre_samples" : 20,
296310
"num_epochs" : 1, #epochs not implemented for c_waves
297311
"spikes_per_epoch" : 1000,
298312
"spread_threshold" : 0.12,
299-
"site_range" : 16,
313+
"site_range" : 16,
300314
"cWaves_path" : cWaves_path,
301315
"use_C_Waves" : True,
302-
"snr_radius" : c_waves_radius_sites
316+
"snr_radius" : c_waves_radius_sites
303317
},
304-
318+
305319

306320
"noise_waveform_params" : {
307321
"classifier_path" : os.path.join(modules_directory, 'noise_templates', 'rf_classifier.pkl'),
@@ -321,7 +335,7 @@ def createInputJson(
321335
"drift_metrics_min_spikes_per_interval" : 10,
322336
"include_pcs" : include_pcs
323337
},
324-
338+
325339
"catGT_helper_params" : {
326340
"run_name" : catGT_run_name,
327341
"gate_string" : gate_string,
@@ -345,12 +359,12 @@ def createInputJson(
345359
"tPrime_3A" : tPrime_3A,
346360
"toStream_path_3A" : toStream_path_3A,
347361
"fromStream_list_3A" : fromStream_list_3A
348-
},
349-
362+
},
363+
350364
"psth_events": {
351365
"event_ex_param_str": event_ex_param_str
352366
}
353-
367+
354368
}
355369

356370
with io.open(output_file, 'w', encoding='utf-8') as f:

ecephys_spike_sorting/scripts/sglx_process_probe.py

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import sys
33
import subprocess
44
import json
5-
from dotenv import load_dotenv
65

76
sys.path.append(os.path.dirname(__file__))
87
from helpers import SpikeGLX_utils
@@ -24,14 +23,7 @@ def run_probe(prb, json_directory, npx_directory,
2423
catGT_loccar_min_um, catGT_loccar_max_um,
2524
catGT_cmd_string,
2625
ks_Th, refPerMS,
27-
ecephys_directory=None,
28-
kilosort_repository=None,
2926
KS2ver=None,
30-
npy_matlab_repository=None,
31-
catGTPath=None,
32-
tPrime_path=None,
33-
cWaves_path=None,
34-
kilosort_output_tmp=None,
3527
ks_remDup=0,
3628
ks_saveRez=1,
3729
ks_copy_fproc=0,
@@ -41,18 +33,6 @@ def run_probe(prb, json_directory, npx_directory,
4133
c_Waves_snr_um=160,
4234
ni_present=True,
4335
ni_extract_string=None):
44-
# load external tool path from .env if not given from .json
45-
dot_env_path = "config/sglx_process_probe.json"
46-
if os.path.exists(dot_env_path):
47-
load_dotenv(dot_env_path)
48-
ecephys_directory=ecephys_directory or os.getenv('ecephys_directory')
49-
kilosort_repository=kilosort_repository or os.getenv('kilosort_repository')
50-
KS2ver=KS2ver or os.getenv('KS2ver')
51-
npy_matlab_repository=npy_matlab_repository or os.getenv('npy_matlab_repository')
52-
catGTPath=catGTPath or os.getenv('catGTPath')
53-
tPrime_path=tPrime_path or os.getenv('tPrime_path')
54-
cWaves_path=cWaves_path or os.getenv('cWaves_path')
55-
kilosort_output_tmp=kilosort_output_tmp or os.getenv('kilosort_output_tmp')
5636

5737
# build path to the first probe folder; look into that folder
5838
# to determine the range of trials if the user specified t limits as
@@ -92,14 +72,7 @@ def run_probe(prb, json_directory, npx_directory,
9272
input_meta_fullpath = os.path.join(input_data_directory, metaName)
9373
print(input_meta_fullpath)
9474
createInputJson(catGT_input_json,
95-
ecephys_directory=ecephys_directory,
96-
kilosort_repository=kilosort_repository,
9775
KS2ver=KS2ver,
98-
npy_matlab_repository=npy_matlab_repository,
99-
catGTPath=catGTPath,
100-
tPrime_path=tPrime_path,
101-
cWaves_path=cWaves_path,
102-
kilosort_output_tmp=kilosort_output_tmp,
10376
npx_directory=npx_directory,
10477
continuous_file=continuous_file,
10578
kilosort_output_directory=catGT_dest,
@@ -136,14 +109,7 @@ def run_probe(prb, json_directory, npx_directory,
136109
print(continuous_file)
137110
print('ks_Th: ' + repr(ks_Th) + ' ,refPerMS: ' + repr(refPerMS))
138111
createInputJson(module_input_json,
139-
ecephys_directory=ecephys_directory,
140-
kilosort_repository=kilosort_repository,
141112
KS2ver=KS2ver,
142-
npy_matlab_repository=npy_matlab_repository,
143-
catGTPath=catGTPath,
144-
tPrime_path=tPrime_path,
145-
cWaves_path=cWaves_path,
146-
kilosort_output_tmp=kilosort_output_tmp,
147113
npx_directory=npx_directory,
148114
continuous_file=continuous_file,
149115
spikeGLX_data=True,
@@ -188,11 +154,14 @@ def run_probe(prb, json_directory, npx_directory,
188154
logFullPath)
189155

190156

191-
#if __name__ == '__main__':
192157
def main():
193158
json_fp = sys.argv[1]
194159

195160
with open(json_fp) as f:
196161
kwargs = json.load(f)
197162

198163
run_probe(**kwargs)
164+
165+
166+
if __name__ == '__main__':
167+
main()

0 commit comments

Comments
 (0)