Skip to content

Commit 74ea46a

Browse files
committed
add MSN-PCN modules
1 parent cfaf61a commit 74ea46a

1 file changed

Lines changed: 84 additions & 0 deletions

File tree

model/model_blocks.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,87 @@ def forward(self, x, latent):
103103
for i in range(self.opt.num_layers):
104104
x = self.activation(self.bn_list[i](self.conv_list[i](x)))
105105
return self.last_conv(x)
106+
107+
108+
# Modules from MSN: https://github.com/Colin97/MSN-Point-Cloud-Completion
109+
class PointNetfeat(nn.Module):
110+
def __init__(self, num_points = 8192, global_feat = True):
111+
super(PointNetfeat, self).__init__()
112+
self.conv1 = torch.nn.Conv1d(3, 64, 1)
113+
self.conv2 = torch.nn.Conv1d(64, 128, 1)
114+
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
115+
116+
self.bn1 = torch.nn.BatchNorm1d(64)
117+
self.bn2 = torch.nn.BatchNorm1d(128)
118+
self.bn3 = torch.nn.BatchNorm1d(1024)
119+
120+
self.num_points = num_points
121+
self.global_feat = global_feat
122+
def forward(self, x):
123+
batchsize = x.size()[0]
124+
x = F.relu(self.bn1(self.conv1(x)))
125+
x = F.relu(self.bn2(self.conv2(x)))
126+
x = self.bn3(self.conv3(x))
127+
x,_ = torch.max(x, 2)
128+
x = x.view(-1, 1024)
129+
return x
130+
131+
class PointGenCon(nn.Module):
132+
def __init__(self, bottleneck_size = 8192):
133+
self.bottleneck_size = bottleneck_size
134+
super(PointGenCon, self).__init__()
135+
self.conv1 = torch.nn.Conv1d(self.bottleneck_size, self.bottleneck_size, 1)
136+
self.conv2 = torch.nn.Conv1d(self.bottleneck_size, self.bottleneck_size//2, 1)
137+
self.conv3 = torch.nn.Conv1d(self.bottleneck_size//2, self.bottleneck_size//4, 1)
138+
self.conv4 = torch.nn.Conv1d(self.bottleneck_size//4, 3, 1)
139+
140+
self.th = nn.Tanh()
141+
self.bn1 = torch.nn.BatchNorm1d(self.bottleneck_size)
142+
self.bn2 = torch.nn.BatchNorm1d(self.bottleneck_size//2)
143+
self.bn3 = torch.nn.BatchNorm1d(self.bottleneck_size//4)
144+
145+
def forward(self, x):
146+
batchsize = x.size()[0]
147+
x = F.relu(self.bn1(self.conv1(x)))
148+
x = F.relu(self.bn2(self.conv2(x)))
149+
x = F.relu(self.bn3(self.conv3(x)))
150+
x = self.th(self.conv4(x))
151+
return x
152+
153+
class PointNetRes(nn.Module):
154+
def __init__(self):
155+
super(PointNetRes, self).__init__()
156+
self.conv1 = torch.nn.Conv1d(4, 64, 1)
157+
self.conv2 = torch.nn.Conv1d(64, 128, 1)
158+
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
159+
self.conv4 = torch.nn.Conv1d(1088, 512, 1)
160+
self.conv5 = torch.nn.Conv1d(512, 256, 1)
161+
self.conv6 = torch.nn.Conv1d(256, 128, 1)
162+
self.conv7 = torch.nn.Conv1d(128, 3, 1)
163+
164+
165+
self.bn1 = torch.nn.BatchNorm1d(64)
166+
self.bn2 = torch.nn.BatchNorm1d(128)
167+
self.bn3 = torch.nn.BatchNorm1d(1024)
168+
self.bn4 = torch.nn.BatchNorm1d(512)
169+
self.bn5 = torch.nn.BatchNorm1d(256)
170+
self.bn6 = torch.nn.BatchNorm1d(128)
171+
self.bn7 = torch.nn.BatchNorm1d(3)
172+
self.th = nn.Tanh()
173+
174+
def forward(self, x):
175+
batchsize = x.size()[0]
176+
npoints = x.size()[2]
177+
x = F.relu(self.bn1(self.conv1(x)))
178+
pointfeat = x
179+
x = F.relu(self.bn2(self.conv2(x)))
180+
x = self.bn3(self.conv3(x))
181+
x,_ = torch.max(x, 2)
182+
x = x.view(-1, 1024)
183+
x = x.view(-1, 1024, 1).repeat(1, 1, npoints)
184+
x = torch.cat([x, pointfeat], 1)
185+
x = F.relu(self.bn4(self.conv4(x)))
186+
x = F.relu(self.bn5(self.conv5(x)))
187+
x = F.relu(self.bn6(self.conv6(x)))
188+
x = self.th(self.conv7(x))
189+
return x

0 commit comments

Comments
 (0)