22import numpy as np
33import itertools
44from . import utils
5+ from .data_processing import WhitenTraj
56import multiprocessing as mp
67import os
78import functools
1112from scipy .stats import pearsonr
1213from sklearn .metrics import roc_auc_score , roc_curve
1314import matplotlib .pyplot as plt
14- import enspara
1515import enspara .cluster as cluster
1616import enspara .info_theory as infotheor
1717import enspara .msm as msm
18- import enspara .cluster as cluster
19- import enspara .info_theory as infotheor
2018import pickle
2119import scipy .sparse
22- import sys
23- from pylab import *
24- from torch .autograd import Variable
2520from collections import defaultdict
2621
2722class Analysis :
@@ -487,7 +482,7 @@ def calc_overlap(d1, d2, bins):
487482 ent2 [i ] = infotheor .shannon_entropy (p2 )
488483 return js , ent1 , ent2
489484
490- def project (enc , lab , vars , i1 , i2 , bins , my_title , cutoff = 0.8 ):
485+ def project (enc , lab , vars , i1 , i2 , n_bins , my_title , cutoff = 0.8 ):
491486 subsample = 100
492487
493488 all_act_inds = np .where (lab > cutoff )[0 ]
@@ -537,7 +532,7 @@ def project(enc, lab, vars, i1, i2, bins, my_title, cutoff=0.8):
537532
538533 height = 4
539534 width = height * n_vars
540- fig = figure (figsize = (width , height ))
535+ fig = plt . figure (figsize = (width , height ))
541536 fig .suptitle (my_title )
542537 bins = 20
543538 dot_size = 0.1
@@ -556,8 +551,8 @@ def project(enc, lab, vars, i1, i2, bins, my_title, cutoff=0.8):
556551 #imshow(h, interpolation='bilinear', aspect='auto', origin='low', extent=[x[0]+delta_x, x[-1]+delta_x, y[0]+delta_y, y[-1]+delta_y], vmin=cmin-cmax, vmax=0, cmap=get_cmap('Blues_r'))
557552 # transpose to put first dimension (i1) on x axis
558553 #imshow(h.T, interpolation='bilinear', aspect='auto', origin='low', extent=[y[0]+delta_y, y[-1]+delta_y, x[0]+delta_x, x[-1]+delta_x], vmin=cmin-cmax, vmax=0, cmap=get_cmap('Blues_r'))
559- imshow (h .T , interpolation = 'bilinear' , aspect = 'auto' , origin = 'low' , extent = [x [0 ]+ delta_x , x [- 1 ]+ delta_x , y [0 ]+ delta_y , y [- 1 ]+ delta_y ], vmin = cmin - cmax , vmax = 0 , cmap = get_cmap ('Blues_r' ))
560- colorbar ()
554+ plt . imshow (h .T , interpolation = 'bilinear' , aspect = 'auto' , origin = 'low' , extent = [x [0 ]+ delta_x , x [- 1 ]+ delta_x , y [0 ]+ delta_y , y [- 1 ]+ delta_y ], vmin = cmin - cmax , vmax = 0 , cmap = plt . get_cmap ('Blues_r' ))
555+ plt . colorbar ()
561556
562557 lines = []
563558 line_labels = []
@@ -568,7 +563,7 @@ def project(enc, lab, vars, i1, i2, bins, my_title, cutoff=0.8):
568563 i2_std = i2_dict [v2 ].std ()
569564 #print(v, "x", i1_mu, i1_std)
570565 #print(v, "y", i2_mu, i2_std)
571- line , _ , _ = errorbar ([i1_mu ], [i2_mu ], xerr = [i1_std ], yerr = [i2_std ], label = v2 )
566+ line , _ , _ = plt . errorbar ([i1_mu ], [i2_mu ], xerr = [i1_std ], yerr = [i2_std ], label = v2 )
572567 lines .append (line )
573568 line_labels .append (v2 )
574569
@@ -579,17 +574,17 @@ def project(enc, lab, vars, i1, i2, bins, my_title, cutoff=0.8):
579574 # if inds.shape[0] > 0:
580575 # scatter(i1_dict[v2][inds], i2_dict[v2][inds], s=dot_size, c='k')
581576
582- line , _ , _ = errorbar ([act_i1_mu ], [act_i2_mu ], xerr = [act_i1_std ], yerr = [act_i2_std ], label = 'act' , ecolor = 'k' , fmt = 'k' )
577+ line , _ , _ = plt . errorbar ([act_i1_mu ], [act_i2_mu ], xerr = [act_i1_std ], yerr = [act_i2_std ], label = 'act' , ecolor = 'k' , fmt = 'k' )
583578 lines .append (line )
584579 line_labels .append ('act' )
585580 #legend()
586581
587- title (v )
582+ plt . title (v )
588583 # scatter([0], [0], s=dot_size*10, c='k')
589584 # scatter([6], [0], s=dot_size*10, c='k')
590585 # scatter([6], [6], s=dot_size*10, c='k')
591586 fig .legend (lines , line_labels )
592- show ()
587+ plt . show ()
593588
594589def morph_conditional (nn_dir , data_dir , n_frames = 10 ):
595590 net = pickle .load (open ("%s/nn_best_polish.pkl" % nn_dir , 'rb' ))
@@ -601,7 +596,7 @@ def morph_conditional(nn_dir, data_dir, n_frames=10):
601596 uwm = np .load (uwm_fn )
602597 cm_fn = os .path .join (data_dir , "cm.npy" )
603598 cm = np .load (cm_fn )
604- enc = load_npy_dir (os .path .join (nn_dir , "encodings" ), "*npy" )
599+ enc = utils . load_npy_dir (os .path .join (nn_dir , "encodings" ), "*npy" )
605600 n_latent = int (enc .shape [1 ])
606601 morph_dir = os .path .join (nn_dir , "morph" )
607602 if not os .path .exists (morph_dir ):
@@ -633,7 +628,7 @@ def morph_conditional(nn_dir, data_dir, n_frames=10):
633628 print ("single" )
634629 outputs = net .decode (morph_enc )
635630 outputs = outputs .data .numpy ()
636- coords = whiten .apply_unwhitening (outputs , uwm , cm )
631+ coords = WhitenTraj .apply_unwhitening (outputs , uwm , cm )
637632 print ("shape" , coords .shape )
638633 recon_trj = md .Trajectory (coords .reshape ((n_frames , n_atoms , 3 )), ref_s .top )
639634 out_fn = os .path .join (morph_dir , "m%d.pdb" % i )
@@ -649,7 +644,7 @@ def morph_cond_mean(nn_dir,data_dir,n_frames=10):
649644 uwm = np .load (uwm_fn )
650645 cm_fn = os .path .join (data_dir , "cm.npy" )
651646 cm = np .load (cm_fn )
652- enc = load_npy_dir (os .path .join (nn_dir , "encodings" ), "*npy" )
647+ enc = utils . load_npy_dir (os .path .join (nn_dir , "encodings" ), "*npy" )
653648 n_latent = int (enc .shape [1 ])
654649 morph_dir = os .path .join (nn_dir , "morph_bin_mean" )
655650 if not os .path .exists (morph_dir ):
@@ -677,7 +672,7 @@ def morph_cond_mean(nn_dir,data_dir,n_frames=10):
677672 traj = utils .recon_traj (morph_enc ,net ,ref_s .top ,cm )
678673 rmsf = get_rmsf (traj )
679674
680- out_fn = os .path .join (outdir , "m%d.pdb" % i )
675+ out_fn = os .path .join (morph_dir , "m%d.pdb" % i )
681676 traj .save_pdb (out_fn , bfactors = rmsf )
682677
683678def morph_std (nn_dir , data_dir , enc ):
@@ -762,8 +757,8 @@ def get_act_inact(nn_dir, data_dir, enc, labels):
762757 np .save (out_fn , inact_rmsf )
763758
764759 #all_h, x = common_hist([act_rmsf, inact_rmsf], ['act', 'inact'], 20)
765- fig = figure (figsize = (4 , 8 ))
766- title
760+ fig = plt . figure (figsize = (4 , 8 ))
761+ plt . title
767762 #plot(x, all_h['act'], label='act')
768763 #plot(x, all_h['inact'], label='inact')
769764 res_nums = []
@@ -772,16 +767,16 @@ def get_act_inact(nn_dir, data_dir, enc, labels):
772767 res_nums .append (r .resSeq )
773768
774769 ax = fig .add_subplot (211 )
775- plot (res_nums , act_rmsf , label = 'act' )
776- plot (res_nums , inact_rmsf , label = 'inact' )
777- legend ()
770+ plt . plot (res_nums , act_rmsf , label = 'act' )
771+ plt . plot (res_nums , inact_rmsf , label = 'inact' )
772+ plt . legend ()
778773
779774 ax = fig .add_subplot (212 )
780775 d = act_rmsf - inact_rmsf
781- plot (res_nums , d , 'k' )
776+ plt . plot (res_nums , d , 'k' )
782777 out_fn = os .path .join (outdir , "act_minus_inact.npy" )
783778 np .save (out_fn , d )
784- show ()
779+ plt . show ()
785780
786781 out_fn = os .path .join (outdir , "act_minus_inact.pdb" )
787782 ref_s = ref_s .atom_slice (ca_inds )
@@ -801,30 +796,30 @@ def enc_corr(enc):
801796def project_act (lab_v , vars , my_title ):
802797 n_vars = len (vars )
803798 print (my_title )
804- fig = figure (figsize = (4 , 4 ))
799+ fig = plt . figure (figsize = (4 , 4 ))
805800 fig .suptitle (my_title )
806801 for i in range (n_vars ):
807802 v = vars [i ]
808803 n , x = np .histogram (lab_v [v ], range = (0 , 1 ), bins = 50 )
809- plot (x [:- 1 ], n , label = v )
804+ plt . plot (x [:- 1 ], n , label = v )
810805 print (v , lab_v [v ].mean ())
811- legend ()
812- show ()
806+ plt . legend ()
807+ plt . show ()
813808
814809
815810def check_loss (nn_dir ):
816811 i = 2
817812 fn = os .path .join (nn_dir , "test_loss_%d.npy" % i )
818813 while os .path .exists (fn ):
819814 d = np .load (fn )
820- plot (d , label = str (i ))
815+ plt . plot (d , label = str (i ))
821816 i += 1
822817 fn = os .path .join (nn_dir , "test_loss_%d.npy" % i )
823818 fn = os .path .join (nn_dir , "test_loss_polish.npy" )
824- d = load (fn )
825- plot (d , label = 'p' )
826- legend ()
827- show ()
819+ d = np . load (fn )
820+ plt . plot (d , label = 'p' )
821+ plt . legend ()
822+ plt . show ()
828823
829824def clust_encod (nn_dir , n_clusters , vars , lag_times ,n_traj_per_var ):
830825 msm_dir = os .path .join (nn_dir , "msm_%d" % n_clusters )
@@ -845,7 +840,7 @@ def clust_encod(nn_dir, n_clusters, vars, lag_times,n_traj_per_var):
845840
846841 height = 4
847842 width = height * n_vars
848- fig = figure (figsize = (width , height ))
843+ fig = plt . figure (figsize = (width , height ))
849844 fig .suptitle (nn_dir )
850845 for i in range (n_vars ):
851846 v = vars [i ]
@@ -859,8 +854,8 @@ def clust_encod(nn_dir, n_clusters, vars, lag_times,n_traj_per_var):
859854
860855 ax = fig .add_subplot (1 , n_vars , i + 1 , aspect = 'auto' )
861856 for i , t in enumerate (lag_times ):
862- scatter (t * np .ones (imp_times .shape [1 ]), imp_times [i ])
863- title (v )
857+ plt . scatter (t * np .ones (imp_times .shape [1 ]), imp_times [i ])
858+ plt . title (v )
864859 ax .set_yscale ('log' )
865860
866861 markov_lag = 10
@@ -875,7 +870,7 @@ def clust_encod(nn_dir, n_clusters, vars, lag_times,n_traj_per_var):
875870 C_fn = os .path .join (msm_dir , "%s_C_norm_lag%d.npy" % (v , markov_lag ))
876871 np .save (C_fn , C )
877872 out_fn = os .path .join (msm_dir , "imp_times.png" )
878- savefig (out_fn )
879- show ()
873+ plt . savefig (out_fn )
874+ plt . show ()
880875
881876
0 commit comments