@@ -58,8 +58,10 @@ def format_axis(ax):
5858fsets = ['vel' , 'acc' , 'labels' , 'ee' , 'eepolar' ,]
5959#decoding_fsets = []
6060#decoding_fsets = ['ee', 'eepolar', 'vel', 'acc', 'labels']
61- #decoding_fsets = ['ee', 'eepolar', 'vel', 'acc']
62- decoding_fsets = ['ee' ]
61+ #decoding_fsets = ['eepolar', 'vel', 'acc', 'labels']
62+ #decoding_fsets = ['eepolar', 'vel', 'acc']
63+ decoding_fsets = ['ee' , 'eepolar' , 'vel' , 'acc' ]
64+ #decoding_fsets = ['ee']
6365#decoding_fsets = ['labels']
6466orientations = ['hor' , 'vert' ]
6567uniquezs = list (np .array ([- 45. , - 42. , - 39. , - 36. , - 33. , - 30. , - 27. , - 24. , - 21. , - 18. , - 15. ,
@@ -242,13 +244,14 @@ def regressiontaskfolder(self, model, analysis = None):
242244
243245runinfo = RunInfo ({'expid' : 402 , #internal experiment id
244246 #'datafraction': 0.05,
245- 'datafraction' : 0.1 ,
247+ 'datafraction' : 'auto' ,
246248 #'datafraction': 0.5, #fraction (0,1] or 'auto'
247249 'randomseed' : 2000 ,
248250 'randomseed_traintest' : 42 ,
249251 'dirr2threshold' : 0.2 ,
250252 'verbose' : 2 , #0 (least), 1, 2 (most)
251- 'model_experiment_id' : 22 , #as per Pranav's model generation, int or 'auto'
253+ #'model_experiment_id': 22, #as per Pranav's model generation, int or 'auto'
254+ 'model_experiment_id' : 'auto' ,
252255 'basefolder' : basefolder ,
253256 'batchsize' : 100 , #for layer representation generation
254257 'default_run' : True , #only variable that is 'trial'-dependent,
@@ -269,6 +272,9 @@ def regressiontaskfolder(self, model, analysis = None):
269272 318 : {'datafraction' : 0.5 , 'model_experiment_id' : 22 }, #ST Decoding
270273 319 : {'datafraction' : 0.1 , 'model_experiment_id' : 32 }, #LSTM Decoding
271274 320 : {'datafraction' : 0.05 , 'model_experiment_id' : 32 }, #LSTM CKA only (no tuning curves)
275+ 402 : {'datafraction' : 0.1 , 'model_experiment_id' : 22 }, # S
276+ 403 : {'datafraction' : 0.1 , 'model_experiment_id' : 22 }, # ST
277+ 404 : {'datafraction' : 0.1 , 'model_experiment_id' : 32 }, # ST
272278}
273279
274280# %% SAVE OUTPUTS AND RUN ANALYSIS
@@ -370,7 +376,7 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
370376 print ("beginning body..." )
371377
372378 startmodel = 0
373- startrun = 1
379+ startrun = 1 ### 0 for full run
374380 startcontrol = False
375381 startior = 0
376382 startheight = 'all'
@@ -380,6 +386,7 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
380386 #endheight = 'after_all' #['after_all', None]
381387 endheight = None
382388
389+ ## Utility functions for late start
383390 runmodels = False
384391 runruns = False
385392 runtype = False
@@ -482,7 +489,7 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
482489
483490 #if(not os.path.exists(runinfo.resultsfolder(model_to_analyse, fset))):
484491 #if(default_run):
485- if (False ):
492+ if (True ):
486493 #if(False):
487494
488495 print ('running %s analysis (fitting tuning curves) for model %s plane %s...' % (fset , modelname , runinfo .planestring ()))
@@ -540,7 +547,8 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
540547
541548 #if(not os.path.exists(runinfo.analysisfolder(model_to_analyse, 'polar_tcs'))):
542549 #if(True):
543- if (default_run ):
550+ #if(default_run):
551+ if (runinfo .planestring () == 'horall' ):
544552 polar_tcs_main (model_to_analyse , runinfo_to_analyse )
545553 else :
546554 print ('polar tc plots already exist' )
@@ -558,7 +566,8 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
558566 print ('generating preferred direction histograms for model %s plane %s...' % (modelname , runinfo .planestring ()))
559567 #if(not os.path.exists(runinfo.analysisfolder(model_to_analyse, 'prefdir'))):
560568 #if(True):
561- if (default_run ):
569+ if (runinfo .planestring () == 'horall' ):
570+
562571 prefdir_main (model_to_analyse , runinfo_to_analyse )
563572 else :
564573 print ('pref dir plots already exist' )
@@ -586,9 +595,9 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
586595
587596 if (control ):
588597 #if(not os.path.exists(runinfo.analysisfolder(trainedmodel, 'comp_violin'))):
589- if (default_run ):
598+ # if(default_run):
590599 #if(True):
591- # if(runinfo.planestring() == 'horall'):
600+ if (runinfo .planestring () == 'horall' ):
592601 print ('saving comparison violin plot for model %s plane %s...' % (modelname , runinfo .planestring ()))
593602 comp_violin_main (trainedmodel , model_to_analyse , runinfo )
594603 else :
@@ -625,9 +634,9 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
625634 print ("doing trreg cka for model %s plane %s ... " % (modelname , runinfo .planestring ()))
626635 rsa_main (model_to_analyse , regressionmodel , runinfo , trreg = True )
627636
628- if (True ):
629- print ("making new violin plots" )
630- comp_tr_reg_violin_main_newplots (trainedmodel , regressionmodel , runinfo )
637+ if (runinfo . planestring () == 'horall' ):
638+ print ("making new violin plots" )
639+ comp_tr_reg_violin_main_newplots (trainedmodel , regressionmodel , runinfo )
631640
632641 if (i == 5 ):
633642 if (control ):
@@ -637,12 +646,14 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
637646 else :
638647 print ('skipping comparisons' )
639648
640- # if(runinfo.planestring() == 'horall'):
641- if (True ):
649+ if (runinfo .planestring () == 'horall' ):
650+ # if(True):
642651 print ('combining rsa results for all models' )
643652 #if(not os.path.exists(runinfo.sharedanalysisfolder(trainedmodel, 'rsa'))):
644653 if (True ):
645654 #if(default_run):
655+ #if(runinfo.planestring() == 'horall'):
656+
646657 rsa_models_comp (model , runinfo )
647658 else :
648659 print ('rsa models comp already completed' )
@@ -651,12 +662,16 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
651662 if (True ):
652663 #if(False):
653664 print ('starting comparisons_tr_reg_main' )
654- comparisons_tr_reg_main (model , regressionmodel , runinfo )
665+ #comparisons_tr_reg_main(model, regressionmodel, runinfo)
666+ comparisons_tr_reg_main (trainedmodel , regressionmodel , runinfo )
655667
668+ '''
656669 #if(False):
657670 if(True):
658671 print("combining trreg RSA results for all models")
659- rsa_models_comp (model , runinfo , trreg = True )
672+ #rsa_models_comp(model, runinfo, trreg=True)
673+ comparisons_tr_reg_main(trainedmodel, regressionmodel, runinfo)
674+ '''
660675
661676 else :
662677 runheight = True
@@ -768,4 +783,6 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
768783
769784 print ("Working on the following tasks: " , tasks )
770785
786+ print (args .data , args .results , args .analysis )
787+
771788 main (args .data , args .results , args .analysis , args .regression_task , include , tasks = tasks , expid = args .expid )
0 commit comments