Skip to content

Commit ee28a4e

Browse files
committed
transforming AE
0 parents  commit ee28a4e

12 files changed

Lines changed: 1049 additions & 0 deletions

File tree

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
images/Input/*
2+
images/output/*
3+
images/Target/*
4+
__pycache__
5+
util.py

CapLayer.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""
2+
Created on Tue May 15 02:19:37 2018
3+
4+
@author: rinzler
5+
"""
6+
7+
from Capsule import Capsule
8+
import torch.nn.functional as F
9+
import torch.nn as nn
10+
11+
class CapLayer(nn.Module):
12+
def __init__(self, num_caps, in_dim, cap_dim, gen_dim):
13+
super(CapLayer, self).__init__()
14+
self.caps = nn.ModuleList([
15+
Capsule(in_dim, cap_dim, gen_dim)
16+
for _ in range(num_caps)])
17+
# print(len(self.caps))
18+
def forward(self, X, delxy, sep = False):
19+
caps_out = [cap(X, delxy) for cap in self.caps]
20+
R = []
21+
for cap in self.caps:
22+
if not sep:
23+
x = cap(X,delxy)
24+
caps_out.append(x)
25+
else:
26+
x, y = cap(X, delxy, sep)
27+
caps_out.append(x)
28+
R.append(y)
29+
t = caps_out[0]
30+
if sep:
31+
r = R[0]
32+
for i in range(1, len(self.caps)):
33+
r = (R[i] + r )/2
34+
# print(t.size())
35+
for i in range(1, len(self.caps)):
36+
37+
t += caps_out[i]
38+
39+
if not sep:
40+
return F.sigmoid(t)
41+
return F.sigmoid(t), r

Capsule.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch.nn as nn
2+
import torch.nn.functional as F
3+
import torch
4+
'Defines one capsule'
5+
6+
class Capsule(nn.Module):
7+
def __init__(self, input_dim, cap_dim, gen_dim, xy_dim = 2):
8+
super(Capsule, self).__init__()
9+
self.indim = input_dim
10+
self.cpdim = cap_dim
11+
self.gndim = gen_dim
12+
self.xytrn = xy_dim
13+
self.cp = nn.Linear(self.indim, self.cpdim) #Recognizer units
14+
self.xy = nn.Linear(self.cpdim, self.xytrn) #estimates of the X and Y
15+
self.pr = nn.Linear(self.cpdim, 1) #prob of feature
16+
self.gn = nn.Linear(self.xytrn, self.gndim) #The generator
17+
self.rc = nn.Linear(self.gndim, self.indim) #The reconstructed image
18+
19+
def forward(self, X, delxy, sp = False):
20+
X = X.view(-1, 28*28)
21+
# print(X.size(), delxy.size())
22+
cap = F.sigmoid(self.cp(X))
23+
# print('cap', cap.size())
24+
x_y = self.xy(cap)
25+
# print('x_y', x_y.size())
26+
prb = self.pr(cap)
27+
# print('prb', prb.size())
28+
# print('x_y + del', (x_y + delxy).size())
29+
gen = self.gn(x_y + delxy)
30+
# print('gen', gen.size())
31+
rec = self.rc(gen)
32+
# print('rec',rec.size())
33+
# rec = rec.view(64, 1, 28, 28)
34+
# torch.matmul(rec,prb)
35+
if sp:
36+
return torch.mul(rec,prb), x_y
37+
else:
38+
return torch.mul(rec,prb)
39+

0 commit comments

Comments
 (0)