-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
70 lines (58 loc) · 2.44 KB
/
predict.py
File metadata and controls
70 lines (58 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import joblib,json, os,pdb
import jieba_fast as jieba
import numpy as np
from process import load_dataset
from models.TextCNN2 import TextCNN
from evaluation import f1_avg
import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
word2id = joblib.load('./dictionary/word2idx_all.p')
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
graph=tf.Graph().as_default()
def predict():
params = json.loads(open('parameters.json').read())
max_sequence_length = params['max_seq_len']
batch_size = params['batch_size']
vocab_size = len(word2id)
textCNN = TextCNN(filter_sizes=list(map(int, params['filter_sizes'].split(","))), num_filters=params['num_filters'],
num_classes=params['num_classes'],
learning_rate=1e-3, batch_size=batch_size, decay_steps=50,
decay_rate=1.0, sequence_length=max_sequence_length, vocab_size=vocab_size,
embed_size=params['embedding_dim'],
usePretrainEmbeddings=params['use_embeddings'],
)
graph = tf.Graph()
with graph.as_default():
x_test = load_dataset(filename='./data/test_data.p', loadflag=True, isTrain=False,
max_sequence_len=max_sequence_length)
session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
sess = tf.Session(config=session_conf)
y_pred = []
with sess.as_default():
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables(), max_to_keep=2000)
ckpt_dir = os.path.abspath(os.path.join(os.path.curdir, "trained_model/"))
print(ckpt_dir, os.path.exists(ckpt_dir))
if os.path.exists(ckpt_dir):
print("Restoring Variables from Checkpoint")
ckpt_path = ckpt_dir + '/model.ckpt-49'
saver.restore(sess, ckpt_path)
else:
print("Can't find the checkpoint.going to stop")
number_of_test_data = len(x_test)
print("number_of_test_data:", number_of_test_data)
for start, end in zip(range(0, number_of_test_data, batch_size), range(batch_size, number_of_test_data, batch_size)):
logits, predictions = sess.run([textCNN.logits,textCNN.predictions], feed_dict={textCNN.input_x: x_test[start:end],
textCNN.dropout_keep_prob: 1,
textCNN.is_training_flag:False})
y_pred += [np.argmax(i) for i in logits]
pdb.set_trace()
return y_pred
if __name__ == '__main__':
y_pred = predict()
print("y_pred", len(y_pred))
# accuray = f1_avg(y_pred, y_true)
# print("evaluation accuray", accuray)