-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmvh_performance_plot_2.py
More file actions
90 lines (82 loc) · 3.24 KB
/
mvh_performance_plot_2.py
File metadata and controls
90 lines (82 loc) · 3.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import statistics
from scipy.interpolate import interp1d
from matplotlib import pyplot as plt
import csv
def main(args):
# Set larger font sizes globally with correct parameter names
plt.rcParams.update({
'font.size': 14,
'axes.labelsize': 14,
'axes.titlesize': 16,
'xtick.labelsize': 12, # Fixed from xticks to xtick
'ytick.labelsize': 12, # Fixed from yticks to ytick
'legend.fontsize': 12
})
base_dir = "/scratch/"
tests = {
'edge accuracy (top-1)',
'silhouette accuracy (top-1)',
'cue-conflict accuracy (top-1)',
'colour accuracy (top-1)',
'contrast accuracy (top-1)',
'high-pass accuracy (top-1)',
'low-pass accuracy (top-1)',
'phase-scrambling accuracy (top-1)',
'power-equalisation accuracy (top-1)',
'false-colour accuracy (top-1)',
'rotation accuracy (top-1)',
'eidolonI accuracy (top-1)',
'eidolonII accuracy (top-1)',
'eidolonIII accuracy (top-1)',
'uniform-noise accuracy (top-1)',
'sketch accuracy (top-1)',
'sketch accuracy (top-5)',
'stylized accuracy (top-1)',
'stylized accuracy (top-5)',
}
models_pruning = {
'resnet18' : 27,
'resnet50': 26,
'swin' : 8,
'vit_b_32' : 14,
}
for ii, test_name in enumerate(tests):
plt.figure(figsize=(10, 8))
for model in models_pruning.keys():
values = []
for step in range(models_pruning[model]):
print(os.path.isfile(base_dir + model + f'_pruning_step_{str(step)}.csv'))
with open(base_dir + model + f'_pruning_step_{str(step)}.csv', mode='r') as csvfile:
csvreader = csv.reader(csvfile)
for ii, row in enumerate(csvreader):
if ii > 0:
print(row)
print(row[0])
if str(row[1] + ' ' + row[2]) == test_name:
print(row)
values.append(float(row[-1]))
print(values)
print(len(values))
print(models_pruning[model])
avg = statistics.mean(values)
for _ in range(step, 28):
values.append(avg)
plt.plot(values, label=model, linewidth=2)
plt.title(test_name, pad=20)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.xlabel('Pruning Step', fontsize=14)
plt.ylabel('Accuracy', fontsize=14)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(base_dir + test_name + '.png', bbox_inches='tight', dpi=300)
plt.close()
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
parser.add_argument("--model-name", default="resnet50", type=str, help="Chosen explainability method")
parser.add_argument("--max-prune", default="26", type=int, help="Chosen explainability method")
return parser
if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)