Skip to content

Commit 67e3dd3

Browse files
committed
updated automatic parameter setettings lookup
1 parent 7850f67 commit 67e3dd3

4 files changed

Lines changed: 116 additions & 78 deletions

File tree

single_cell/combined_violinquantiles_controls.py

Lines changed: 74 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -201,34 +201,40 @@ def plot_compvp(trainedmodevals, controlmodevals, trainedmodel, regcomp = False,
201201
##exclude r2 == 1 scores
202202
#mod = [x[x != 1] for x in mod]
203203
mod = [x[(x != 1) & (x > -0.1)] for x in mod]
204-
205-
vp = ax1.violinplot(mod,
206-
positions=[ilayer*lspace+space*i+1 for ilayer in range(nlayers)],
207-
showextrema = False,
208-
showmedians = False,
209-
showmeans = False,
210-
widths=width)
211-
212-
for part in vp['bodies']:
213-
if lr == 'l':
214-
part.set_facecolor(cmap(cidx[i]))
215-
part.set_edgecolor(cmap(cidx[i]))
216-
else:
217-
if not regcomp:
218-
part.set_facecolor(cmap(cidx[ccolorindex]))
219-
part.set_edgecolor(cmap(cidx[ccolorindex]))
220-
else:
204+
try:
205+
206+
vp = ax1.violinplot(mod,
207+
positions=[ilayer*lspace+space*i+1 for ilayer in range(nlayers)],
208+
showextrema = False,
209+
showmedians = False,
210+
showmeans = False,
211+
widths=width)
212+
213+
for part in vp['bodies']:
214+
if lr == 'l':
221215
part.set_facecolor(cmap(cidx[i]))
222216
part.set_edgecolor(cmap(cidx[i]))
223-
part.set_alpha(alpha[1])
224-
part.set_zorder(zorder)
225-
226-
try:
227-
clip(vp, lr)
228-
except IndexError as e:
229-
print(e)
230-
print("not enough samples for ", lr)
231-
print(vp)
217+
else:
218+
if not regcomp:
219+
part.set_facecolor(cmap(cidx[ccolorindex]))
220+
part.set_edgecolor(cmap(cidx[ccolorindex]))
221+
else:
222+
part.set_facecolor(cmap(cidx[i]))
223+
part.set_edgecolor(cmap(cidx[i]))
224+
part.set_alpha(alpha[1])
225+
part.set_zorder(zorder)
226+
227+
try:
228+
clip(vp, lr)
229+
except IndexError as e:
230+
print(e)
231+
print("not enough samples for ", lr)
232+
print(vp)
233+
234+
except ValueError as e:
235+
print("empty array, can't do violin plot", e)
236+
vp = None
237+
232238
patches.append(mpatches.Patch(color=cmap(cidx[i]), alpha=0.7))
233239

234240
vps.append(vp)
@@ -245,7 +251,10 @@ def plot_compvp(trainedmodevals, controlmodevals, trainedmodel, regcomp = False,
245251
q90s = np.zeros((nlayers,))
246252

247253
for ilayer, layer in enumerate(mod):
248-
q90s[ilayer] = np.quantile(layer, q)
254+
try:
255+
q90s[ilayer] = np.quantile(layer, q)
256+
except e:
257+
print(e)
249258

250259
ax1.plot([ilayer*lspace+space*i+1 for ilayer in range(nlayers)],
251260
q90s, color=cmap(cidx[i]), marker=marker, alpha=alpha[0])
@@ -368,8 +377,10 @@ def plot_compvp_v3(trainedmodevals, controlmodevals, trainedmodel, regcomp = Fal
368377
q90s = np.zeros((nlayers,))
369378

370379
for ilayer, layer in enumerate(mod):
371-
q90s[ilayer] = np.quantile(layer, q)
372-
380+
try:
381+
q90s[ilayer] = np.quantile(layer, q)
382+
except e:
383+
print(e)
373384
ax1.plot([ilayer*lspace+space*i+1 for ilayer in range(nlayers)],
374385
q90s, color=cmap(cidx[i]), marker=marker, alpha=alpha[0])
375386

@@ -489,33 +500,39 @@ def plot_compvp_ee(trainedmodevals, controlmodevals, trainedmodel, regcomp = Fal
489500
##exclude r2 == 1 scores
490501
mod = [x[x != 1] for x in mod]
491502

492-
vp = ax1.violinplot(mod,
493-
positions=[ilayer*lspace+space*i+1 for ilayer in range(nlayers)],
494-
showextrema = False,
495-
showmedians = False,
496-
showmeans = False,
497-
widths=width)
498-
499-
for part in vp['bodies']:
500-
if lr == 'l':
501-
part.set_facecolor(cmap(cidx[i]))
502-
part.set_edgecolor(cmap(cidx[i]))
503-
else:
504-
if not regcomp:
505-
part.set_facecolor(cmap(cidx[ccolorindex]))
506-
part.set_edgecolor(cmap(cidx[ccolorindex]))
507-
else:
503+
try:
504+
vp = ax1.violinplot(mod,
505+
positions=[ilayer*lspace+space*i+1 for ilayer in range(nlayers)],
506+
showextrema = False,
507+
showmedians = False,
508+
showmeans = False,
509+
widths=width)
510+
511+
for part in vp['bodies']:
512+
if lr == 'l':
508513
part.set_facecolor(cmap(cidx[i]))
509514
part.set_edgecolor(cmap(cidx[i]))
510-
part.set_alpha(alpha[1])
511-
part.set_zorder(zorder)
515+
else:
516+
if not regcomp:
517+
part.set_facecolor(cmap(cidx[ccolorindex]))
518+
part.set_edgecolor(cmap(cidx[ccolorindex]))
519+
else:
520+
part.set_facecolor(cmap(cidx[i]))
521+
part.set_edgecolor(cmap(cidx[i]))
522+
part.set_alpha(alpha[1])
523+
part.set_zorder(zorder)
524+
525+
try:
526+
clip(vp, lr)
527+
except IndexError as e:
528+
print(e)
529+
print("not enough samples for ", lr)
530+
print(vp)
531+
532+
except ValueError as e:
533+
print("empty array, can't do violin plot", e)
534+
vp = None
512535

513-
try:
514-
clip(vp, lr)
515-
except IndexError as e:
516-
print(e)
517-
print("not enough samples for ", lr)
518-
print(vp)
519536

520537
patches.append(mpatches.Patch(color=cmap(cidx[i]), alpha=0.7))
521538

@@ -533,7 +550,10 @@ def plot_compvp_ee(trainedmodevals, controlmodevals, trainedmodel, regcomp = Fal
533550
q90s = np.zeros((nlayers,))
534551

535552
for ilayer, layer in enumerate(mod):
536-
q90s[ilayer] = np.quantile(layer, q)
553+
try:
554+
q90s[ilayer] = np.quantile(layer, q)
555+
except e:
556+
print(e)
537557

538558
ax1.plot([ilayer*lspace+space*i+1 for ilayer in range(nlayers)],
539559
q90s, color=cmap(cidx[i]), marker=marker, alpha=alpha[0])

single_cell/control_comparisons.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@
6565
compors = ['hors vs. hors', 'verts vs. verts', 'hor vs. verts', 'vert vs. hors']
6666

6767
#alphas = [0, 0.001, 0.01, 0.1, 1.0, 5.0]
68-
alphas = [0, 0.001, 0.01, 0.1, 1.0, 5.0, 10, 100, 1000, 10000, 100000, 1000000]
69-
#alphas = [0]
68+
#alphas = [0, 0.001, 0.01, 0.1, 1.0, 5.0, 10, 100, 1000, 10000, 100000, 1000000]
69+
alphas = [1]
7070

7171
def format_axis(ax):
7272
ax.spines['top'].set_visible(False)
@@ -2599,8 +2599,8 @@ def comparisons_tr_reg_main(taskmodel, regressionmodel, runinfo, alpha=None):
25992599
print('kinetic and label embeddings already analyzed')
26002600

26012601
#if(runinfo.default_run):
2602-
#if(runinfo['height'] == 'all'):
2603-
if(False):
2602+
if(runinfo['height'] == 'all'):
2603+
#if(False):
26042604
for alpha in alphas:
26052605
print('compiling dataframe for decoding comparions trained & reg...')
26062606
decoding_df = compile_decoding_comparisons_tr_reg_df(taskmodel, regressionmodel, runinfo, alpha)

single_cell/controls_main.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,10 @@ def format_axis(ax):
5858
fsets = ['vel', 'acc', 'labels', 'ee', 'eepolar',]
5959
#decoding_fsets = []
6060
#decoding_fsets = ['ee', 'eepolar', 'vel', 'acc', 'labels']
61-
#decoding_fsets = ['ee', 'eepolar', 'vel', 'acc']
62-
decoding_fsets = ['ee']
61+
#decoding_fsets = ['eepolar', 'vel', 'acc', 'labels']
62+
#decoding_fsets = ['eepolar', 'vel', 'acc']
63+
decoding_fsets = ['ee', 'eepolar', 'vel', 'acc']
64+
#decoding_fsets = ['ee']
6365
#decoding_fsets = ['labels']
6466
orientations = ['hor', 'vert']
6567
uniquezs = list(np.array([-45., -42., -39., -36., -33., -30., -27., -24., -21., -18., -15.,
@@ -242,13 +244,14 @@ def regressiontaskfolder(self, model, analysis = None):
242244

243245
runinfo = RunInfo({'expid': 402, #internal experiment id
244246
#'datafraction': 0.05,
245-
'datafraction': 0.1,
247+
'datafraction': 'auto',
246248
#'datafraction': 0.5, #fraction (0,1] or 'auto'
247249
'randomseed': 2000,
248250
'randomseed_traintest': 42,
249251
'dirr2threshold': 0.2,
250252
'verbose': 2, #0 (least), 1, 2 (most)
251-
'model_experiment_id': 22, #as per Pranav's model generation, int or 'auto'
253+
#'model_experiment_id': 22, #as per Pranav's model generation, int or 'auto'
254+
'model_experiment_id': 'auto',
252255
'basefolder': basefolder,
253256
'batchsize': 100, #for layer representation generation
254257
'default_run': True, #only variable that is 'trial'-dependent,
@@ -269,6 +272,9 @@ def regressiontaskfolder(self, model, analysis = None):
269272
318: {'datafraction': 0.5, 'model_experiment_id' : 22}, #ST Decoding
270273
319: {'datafraction': 0.1, 'model_experiment_id' : 32}, #LSTM Decoding
271274
320: {'datafraction': 0.05, 'model_experiment_id' : 32}, #LSTM CKA only (no tuning curves)
275+
402: {'datafraction': 0.1, 'model_experiment_id': 22 }, # S
276+
403: {'datafraction': 0.1, 'model_experiment_id': 22 }, # ST
277+
404: {'datafraction': 0.1, 'model_experiment_id': 32 }, # ST
272278
}
273279

274280
# %% SAVE OUTPUTS AND RUN ANALYSIS
@@ -370,7 +376,7 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
370376
print("beginning body...")
371377

372378
startmodel = 0
373-
startrun = 1
379+
startrun = 1 ### 0 for full run
374380
startcontrol = False
375381
startior = 0
376382
startheight = 'all'
@@ -380,6 +386,7 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
380386
#endheight = 'after_all' #['after_all', None]
381387
endheight = None
382388

389+
## Utility functions for late start
383390
runmodels = False
384391
runruns = False
385392
runtype = False
@@ -482,7 +489,7 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
482489

483490
#if(not os.path.exists(runinfo.resultsfolder(model_to_analyse, fset))):
484491
#if(default_run):
485-
if(False):
492+
if(True):
486493
#if(False):
487494

488495
print('running %s analysis (fitting tuning curves) for model %s plane %s...' %(fset, modelname, runinfo.planestring()))
@@ -540,7 +547,8 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
540547

541548
#if(not os.path.exists(runinfo.analysisfolder(model_to_analyse, 'polar_tcs'))):
542549
#if(True):
543-
if(default_run):
550+
#if(default_run):
551+
if(runinfo.planestring() == 'horall'):
544552
polar_tcs_main(model_to_analyse, runinfo_to_analyse)
545553
else:
546554
print('polar tc plots already exist')
@@ -558,7 +566,8 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
558566
print('generating preferred direction histograms for model %s plane %s...' %(modelname, runinfo.planestring()))
559567
#if(not os.path.exists(runinfo.analysisfolder(model_to_analyse, 'prefdir'))):
560568
#if(True):
561-
if(default_run):
569+
if(runinfo.planestring() == 'horall'):
570+
562571
prefdir_main(model_to_analyse, runinfo_to_analyse)
563572
else:
564573
print('pref dir plots already exist')
@@ -586,9 +595,9 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
586595

587596
if(control):
588597
#if(not os.path.exists(runinfo.analysisfolder(trainedmodel, 'comp_violin'))):
589-
if(default_run):
598+
#if(default_run):
590599
#if(True):
591-
#if(runinfo.planestring() == 'horall'):
600+
if(runinfo.planestring() == 'horall'):
592601
print('saving comparison violin plot for model %s plane %s...' %(modelname, runinfo.planestring()))
593602
comp_violin_main(trainedmodel, model_to_analyse, runinfo)
594603
else:
@@ -625,9 +634,9 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
625634
print("doing trreg cka for model %s plane %s ... " %(modelname, runinfo.planestring()))
626635
rsa_main(model_to_analyse, regressionmodel, runinfo, trreg=True)
627636

628-
if(True):
629-
print("making new violin plots")
630-
comp_tr_reg_violin_main_newplots(trainedmodel, regressionmodel, runinfo)
637+
if(runinfo.planestring() == 'horall'):
638+
print("making new violin plots")
639+
comp_tr_reg_violin_main_newplots(trainedmodel, regressionmodel, runinfo)
631640

632641
if (i==5):
633642
if(control):
@@ -637,12 +646,14 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
637646
else:
638647
print('skipping comparisons')
639648

640-
#if(runinfo.planestring() == 'horall'):
641-
if(True):
649+
if(runinfo.planestring() == 'horall'):
650+
#if(True):
642651
print('combining rsa results for all models')
643652
#if(not os.path.exists(runinfo.sharedanalysisfolder(trainedmodel, 'rsa'))):
644653
if(True):
645654
#if(default_run):
655+
#if(runinfo.planestring() == 'horall'):
656+
646657
rsa_models_comp(model, runinfo)
647658
else:
648659
print('rsa models comp already completed')
@@ -651,12 +662,16 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
651662
if(True):
652663
#if(False):
653664
print('starting comparisons_tr_reg_main')
654-
comparisons_tr_reg_main(model, regressionmodel, runinfo)
665+
#comparisons_tr_reg_main(model, regressionmodel, runinfo)
666+
comparisons_tr_reg_main(trainedmodel, regressionmodel, runinfo)
655667

668+
'''
656669
#if(False):
657670
if(True):
658671
print("combining trreg RSA results for all models")
659-
rsa_models_comp(model, runinfo, trreg=True)
672+
#rsa_models_comp(model, runinfo, trreg=True)
673+
comparisons_tr_reg_main(trainedmodel, regressionmodel, runinfo)
674+
'''
660675

661676
else:
662677
runheight = True
@@ -768,4 +783,6 @@ def main(do_data=False, do_results=False, do_analysis=False, do_regression_task
768783

769784
print("Working on the following tasks: ", tasks)
770785

786+
print(args.data, args.results, args.analysis)
787+
771788
main(args.data, args.results, args.analysis, args.regression_task, include, tasks= tasks, expid= args.expid)

0 commit comments

Comments
 (0)