-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhelpers.py
More file actions
158 lines (124 loc) · 4.88 KB
/
helpers.py
File metadata and controls
158 lines (124 loc) · 4.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""This module groups the helpers function, which means, the functions used to ask questions to the user and manage its answers."""
from datetime import datetime
from matplotlib import pyplot as plt
def choose_classification_method():
"""Can be used to ask the classification method to evaluate.
While the user's input is invalid and script is running, it will ask a new value.
Returns:
int: Representing th chosen method :
- 0 for unoptimized KNN method
- 1 for optimized KNN
- 2 for Neural Networks
"""
method_int = 0
method_int_test = False
while not method_int_test:
str_method_int = input(
"\nMETHOD :\nChoose the method you want to use to classify CIFAR-10 images : 0 - KNN (unoptimized), 1 - KNN (optimized), 2 - NN MLP (Neural Network)\nYou entered : "
)
try:
method_int = int(str_method_int)
if (type(method_int) == int) and (0 <= method_int <= 2):
method_int_test = True
else:
print(
"You must choose a method among : 0 - KNN (unoptimized), 1 - KNN (optimized), 2 - NN MLP (Neural Network)\n",
method_int,
"is invalid",
)
method_int_test = False
except Exception as e:
method_int_test = False
print(e)
return method_int
def choose_split_factor():
"""Can be used to ask the split factor used to split CIFAR-10 data in training and test data.
While the user's input is invalid and script is running, it will ask a new value.
Returns:
float: the split factor between 0. and 1. The closer it is to 1, the smaller test data will be.
Example : with split_factor = 0.8 : 80% data will be training data, and 20% will be test data.
"""
split_factor = 1
split_test = False
while not (split_test):
str_split_factor = input(
"\nSPLIT FACTOR :\nEnter a float between 0 and 1 which determines the split factor between training and test sets.\nYou entered : "
)
try:
split_factor = float(str_split_factor)
if (type(split_factor) == float) and (0.0 < split_factor < 1.0):
split_test = True
else:
print(
"You must enter a float between 0 and 1 :",
split_factor,
"is invalid",
)
split_test = False
except Exception as e:
split_test = False
print(e)
return split_factor
def choose_to_save():
"""Can be used to ask whether to save the plotted figure or not.
While the user's input is invalid and script is running, it will ask a new value.
Returns:
int: 0 if the user decided to not save the figure, 1 if he decided to save it.
"""
save = "n"
save_test = False
while not save_test:
save = input(
"\nSAVE PLOT :\nDo you want to save the current plot : type [y/yes] for yes and [n/no] for no.\nYou entered : "
)
try:
if save in ["y", "yes", "n", "no"]:
save_test = True
else:
print(
"You must enter [y/yes] or [n/no] to choose :", save, "is invalid"
)
save_test = False
except Exception as e:
save_test = False
print(e)
if save in ["n", "no"]:
return 0
else:
return 1
def plot_and_save_fig(x_max, accuracies, split_factor, name):
"""Plot the accuracy of a classification method.
When using KNN, it will plot the accuracy over k (he number of nearest neighbors used for classification).
When using NN, it will plot the accuracy over the number of epochs spent in training.
Args:
x_max (int): The maximum number of neighbors used to evaluate the classification method.
accuracies (List[int]): The computed accuracies of KNN method for k in range (1,k_max).
split_factor (float): The split factor used to split CIFAR-10 data in training and test data.
"""
print("Plotting figure")
range_k_max = range(1, x_max + 1, 1)
list_k = list(range_k_max)
fig = plt.figure()
plt.plot(
list_k,
accuracies,
marker="o",
linestyle="--",
color="b",
label="split_factor = " + str(split_factor),
)
plt.title(name + " method accuracy using CIFAR-10 data")
plt.grid(True, which="both")
plt.ylabel("Accuracy")
plt.legend()
if name == "knn":
default_x_ticks = range_k_max
plt.xticks(default_x_ticks, list_k)
plt.xlabel("k number of neighbors")
else:
plt.xlabel("Epoch number")
plt.show()
save = choose_to_save()
if save == 1:
fig.savefig("results/" + name + str(datetime.now()) + ".png")
return