@@ -103,3 +103,87 @@ def forward(self, x, latent):
103103 for i in range (self .opt .num_layers ):
104104 x = self .activation (self .bn_list [i ](self .conv_list [i ](x )))
105105 return self .last_conv (x )
106+
107+
108+ # Modules from MSN: https://github.com/Colin97/MSN-Point-Cloud-Completion
109+ class PointNetfeat (nn .Module ):
110+ def __init__ (self , num_points = 8192 , global_feat = True ):
111+ super (PointNetfeat , self ).__init__ ()
112+ self .conv1 = torch .nn .Conv1d (3 , 64 , 1 )
113+ self .conv2 = torch .nn .Conv1d (64 , 128 , 1 )
114+ self .conv3 = torch .nn .Conv1d (128 , 1024 , 1 )
115+
116+ self .bn1 = torch .nn .BatchNorm1d (64 )
117+ self .bn2 = torch .nn .BatchNorm1d (128 )
118+ self .bn3 = torch .nn .BatchNorm1d (1024 )
119+
120+ self .num_points = num_points
121+ self .global_feat = global_feat
122+ def forward (self , x ):
123+ batchsize = x .size ()[0 ]
124+ x = F .relu (self .bn1 (self .conv1 (x )))
125+ x = F .relu (self .bn2 (self .conv2 (x )))
126+ x = self .bn3 (self .conv3 (x ))
127+ x ,_ = torch .max (x , 2 )
128+ x = x .view (- 1 , 1024 )
129+ return x
130+
131+ class PointGenCon (nn .Module ):
132+ def __init__ (self , bottleneck_size = 8192 ):
133+ self .bottleneck_size = bottleneck_size
134+ super (PointGenCon , self ).__init__ ()
135+ self .conv1 = torch .nn .Conv1d (self .bottleneck_size , self .bottleneck_size , 1 )
136+ self .conv2 = torch .nn .Conv1d (self .bottleneck_size , self .bottleneck_size // 2 , 1 )
137+ self .conv3 = torch .nn .Conv1d (self .bottleneck_size // 2 , self .bottleneck_size // 4 , 1 )
138+ self .conv4 = torch .nn .Conv1d (self .bottleneck_size // 4 , 3 , 1 )
139+
140+ self .th = nn .Tanh ()
141+ self .bn1 = torch .nn .BatchNorm1d (self .bottleneck_size )
142+ self .bn2 = torch .nn .BatchNorm1d (self .bottleneck_size // 2 )
143+ self .bn3 = torch .nn .BatchNorm1d (self .bottleneck_size // 4 )
144+
145+ def forward (self , x ):
146+ batchsize = x .size ()[0 ]
147+ x = F .relu (self .bn1 (self .conv1 (x )))
148+ x = F .relu (self .bn2 (self .conv2 (x )))
149+ x = F .relu (self .bn3 (self .conv3 (x )))
150+ x = self .th (self .conv4 (x ))
151+ return x
152+
153+ class PointNetRes (nn .Module ):
154+ def __init__ (self ):
155+ super (PointNetRes , self ).__init__ ()
156+ self .conv1 = torch .nn .Conv1d (4 , 64 , 1 )
157+ self .conv2 = torch .nn .Conv1d (64 , 128 , 1 )
158+ self .conv3 = torch .nn .Conv1d (128 , 1024 , 1 )
159+ self .conv4 = torch .nn .Conv1d (1088 , 512 , 1 )
160+ self .conv5 = torch .nn .Conv1d (512 , 256 , 1 )
161+ self .conv6 = torch .nn .Conv1d (256 , 128 , 1 )
162+ self .conv7 = torch .nn .Conv1d (128 , 3 , 1 )
163+
164+
165+ self .bn1 = torch .nn .BatchNorm1d (64 )
166+ self .bn2 = torch .nn .BatchNorm1d (128 )
167+ self .bn3 = torch .nn .BatchNorm1d (1024 )
168+ self .bn4 = torch .nn .BatchNorm1d (512 )
169+ self .bn5 = torch .nn .BatchNorm1d (256 )
170+ self .bn6 = torch .nn .BatchNorm1d (128 )
171+ self .bn7 = torch .nn .BatchNorm1d (3 )
172+ self .th = nn .Tanh ()
173+
174+ def forward (self , x ):
175+ batchsize = x .size ()[0 ]
176+ npoints = x .size ()[2 ]
177+ x = F .relu (self .bn1 (self .conv1 (x )))
178+ pointfeat = x
179+ x = F .relu (self .bn2 (self .conv2 (x )))
180+ x = self .bn3 (self .conv3 (x ))
181+ x ,_ = torch .max (x , 2 )
182+ x = x .view (- 1 , 1024 )
183+ x = x .view (- 1 , 1024 , 1 ).repeat (1 , 1 , npoints )
184+ x = torch .cat ([x , pointfeat ], 1 )
185+ x = F .relu (self .bn4 (self .conv4 (x )))
186+ x = F .relu (self .bn5 (self .conv5 (x )))
187+ x = F .relu (self .bn6 (self .conv6 (x )))
188+ x = self .th (self .conv7 (x ))
189+ return x
0 commit comments