11import os , io , json , sys
2+ from dotenv import load_dotenv
23
34if sys .platform == 'linux' :
45 import pwd
@@ -26,8 +27,8 @@ def createInputJson(
2627 catGTPath = None ,
2728 tPrime_path = None ,
2829 cWaves_path = None ,
29- kilosort_output_tmp = None ,
30- npx_directory = None ,
30+ kilosort_output_tmp = None ,
31+ npx_directory = None ,
3132 continuous_file = None ,
3233 spikeGLX_data = True ,
3334 input_meta_path = None ,
@@ -54,12 +55,12 @@ def createInputJson(
5455 tPrime_3A = False ,
5556 toStream_path_3A = None ,
5657 fromStream_list_3A = None ,
57- ks_remDup = 0 ,
58+ ks_remDup = 0 ,
5859 ks_finalSplits = 1 ,
5960 ks_labelGood = 1 ,
6061 ks_saveRez = 1 ,
6162 ks_copy_fproc = 0 ,
62- ks_minfr_goodchannels = 0.1 ,
63+ ks_minfr_goodchannels = 0.1 ,
6364 ks_whiteningRadius_um = 163 ,
6465 ks_Th = '[10,4]' ,
6566 ks_CSBseed = 1 ,
@@ -81,22 +82,35 @@ def createInputJson(
8182
8283 # KS 3.0 does not yet output pcs.
8384 include_pcs = KS2ver != '3.0'
84-
85+
8586 #npy_matlab_repository = r'C:\Users\labadmin\Documents\jic\npy-matlab-master'
8687 #catGTPath = r'C:\Users\labadmin\Documents\jic\CatGT-win'
8788 #tPrime_path=r'C:\Users\labadmin\Documents\jic\TPrime-win'
8889 #cWaves_path=r'C:\Users\labadmin\Documents\jic\C_Waves-win'
8990 # these are passed through arguements now
9091
91-
92+
9293 # for config files and kilosort working space
93- # kilosort_output_tmp = r'D:\kilosort_datatemp'
94+ # kilosort_output_tmp = r'D:\kilosort_datatemp'
9495 # this is passed through arguements now
95-
96+
97+ dot_env_path = "config/sglx_process_probe.json"
98+ if os .path .exists (dot_env_path ):
99+ load_dotenv (dot_env_path )
100+
101+ ecephys_directory = ecephys_directory or os .getenv ('ecephys_directory' )
102+ kilosort_repository = kilosort_repository or os .getenv ('kilosort_repository' )
103+ KS2ver = KS2ver or os .getenv ('KS2ver' )
104+ npy_matlab_repository = npy_matlab_repository or os .getenv ('npy_matlab_repository' )
105+ catGTPath = catGTPath or os .getenv ('catGTPath' )
106+ tPrime_path = tPrime_path or os .getenv ('tPrime_path' )
107+ cWaves_path = cWaves_path or os .getenv ('cWaves_path' )
108+ kilosort_output_tmp = kilosort_output_tmp or os .getenv ('kilosort_output_tmp' )
109+
96110 # derived directory names
97-
111+
98112 modules_directory = os .path .join (ecephys_directory ,'modules' )
99-
113+
100114 if kilosort_output_directory is None \
101115 and extracted_data_directory is None \
102116 and npx_directory is None :
@@ -105,49 +119,49 @@ def createInputJson(
105119
106120 #default ephys params. For spikeGLX, these get replaced by values read from metadata
107121 sample_rate = 30000
108- num_channels = 385
122+ num_channels = 385
109123 reference_channels = [191 ]
110124 uVPerBit = 2.34375
111125 acq_system = 'PXI'
112-
113-
126+
127+
114128 if spikeGLX_data :
115129 # location of the raw data is the continuous file passed from script
116130 # metadata file should be located in same directory
117- #
131+ #
118132 # kilosort output will be put in the same directory as the input raw data,
119133 # set in kilosort_output_directory passed from script
120134 # kilososrt postprocessing (duplicate removal) and identification of noise
121135 # clusters will act on phy output in the kilosort output directory
122136 #
123- #
137+ #
124138 if input_meta_path is not None :
125- probe_type , sample_rate , num_channels , uVPerBit = SpikeGLX_utils .EphysParams (input_meta_path )
139+ probe_type , sample_rate , num_channels , uVPerBit = SpikeGLX_utils .EphysParams (input_meta_path )
126140 print ('SpikeGLX params read from meta' )
127141 print ('probe type: {:s}, sample_rate: {:.5f}, num_channels: {:d}, uVPerBit: {:.4f}' .format \
128142 (probe_type , sample_rate , num_channels , uVPerBit ))
129143 #print('kilosort output directory: ' + kilosort_output_directory )
130144
131-
145+
132146 else :
133147 print ('currently only supporting spikeGLX data' )
134-
135148
136-
149+
150+
137151
138152 # geometry params by probe type. expand the dictoionaries to add types
139153 # vertical probe pitch vs probe type
140- vpitch = {'3A' : 20 , 'NP1' : 20 , 'NP21' : 15 , 'NP24' : 15 , 'NP1100' : 6 }
141- hpitch = {'3A' : 32 , 'NP1' : 32 , 'NP21' : 32 , 'NP24' : 32 , 'NP1100' : 6 }
142- nColumn = {'3A' : 2 , 'NP1' : 2 , 'NP21' : 2 , 'NP24' : 2 , 'NP1100' : 8 }
143-
144-
154+ vpitch = {'3A' : 20 , 'NP1' : 20 , 'NP21' : 15 , 'NP24' : 15 , 'NP1100' : 6 }
155+ hpitch = {'3A' : 32 , 'NP1' : 32 , 'NP21' : 32 , 'NP24' : 32 , 'NP1100' : 6 }
156+ nColumn = {'3A' : 2 , 'NP1' : 2 , 'NP21' : 2 , 'NP24' : 2 , 'NP1100' : 8 }
157+
158+
145159 # CatGT needs the inner and outer redii for local common average referencing
146160 # specified in sites
147161 catGT_loccar_min_sites = int (round (catGT_loccar_min_um / vpitch .get (probe_type )))
148162 catGT_loccar_max_sites = int (round (catGT_loccar_max_um / vpitch .get (probe_type )))
149163 # print('loccar min: ' + repr(catGT_loccar_min_sites))
150-
164+
151165 # whiteningRange is the number of sites used for whitening in KIlosort
152166 # preprocessing. Calculate the number of sites within the user-specified
153167 # whitening radius for this probe geometery
@@ -156,24 +170,24 @@ def createInputJson(
156170 ks_whiteningRange = int (round (2 * nrows * nColumn .get (probe_type )))
157171 if ks_whiteningRange > 384 :
158172 ks_whiteningRange = 384
159-
173+
160174 # nNeighbors is the number of sites kilosort includes in a template.
161175 # Calculate the number of sites within that radisu.
162176 nrows = np .sqrt ((np .square (ks_templateRadius_um ) - np .square (hpitch .get (probe_type ))))/ vpitch .get (probe_type )
163177 ks_nNeighbors = int (round (2 * nrows * nColumn .get (probe_type )))
164178 if ks_nNeighbors > 64 :
165179 ks_nNeighbors = 64 #max allowed in CUDA
166180 # print('ks_nNeighbors: ' + repr(ks_nNeighbors))
167-
181+
168182 c_waves_radius_sites = int (round (c_Waves_snr_um / vpitch .get (probe_type )))
169183
170184 # Create string designating temporary output file for KS2 (gets inserted into KS2 config.m file)
171185 fproc = os .path .join (kilosort_output_tmp ,'temp_wh.dat' ) # full path for temp whitened data file
172186 fproc_forward_slash = fproc .replace ('\\ ' ,'/' )
173187 fproc_str = "'" + fproc_forward_slash + "'"
174-
175188
176-
189+
190+
177191 dictionary = \
178192 {
179193
@@ -193,7 +207,7 @@ def createInputJson(
193207 "waveform_metrics" : {
194208 "waveform_metrics_file" : os .path .join (kilosort_output_directory , 'waveform_metrics.csv' )
195209 },
196-
210+
197211 "cluster_metrics" : {
198212 "cluster_metrics_file" : os .path .join (kilosort_output_directory , 'metrics.csv' )
199213 },
@@ -210,7 +224,7 @@ def createInputJson(
210224 "lfp_band_file" : os .path .join (extracted_data_directory , 'continuous' , 'Neuropix-' + acq_system + '-100.1' , 'continuous.dat' ),
211225 "reorder_lfp_channels" : True ,
212226 "cluster_group_file_name" : 'cluster_group.tsv'
213- },
227+ },
214228
215229 "extract_from_npx_params" : {
216230 "npx_directory" : npx_directory ,
@@ -235,7 +249,7 @@ def createInputJson(
235249 "time_interval" : 5 ,
236250 "skip_s_per_pass" : 10 ,
237251 "start_time" : 10
238- },
252+ },
239253
240254 "median_subtraction_params" : {
241255 "median_subtraction_executable" : "C:\\ Users\\ svc_neuropix\\ Documents\\ GitHub\\ spikebandmediansubtraction\\ Builds\\ VisualStudio2013\\ Release\\ SpikeBandMedianSubtraction.exe" ,
@@ -289,19 +303,19 @@ def createInputJson(
289303 },
290304
291305 "mean_waveform_params" : {
292-
306+
293307 "mean_waveforms_file" : os .path .join (kilosort_output_directory , 'mean_waveforms.npy' ),
294308 "samples_per_spike" : 82 ,
295309 "pre_samples" : 20 ,
296310 "num_epochs" : 1 , #epochs not implemented for c_waves
297311 "spikes_per_epoch" : 1000 ,
298312 "spread_threshold" : 0.12 ,
299- "site_range" : 16 ,
313+ "site_range" : 16 ,
300314 "cWaves_path" : cWaves_path ,
301315 "use_C_Waves" : True ,
302- "snr_radius" : c_waves_radius_sites
316+ "snr_radius" : c_waves_radius_sites
303317 },
304-
318+
305319
306320 "noise_waveform_params" : {
307321 "classifier_path" : os .path .join (modules_directory , 'noise_templates' , 'rf_classifier.pkl' ),
@@ -321,7 +335,7 @@ def createInputJson(
321335 "drift_metrics_min_spikes_per_interval" : 10 ,
322336 "include_pcs" : include_pcs
323337 },
324-
338+
325339 "catGT_helper_params" : {
326340 "run_name" : catGT_run_name ,
327341 "gate_string" : gate_string ,
@@ -345,12 +359,12 @@ def createInputJson(
345359 "tPrime_3A" : tPrime_3A ,
346360 "toStream_path_3A" : toStream_path_3A ,
347361 "fromStream_list_3A" : fromStream_list_3A
348- },
349-
362+ },
363+
350364 "psth_events" : {
351365 "event_ex_param_str" : event_ex_param_str
352366 }
353-
367+
354368 }
355369
356370 with io .open (output_file , 'w' , encoding = 'utf-8' ) as f :
0 commit comments