-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
55 lines (40 loc) · 1.59 KB
/
main.py
File metadata and controls
55 lines (40 loc) · 1.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""Main script.
When launched, it guides the user from choosing a classification method and its parameters to plotting and saving the data.
"""
from modules.helpers import (
choose_classification_method,
choose_split_factor,
plot_and_save_fig,
)
from modules.knn import evaluate_knn, evaluate_knn_optimized
from modules.mlp import run_mlp_training
from modules.read_cifar import read_cifar, split_dataset
chosen_method = choose_classification_method()
split_factor = choose_split_factor()
k_max = 20
dict = read_cifar("data")
(data, labels) = dict
data_train, labels_train, data_test, labels_test = split_dataset(
data, labels, split_factor
)
if chosen_method == 0:
accuracies = []
list_k = list(range(1, 21, 1))
for k in list_k:
print("Evaluating KNN for k = ", k)
k_accuracy = evaluate_knn(data_train, labels_train, data_test, labels_test, k)
accuracies.append(k_accuracy)
plot_and_save_fig(k_max, accuracies, split_factor, "knn")
elif chosen_method == 1:
accuracy_for_all_k = evaluate_knn_optimized(
data_train, labels_train, data_test, labels_test, k_max
)
plot_and_save_fig(k_max, accuracy_for_all_k, split_factor, "knn")
elif chosen_method == 2:
num_epoch = 100
(train_accuracies, train_losses, final_accuracy) = run_mlp_training(
data_train, labels_train, data_test, labels_test, 64, 0.1, num_epoch
)
print("Final accuracy for test data is :", final_accuracy)
plot_and_save_fig(num_epoch, train_accuracies, split_factor, "mlp")
plot_and_save_fig(num_epoch, train_losses, split_factor, "mlp loss")