Skip to content

Commit b745122

Browse files
author
jparkpjw
committed
cleaned up imports and updated deprecated pylab calls
1 parent 663e492 commit b745122

3 files changed

Lines changed: 35 additions & 43 deletions

File tree

diffnets/analysis.py

Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
import itertools
44
from . import utils
5+
from .data_processing import WhitenTraj
56
import multiprocessing as mp
67
import os
78
import functools
@@ -11,17 +12,11 @@
1112
from scipy.stats import pearsonr
1213
from sklearn.metrics import roc_auc_score, roc_curve
1314
import matplotlib.pyplot as plt
14-
import enspara
1515
import enspara.cluster as cluster
1616
import enspara.info_theory as infotheor
1717
import enspara.msm as msm
18-
import enspara.cluster as cluster
19-
import enspara.info_theory as infotheor
2018
import pickle
2119
import scipy.sparse
22-
import sys
23-
from pylab import *
24-
from torch.autograd import Variable
2520
from collections import defaultdict
2621

2722
class Analysis:
@@ -487,7 +482,7 @@ def calc_overlap(d1, d2, bins):
487482
ent2[i] = infotheor.shannon_entropy(p2)
488483
return js, ent1, ent2
489484

490-
def project(enc, lab, vars, i1, i2, bins, my_title, cutoff=0.8):
485+
def project(enc, lab, vars, i1, i2, n_bins, my_title, cutoff=0.8):
491486
subsample = 100
492487

493488
all_act_inds = np.where(lab>cutoff)[0]
@@ -537,7 +532,7 @@ def project(enc, lab, vars, i1, i2, bins, my_title, cutoff=0.8):
537532

538533
height = 4
539534
width = height*n_vars
540-
fig = figure(figsize=(width, height))
535+
fig = plt.figure(figsize=(width, height))
541536
fig.suptitle(my_title)
542537
bins = 20
543538
dot_size = 0.1
@@ -556,8 +551,8 @@ def project(enc, lab, vars, i1, i2, bins, my_title, cutoff=0.8):
556551
#imshow(h, interpolation='bilinear', aspect='auto', origin='low', extent=[x[0]+delta_x, x[-1]+delta_x, y[0]+delta_y, y[-1]+delta_y], vmin=cmin-cmax, vmax=0, cmap=get_cmap('Blues_r'))
557552
# transpose to put first dimension (i1) on x axis
558553
#imshow(h.T, interpolation='bilinear', aspect='auto', origin='low', extent=[y[0]+delta_y, y[-1]+delta_y, x[0]+delta_x, x[-1]+delta_x], vmin=cmin-cmax, vmax=0, cmap=get_cmap('Blues_r'))
559-
imshow(h.T, interpolation='bilinear', aspect='auto', origin='low', extent=[x[0]+delta_x, x[-1]+delta_x, y[0]+delta_y, y[-1]+delta_y], vmin=cmin-cmax, vmax=0, cmap=get_cmap('Blues_r'))
560-
colorbar()
554+
plt.imshow(h.T, interpolation='bilinear', aspect='auto', origin='low', extent=[x[0]+delta_x, x[-1]+delta_x, y[0]+delta_y, y[-1]+delta_y], vmin=cmin-cmax, vmax=0, cmap=plt.get_cmap('Blues_r'))
555+
plt.colorbar()
561556

562557
lines = []
563558
line_labels = []
@@ -568,7 +563,7 @@ def project(enc, lab, vars, i1, i2, bins, my_title, cutoff=0.8):
568563
i2_std = i2_dict[v2].std()
569564
#print(v, "x", i1_mu, i1_std)
570565
#print(v, "y", i2_mu, i2_std)
571-
line, _, _ = errorbar([i1_mu], [i2_mu], xerr=[i1_std], yerr=[i2_std], label=v2)
566+
line, _, _ = plt.errorbar([i1_mu], [i2_mu], xerr=[i1_std], yerr=[i2_std], label=v2)
572567
lines.append(line)
573568
line_labels.append(v2)
574569

@@ -579,17 +574,17 @@ def project(enc, lab, vars, i1, i2, bins, my_title, cutoff=0.8):
579574
# if inds.shape[0] > 0:
580575
# scatter(i1_dict[v2][inds], i2_dict[v2][inds], s=dot_size, c='k')
581576

