-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprediction.py
More file actions
109 lines (77 loc) · 2.94 KB
/
prediction.py
File metadata and controls
109 lines (77 loc) · 2.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Authors: klin
Email: l33klin@foxmail.com
Date: 2020/3/3
"""
import os
import time
import random
import getpass
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
USER = getpass.getuser()
MODEL_PATH = "./model/2" # the path which you save your model after train
BITE_IMG = "bite.jpg"
NO_BITE_IMG = "no_bite.jpg"
IMG_HEIGHT = IMG_WIDTH = 150
def preprocess_image(image):
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH])
image /= 255.0 # normalize to [0,1] range
image = np.expand_dims(image, axis=0)
return image
def load_and_preprocess_image(path):
image = tf.io.read_file(path)
return preprocess_image(image)
# Recreate the exact same model
new_model = tf.keras.models.load_model(MODEL_PATH)
# Check that the state is preserved
# new_predictions = new_model.predict([load_and_preprocess_image(BITE_IMG)])
# DATASET_DIR = "/Users/{}/Nextcloud/Documents/MineCraft/2020-02-26/dataset".format(USER)
DATASET_DIR = "./dataset"
BITE_DIR = os.path.join(DATASET_DIR, "validation/bite")
NO_BITE_DIR = os.path.join(DATASET_DIR, "validation/no_bite")
BITE_IMG_5 = random.sample(os.listdir(BITE_DIR), 5)
NO_BITE_IMG_5 = random.sample(os.listdir(NO_BITE_DIR), 5)
# BITE_IMG_LIST = [load_and_preprocess_image(os.path.join(BITE_DIR, img)) for img in BITE_IMG_5]
# NO_BITE_IMG_LIST = [load_and_preprocess_image(os.path.join(NO_BITE_DIR, img)) for img in NO_BITE_IMG_5]
start = time.time()
for img in BITE_IMG_5:
predictions = new_model.predict([load_and_preprocess_image(os.path.join(BITE_DIR, img))])
print("predictions: {}".format(predictions))
for img in NO_BITE_IMG_5:
predictions = new_model.predict([load_and_preprocess_image(os.path.join(NO_BITE_DIR, img))])
print("predictions: {}".format(predictions))
print("Cost time: {}".format(time.time() - start))
def get_file_name(item):
file_path = item[0]
name_suf = os.path.split(file_path)[-1]
if "." in name_suf:
return ".".join(name_suf.split('.')[:-1])
else:
return name_suf
start = time.time()
IMG_ALL = [(os.path.join(BITE_DIR, x), True) for x in os.listdir(BITE_DIR)] \
+ [(os.path.join(NO_BITE_DIR, x), False) for x in os.listdir(NO_BITE_DIR)]
IMG_ALL = sorted(IMG_ALL, key=get_file_name)
bite_xValue = []
bite_yValue = []
no_bite_xValue = []
no_bite_yValue = []
for index, item in enumerate(IMG_ALL):
predictions = new_model.predict([load_and_preprocess_image(item[0])])
if item[1]:
bite_xValue.append(index)
bite_yValue.append(predictions[0][0])
else:
no_bite_xValue.append(index)
no_bite_yValue.append(predictions[0][0])
print("predict all cost time: {}".format(time.time() - start))
plt.title('Bite scatter')
plt.legend()
plt.scatter(bite_xValue, bite_yValue, s=20, c="r", marker='o')
plt.scatter(no_bite_xValue, no_bite_yValue, s=20, c="b", marker='o')
plt.show()