-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraph.py
More file actions
47 lines (30 loc) · 1.36 KB
/
graph.py
File metadata and controls
47 lines (30 loc) · 1.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import argparse
import matplotlib.pyplot as plt
import numpy as np
from data_preprocessing import create_dataframe
from variables import colors, learning_rate, stochastic
def draw_graphs(losses, classes, accuracies):
total = len(losses)
plt.style.use("gruvbox.mplstyle")
figure, axes = plt.subplots(figsize=(12, 8), dpi=500)
for class_name, content in losses.items():
axes.plot(np.linspace(0, total, total), content, ".", c=colors[class_name])
axes.plot(np.linspace(0, total, total), accuracies, ".", c="#ebdbb2")
axes.set_xlabel("step")
axes.set_ylabel("loss")
axes.legend(classes)
return figure
if __name__ == "__main__":
print("\n------------ Graph -----------")
parser = argparse.ArgumentParser(description="A simple python program to print a summary of a given csv dataset")
parser.add_argument("losses", help="the loss csv file")
parser.add_argument("accuracies", help="the loss csv file")
args = parser.parse_args()
losses = create_dataframe(args.losses)
accuracies = create_dataframe(args.accuracies)
classes = losses.columns.tolist()
figure = draw_graphs(losses, classes, accuracies)
figure.suptitle(f"learning rate: {learning_rate} | Stochastic: {stochastic}")
figure.savefig(f"static/Image/loss/lr_{learning_rate}_st_{stochastic}.png", dpi=200)
plt.show()
plt.close()