582-
line, _, _ = errorbar([act_i1_mu], [act_i2_mu], xerr=[act_i1_std], yerr=[act_i2_std], label='act', ecolor='k', fmt='k')
577+
line, _, _ = plt.errorbar([act_i1_mu], [act_i2_mu], xerr=[act_i1_std], yerr=[act_i2_std], label='act', ecolor='k', fmt='k')
583578
lines.append(line)
584579
line_labels.append('act')
585580
#legend()
586581

587-
title(v)
582+
plt.title(v)
588583
# scatter([0], [0], s=dot_size*10, c='k')
589584
# scatter([6], [0], s=dot_size*10, c='k')
590585
# scatter([6], [6], s=dot_size*10, c='k')
591586
fig.legend(lines, line_labels)
592-
show()
587+
plt.show()
593588

594589
def morph_conditional(nn_dir, data_dir, n_frames=10):
595590
net = pickle.load(open("%s/nn_best_polish.pkl" % nn_dir, 'rb'))
@@ -601,7 +596,7 @@ def morph_conditional(nn_dir, data_dir, n_frames=10):
601596
uwm = np.load(uwm_fn)
602597
cm_fn = os.path.join(data_dir, "cm.npy")
603598
cm = np.load(cm_fn)
604-
enc = load_npy_dir(os.path.join(nn_dir, "encodings"), "*npy")
599+
enc = utils.load_npy_dir(os.path.join(nn_dir, "encodings"), "*npy")
605600
n_latent = int(enc.shape[1])
606601
morph_dir = os.path.join(nn_dir, "morph")
607602
if not os.path.exists(morph_dir):
@@ -633,7 +628,7 @@ def morph_conditional(nn_dir, data_dir, n_frames=10):
633628
print("single")
634629
outputs = net.decode(morph_enc)
635630
outputs = outputs.data.numpy()
636-
coords = whiten.apply_unwhitening(outputs, uwm, cm)
631+
coords = WhitenTraj.apply_unwhitening(outputs, uwm, cm)
637632
print("shape", coords.shape)
638633
recon_trj = md.Trajectory(coords.reshape((n_frames, n_atoms, 3)), ref_s.top)
639634
out_fn = os.path.join(morph_dir, "m%d.pdb" % i)
@@ -649,7 +644,7 @@ def morph_cond_mean(nn_dir,data_dir,n_frames=10):
649644
uwm = np.load(uwm_fn)
650645
cm_fn = os.path.join(data_dir, "cm.npy")
651646
cm = np.load(cm_fn)
652-
enc = load_npy_dir(os.path.join(nn_dir, "encodings"), "*npy")
647+
enc = utils.load_npy_dir(os.path.join(nn_dir, "encodings"), "*npy")
653648
n_latent = int(enc.shape[1])
654649
morph_dir = os.path.join(nn_dir, "morph_bin_mean")
655650
if not os.path.exists(morph_dir):
@@ -677,7 +672,7 @@ def morph_cond_mean(nn_dir,data_dir,n_frames=10):
677672
traj = utils.recon_traj(morph_enc,net,ref_s.top,cm)
678673
rmsf = get_rmsf(traj)
679674

680-
out_fn = os.path.join(outdir, "m%d.pdb" % i)
675+
out_fn = os.path.join(morph_dir, "m%d.pdb" % i)
681676
traj.save_pdb(out_fn, bfactors=rmsf)
682677

683678
def morph_std(nn_dir, data_dir, enc):
@@ -762,8 +757,8 @@ def get_act_inact(nn_dir, data_dir, enc, labels):
762757
np.save(out_fn, inact_rmsf)
763758

764759
#all_h, x = common_hist([act_rmsf, inact_rmsf], ['act', 'inact'], 20)
765-
fig = figure(figsize=(4, 8))
766-
title
760+
fig = plt.figure(figsize=(4, 8))
761+
plt.title
767762
#plot(x, all_h['act'], label='act')
768763
#plot(x, all_h['inact'], label='inact')
769764
res_nums = []
@@ -772,16 +767,16 @@ def get_act_inact(nn_dir, data_dir, enc, labels):
772767
res_nums.append(r.resSeq)
773768

774769
ax = fig.add_subplot(211)
775-
plot(res_nums, act_rmsf, label='act')
776-
plot(res_nums, inact_rmsf, label='inact')
777-
legend()
770+
plt.plot(res_nums, act_rmsf, label='act')
771+
plt.plot(res_nums, inact_rmsf, label='inact')
772+
plt.legend()
778773

