-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathtrain.py
More file actions
104 lines (80 loc) · 3.41 KB
/
train.py
File metadata and controls
104 lines (80 loc) · 3.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import numpy as np
from utils import *
def train(train_loader, model, criterion, optimizer, center):
train_loss = 0
correct = 0
total = 0
idx = 0
model.train()
for i, data in enumerate(train_loader):
idx = i
img, positive_img, label = data
input = img.cuda()
positive_input = positive_img.cuda()
label = torch.Tensor(label).type(torch.int64).cuda()
out1, attention_maps1, bilinear_features, output1 = model(input)
erase_img = attention_erase(attention_maps1, input)
out2, _, _, output2 = model(erase_img)
out3, attention_maps3, _, output3 = model(positive_input)
erase_img_positive = attention_erase(attention_maps3, positive_input)
out4, _, _, output4 = model(erase_img_positive)
fuse_map1 = co_att(attention_maps1, attention_maps3)
fuse_map2 = co_att(attention_maps3, attention_maps1)
bilinear_pooling = Bilinear_Pooling()
pooling1 = torch.flatten(bilinear_pooling(out1, fuse_map1), 1)
pooling2 = torch.flatten(bilinear_pooling(out2, fuse_map2), 1)
output5 = model.module.classifier(pooling1)
output6 = model.module.classifier(pooling2)
loss1 = criterion(output1, label)
loss2 = criterion(output2, label)
loss3 = criterion(output3, label)
loss4 = criterion(output4, label)
loss5 = criterion(output5, label)
loss6 = criterion(output6, label)
features = bilinear_features.reshape(bilinear_features.shape[0], -1)/100
center_loss, center_diff = Center_Loss(features, center, label)
center[label] += center_diff
loss = (loss1 + loss2 + loss3 + loss4 + loss5 + loss6)/6 + center_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = torch.max(output1.data, 1)
total += label.size(0)
correct += predicted.eq(label.data).cpu().sum()
if i % 50 == 0:
print('Step: %d | Loss: %.3f | Acc: %.3f%% (%d/%d)' % (
i, train_loss / (i + 1), 100. * float(correct) / total, correct, total))
train_acc = 100. * float(correct) / total
train_loss = train_loss / (idx + 1)
return train_acc, train_loss
def test(test_loader, model, criterion, center):
test_loss = 0
correct = 0
total = 0
idx = 0
model.eval()
for i, data in enumerate(test_loader):
idx = i
img, label = data
input = img.cuda()
label = label.cuda()
_, attention_maps, bilinear_features, output1 = model(input)
erase_img = attention_erase(attention_maps, input)
_, _, _, output2 = model(erase_img)
loss1 = criterion(output1, label)
loss2 = criterion(output2, label)
features = bilinear_features.reshape(bilinear_features.shape[0], -1)/100
center_loss, _ = Center_Loss(features, center, label)
loss = (loss1 + loss2)/2 + center_loss
test_loss += loss.item()
_, predicted = torch.max(output1.data, 1)
total += label.size(0)
correct += predicted.eq(label.data).cpu().sum()
if i % 50 == 0:
print('Step: %d | Loss: %.3f | Acc: %.3f%% (%d/%d)' % (
i, test_loss / (i + 1), 100. * float(correct) / total, correct, total))
test_acc = 100. * float(correct) / total
test_loss = test_loss / (idx + 1)
return test_acc, test_loss