-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtrainer.py
More file actions
57 lines (45 loc) · 1.88 KB
/
trainer.py
File metadata and controls
57 lines (45 loc) · 1.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import numpy as np
from tqdm import tqdm
class Optimizer(object):
def __init__(self, model, data,
learning_rate=0.0001,
num_epochs=200,
batch_size=32,
val_period=20,
verbose=True):
self.model = model
self.model.learning_rate = learning_rate
self.num_epochs = num_epochs
self.val_period = val_period
self.x_train, self.x_val, self.x_test, self.y_train, self.y_val, self.y_test = data
def train(self):
for epoch in range(self.num_epochs):
print("epoch %i starting..." % epoch)
losses = []
for input, target in tqdm(zip(self.x_train, self.y_train), total=len(self.y_train)):
loss = self.model.forward(input, target)
losses.append(loss)
self.model.backward(input, target)
self.model.update()
print ("loss of epoch %i is: %f " % (epoch, np.mean(losses)))
# get error on val set
if epoch % self.val_period == 0:
self.validate()
def validate(self):
errors = []
for input, target in tqdm(zip(self.x_val, self.y_val), total=len(self.y_val)):
pred = self.model.predict(input)
class_error = np.equal(pred, target).astype(np.int32)
errors.append(class_error)
errors = np.hstack(errors)
errors = np.mean(errors)
print ("validation accuracy: %f" % errors)
def test(self):
errors = []
for input, target in tqdm(zip(self.x_test, self.y_test), total=len(self.y_test)):
pred = self.model.predict(input)
class_error = np.equal(pred, target).astype(np.int32)
errors.append(class_error)
errors = np.hstack(errors)
errors = np.mean(errors)
print ("test accuracy: %f" % errors)