-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
43 lines (40 loc) · 2.32 KB
/
train.py
File metadata and controls
43 lines (40 loc) · 2.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import cupy as np
from load_train_data import load_train_data
from load_test_data import load_test_data
from load_or_train_model import load_or_train_model
x, y = load_train_data('train.csv', 'train.pickle')
nn = load_or_train_model(x, y)
test_x = load_test_data('test.csv', 'test.pickle')
print("ImageId,Label")
i = 1
for test in test_x:
predicted = nn.predict(test.reshape((784, 1)))
print(i, np.argmax(predicted), sep=",")
i += 1
# print("Generating images...")
# x = np.zeros((784, 1))
# # generated = nn.generate(x, np.array([0, 0, 0, 1, 0, 0, 0, 0, 0, 0]).reshape((10, 1)), 0, 255, 1, 1)
# # generated = nn.generate(x, np.array([0, 0, 0, 1, 0, 0, 0, 0, 0, 0]).reshape((10, 1)), 0, 255, 10, 0.1)
# # generated = nn.generate(x, np.array([0, 0, 0, 1, 0, 0, 0, 0, 0, 0]).reshape((10, 1)), 0, 255, 30, 0.05)
# generated = nn.generate(x, np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 0]).reshape((10, 1)), 0, 255, 1, 500)
# generated = nn.generate(x, np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 0]).reshape((10, 1)), 0, 255, 10, 100)
# generated = nn.generate(x, np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 0]).reshape((10, 1)), 0, 255, 50, 80)
# generated = nn.generate(x, np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 0]).reshape((10, 1)), 0, 255, 100, 50)
# generated = nn.generate(x, np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 0]).reshape((10, 1)), 0, 255, 200, 30)
# generated = nn.generate(x, np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 0]).reshape((10, 1)), 0, 255, 200, 10)
# generated = nn.generate(x, np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 0]).reshape((10, 1)), 0, 255, 200, 8)
# generated = nn.generate(x, np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 0]).reshape((10, 1)), 0, 255, 200, 5)
# generated = nn.generate(x, np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 0]).reshape((10, 1)), 0, 255, 200, 2)
# generated = nn.generate(x, np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 0]).reshape((10, 1)), 0, 255, 200, 1)
# generated = nn.generate(x, np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 0]).reshape((10, 1)), 0, 255, 300, 0.8)
# generated = nn.generate(x, np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 0]).reshape((10, 1)), 0, 255, 300, 0.1)
# loss = nn.loss_function.f(np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 0]).reshape(10,1), nn.predict(generated))
# print("Loss:", loss)
#
# generated = np.round(generated)
#
# import matplotlib.pyplot as plt
# plt.imshow(generated.reshape((28, 28)).tolist(),cmap='gray')
# plt.clim(0, 255)
# plt.show()
#