-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
84 lines (67 loc) · 2.16 KB
/
predict.py
File metadata and controls
84 lines (67 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import tensorflow as tf
import numpy as np
import json, os
models = {
"id": tf.keras.models.load_model("models/id/idk-id.h5"),
"gen": tf.keras.models.load_model("models/gen/15-altb-gen.h5"),
"age": tf.keras.models.load_model("models/age/idk-long-age.h5")
}
def get_indices(name):
if name in ("id", "gen", "age"):
with open(f"models/{name}/{name}-indices.json") as file:
return json.load(file)
def get_label(name, model, img):
pred = model.predict_on_batch(tf.expand_dims(img, 0))
if name in ("id", "age"):
return get_indices(name)[str(np.argmax(pred))], 0
elif name == "gen":
return get_indices("gen")[str(int(round(pred[0][0].numpy())))], pred[0][0].numpy() if not int(
round(pred[0][0].numpy())) else 100 - pred[0][0].numpy()
else:
raise RuntimeError
def predictFromPath(img_path="dbs/test/Ariel_Sharon_0006.jpg"):
path = tf.constant(img_path)
image = tf.io.read_file(path)
image = tf.io.decode_image(image)
image = tf.image.resize(image, (80, 80))
returns = {
"id": {
"class": None,
"confidence": None
},
"gen": {
"class": None,
"confidence": None
},
"age": {
"class": None,
"confidence": None
}
}
for mod in models.keys():
labels = get_label(mod, models[mod], image)
returns[mod]["class"], returns[mod]["confidence"] = labels[0], f"{labels[1] * 100}%"
return returns, img_path
def predictFromTensor(tensor):
image = tf.convert_to_tensor(tensor, dtype=tf.float32)
image = tf.image.resize(image, (80, 80))
returns = {
"id": {
"class": None,
"confidence": None
},
"gen": {
"class": None,
"confidence": None
},
"age": {
"class": None,
"confidence": None
}
}
for mod in models.keys():
labels = get_label(mod, models[mod], image)
returns[mod]["class"], returns[mod]["confidence"] = labels[0], f"{labels[1] * 100}%"
return returns
if __name__ == '__main__':
predictFromPath()