-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodified_GAP.lua
More file actions
122 lines (110 loc) · 3.35 KB
/
modified_GAP.lua
File metadata and controls
122 lines (110 loc) · 3.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
require 'nn';
require 'image';
mnist = require 'mnist';
require 'optim';
require 'gnuplot';
model = torch.load('model_MNIST2.t7')
model2 = nn.Sequential()
model2:insert(model:get(13),1)
model:remove(13)
model2:insert(model:get(12),1)
model:remove(12)
model2:insert(model:get(11),1)
model:remove(11)
model2:insert(model:get(10),1)
model:remove(10)
model2:insert(model:get(9),1)
model:remove(9)
model2:insert(model:get(8),1)
model:remove(8)
model2:insert(model:get(7),1)
model:remove(7)
print(model)
print(model2)
checkOut = model:forward(torch.rand(1,28,28))
numFilters = checkOut:size(1)
trainData = mnist.traindataset().data:double():div(255):reshape(60000,1,28,28)
trainlabels = mnist.traindataset().label+1
trSize = mnist.traindataset().size
sorted,indices = torch.sort(trainlabels)
classes = {'0', '1', '2','3', '4','5', '6','7', '8','9'}
classSize = torch.Tensor(#classes):zero()
for i=1,trSize do
classSize[trainlabels[i]] = classSize[trainlabels[i]]+1
end
print(classSize)
print(classSize:sum())
criterion = nn.ClassNLLCriterion()
errorTensor = torch.Tensor(1+numFilters,#classes):zero()
batchSize = 200
indexDone=0
errorOccurred = 0
for c=1,#classes do
print('Starting Class '..c)
for t=1,classSize[c]/batchSize do
local input = torch.Tensor(batchSize,1,28,28)
local target = torch.Tensor(batchSize)
classTarget = c
for t1 = 1,batchSize do
local t2 = (t-1)*batchSize + t1 + indexDone
input[t1] = trainData[indices[t2]]
target[t1] = trainlabels[indices[t2]]
if target[t1]~=c then
print('error at c='..c..' t='..t..' t1='..t1)
errorOccurred=1
break;
end
end
if errorOccurred==1 then
break
end
local filters = model:forward(input)
local error_ = criterion:forward(model2:forward(filters),target)*batchSize
errorTensor[numFilters+1][classTarget] = errorTensor[numFilters+1][classTarget] + error_
for f=1,numFilters do
collectgarbage()
local filter_masked = torch.Tensor(filters:size()):copy(filters)
filter_masked[{{},{f},{},{}}]:zero()
local error_masked = criterion:forward(model2:forward(filter_masked),target)*batchSize
errorTensor[f][classTarget] = errorTensor[f][classTarget] + error_masked
end
end
if errorOccurred==1 then
break
end
if classSize[c]%batchSize~=0 then
local remaining = classSize[c]%batchSize
local input = torch.Tensor(remaining,1,28,28)
local target = torch.Tensor(remaining)
classTarget = c
for t1 = 1,remaining do
local t2 = classSize[c]-remaining + t1 + indexDone
input[t1] = trainData[indices[t2]]
target[t1] = trainlabels[indices[t2]]
if target[t1]~=c then
print('error at c='..c..' t='..t..' t1='..t1)
errorOccurred=1
break;
end
end
if errorOccurred==1 then
break
end
local filters = model:forward(input)
local error_ = criterion:forward(model2:forward(filters),target)*remaining
errorTensor[numFilters+1][classTarget] = errorTensor[numFilters+1][classTarget] + error_
for f=1,numFilters do
collectgarbage()
local filter_masked = torch.Tensor(filters:size()):copy(filters)
filter_masked[{{},{f},{},{}}]:zero()
local error_masked = criterion:forward(model2:forward(filter_masked),target)*remaining
errorTensor[f][classTarget] = errorTensor[f][classTarget] + error_masked
end
end
if errorOccurred==1 then
break
end
indexDone = indexDone+classSize[c]
--print(errorTensor)
end
torch.save('modifiedGAP_errTensor.t7',errorTensor)