-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathclassifier.py
More file actions
111 lines (83 loc) · 3.52 KB
/
classifier.py
File metadata and controls
111 lines (83 loc) · 3.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import optimizers
from tensorflow.keras.models import Model, save_model
from tensorflow.keras.callbacks import ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer
import logging
import sys
sys.path.append("../utils")
sys.path.append("classifier_utilities")
from utils import *
from interface_utils import *
from classifier_utils import *
from classifier_interface_utils import *
def main(args):
# first make sure that the paths provided are valid
if filepath_is_not_valid(args.data):
logging.error("The path {} is not a file. Aborting..".format(args.data))
exit()
if filepath_is_not_valid(args.datalabels):
logging.error("The path {} is not a file. Aborting..".format(args.datalabels))
exit()
if not filepath_can_be_reached(args.output_path):
logging.error("The path {} is not a file. Aborting..".format(args.output_path))
exit()
if filepath_is_not_valid(args.model_path):
logging.error("The path {} is not a file. Aborting..".format(args.model_path))
exit()
# parse the data from the dataset
X = parse_dataset(args.data)
Y = parse_labelset(args.datalabels)
rows = X.shape[1]
columns = X.shape[2]
# We also need to convert the labels to binary arrays
lb = LabelBinarizer()
Y = lb.fit_transform(Y)
# reshape so that the shapes are (number_of_images, rows, columns, 1)
X = X.reshape(-1, rows, columns, 1)
# normalize
X = X / 255.
# split data to training and validation
rs = 13
X_train, X_val, Y_train, Y_val = train_test_split(X, Y, test_size=0.15, random_state=rs, shuffle=True)
units, epochs, batch_size = (128, 1, 64)
# load the encoder
encoder = load_keras_model(args.model_path)
# "freeze" its weights
encoder.trainable = False
# create the classifier using the encoder
classifier = create_classifier(rows, columns, encoder, units)
print()
classifier.summary()
# setup the classifier
callback = ReduceLROnPlateau(monitor="val_loss", factor=1.0/2, patience=4, min_delta=0.005,
cooldown=0, min_lr=1e-8, verbose=1)
classifier.compile(optimizer=optimizers.Adam(1e-3), loss="categorical_crossentropy", metrics=["categorical_crossentropy", "accuracy"])
# train with the encoder frozen
history = classifier.fit(X_train, Y_train, batch_size=batch_size, epochs=epochs,
shuffle=True, validation_data=(X_val, Y_val),
callbacks=[callback])
# "unfreeze" the encoder
encoder.trainable = True
# now train the whole model
history_ft = classifier.fit(X_train, Y_train, batch_size=batch_size, epochs=epochs,
shuffle=True, validation_data=(X_val, Y_val),
callbacks=[callback])
print("\nProducing output file...")
Y_prob = classifier.predict(X)
Y_pred = np.round(Y_prob)
Y_unbin = np.argmax(Y_pred, 1)
clusters = separate_to_clusters(Y_unbin, 10)
produce_label_file(clusters, args.output_path)
if __name__ == "__main__":
""" call main() function here """
print()
# configure the level of the logging and the format of the messages
logging.basicConfig(level=logging.ERROR, format="%(levelname)s: %(message)s\n")
# parse the command line input
args = parse_input()
# call the main() driver function
main(args)
print("\n")