-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathtrainer.py
More file actions
66 lines (50 loc) · 1.96 KB
/
trainer.py
File metadata and controls
66 lines (50 loc) · 1.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import model
import torch
from torch import nn
from torch import optim
device = torch.device('cpu')
class trainer:
def __init__(self):
self.net = model.ConvNet()
self.net.to(device)
self.criterion = nn.CrossEntropyLoss()
self.optimizer = optim.Adam(self.net.parameters(), lr=0.01)
self.n_epochs = 4
def train(self,trainloader,testloader):
accuracy = 0
self.net.train()
for epoch in range(self.n_epochs):
running_loss = 0.0
print_every = 200 # mini-batches
for i, (inputs, labels) in enumerate(trainloader, 0):
# Transfer to GPU
inputs, labels = inputs.to(device), labels.to(device)
# zero the parameter gradients
self.optimizer.zero_grad()
# forward + backward + optimize
outputs = self.net(inputs)
loss = self.criterion(outputs, labels)
loss.backward()
self.optimizer.step()
# print statistics
running_loss += loss.item()
if (i % print_every) == (print_every-1):
print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss/print_every))
running_loss = 0.0
# Print accuracy after every epoch
accuracy = compute_accuracy(self.net, testloader)
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * accuracy))
print('Finished Training')
return accuracy
def compute_accuracy(net, testloader):
net.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return correct / total