779774
ax = fig.add_subplot(212)
780775
d = act_rmsf-inact_rmsf
781-
plot(res_nums, d, 'k')
776+
plt.plot(res_nums, d, 'k')
782777
out_fn = os.path.join(outdir, "act_minus_inact.npy")
783778
np.save(out_fn, d)
784-
show()
779+
plt.show()
785780

786781
out_fn = os.path.join(outdir, "act_minus_inact.pdb")
787782
ref_s = ref_s.atom_slice(ca_inds)
@@ -801,30 +796,30 @@ def enc_corr(enc):
801796
def project_act(lab_v, vars, my_title):
802797
n_vars = len(vars)
803798
print(my_title)
804-
fig = figure(figsize=(4, 4))
799+
fig = plt.figure(figsize=(4, 4))
805800
fig.suptitle(my_title)
806801
for i in range(n_vars):
807802
v = vars[i]
808803
n, x = np.histogram(lab_v[v], range=(0, 1), bins=50)
809-
plot(x[:-1], n, label=v)
804+
plt.plot(x[:-1], n, label=v)
810805
print(v, lab_v[v].mean())
811-
legend()
812-
show()
806+
plt.legend()
807+
plt.show()
813808

814809

815810
def check_loss(nn_dir):
816811
i = 2
817812
fn = os.path.join(nn_dir, "test_loss_%d.npy" % i)
818813
while os.path.exists(fn):
819814
d = np.load(fn)
820-
plot(d, label=str(i))
815+
plt.plot(d, label=str(i))
821816
i += 1
822817
fn = os.path.join(nn_dir, "test_loss_%d.npy" % i)
823818
fn = os.path.join(nn_dir, "test_loss_polish.npy")
824-
d = load(fn)
825-
plot(d, label='p')
826-
legend()
827-
show()
819+
d = np.load(fn)
820+
plt.plot(d, label='p')
821+
plt.legend()
822+
plt.show()
828823

829824
def clust_encod(nn_dir, n_clusters, vars, lag_times,n_traj_per_var):
830825
msm_dir = os.path.join(nn_dir, "msm_%d" % n_clusters)
@@ -845,7 +840,7 @@ def clust_encod(nn_dir, n_clusters, vars, lag_times,n_traj_per_var):
845840

846841
height = 4
847842
width = height*n_vars
848-
fig = figure(figsize=(width, height))
843+
fig = plt.figure(figsize=(width, height))
849844
fig.suptitle(nn_dir)
850845
for i in range(n_vars):
851846
v = vars[i]
@@ -859,8 +854,8 @@ def clust_encod(nn_dir, n_clusters, vars, lag_times,n_traj_per_var):
859854

860855
ax = fig.add_subplot(1, n_vars, i+1, aspect='auto')
861856
for i, t in enumerate(lag_times):
862-
scatter(t*np.ones(imp_times.shape[1]), imp_times[i])
863-
title(v)
857+
plt.scatter(t*np.ones(imp_times.shape[1]), imp_times[i])
858+
plt.title(v)
864859
ax.set_yscale('log')
865860

866861
markov_lag = 10
@@ -875,7 +870,7 @@ def clust_encod(nn_dir, n_clusters, vars, lag_times,n_traj_per_var):
875870
C_fn = os.path.join(msm_dir, "%s_C_norm_lag%d.npy" % (v, markov_lag))
876871
np.save(C_fn, C)
877872
out_fn = os.path.join(msm_dir, "imp_times.png")
878-
savefig(out_fn)
879-
show()
873+
plt.savefig(out_fn)
874+
plt.show()
880875

881876

diffnets/data_processing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from .utils import *
66
import pickle
77
from collections import defaultdict
8-
98
import numpy as np
109
import mdtraj as md
1110
from scipy.linalg import inv, sqrtm

diffnets/training.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import os
22
import pickle
3-
import sys
43
import multiprocessing as mp
54
import mdtraj as md
65
import numpy as np
7-
from . import exmax, nnutils, utils, data_processing
6+
from . import exmax, nnutils, utils
87
import copy
98
import pickle
10-
119
import torch
1210
import torch.nn as nn
1311
import torch.optim as optim

0 commit comments

Comments
 (0)