its seems that the dimension is mismatched in class Mask():
mask_weight --> cls+bbox -->dim = (4+1)*class_num x 256
x = x.permute(0, 2, 3, 1) # dim = (B*rois.size()[0], 28, 28, 256)
op = torch.matmul(x.contiguous().view(-1,256), mask_weight.t())
# dim = (B*rois.size()[0]*28*28, 256) x (256, (4+1)*class_num) = (B*rois.size()[0]*28*28, 5*81)
op = op.view(x.shape[0], x.shape[1], x.shape[2], self.class_num)
# suppose to be dim = (B*rois.size()[0], 28, 28, class_num)
# however dim = (B*rois.size()[0]*5), 28, 28, class_num)
# the two are mismatch when w_det = cls+box, only correct when w_det is from cls
its seems that the dimension is mismatched in class Mask():
mask_weight --> cls+bbox -->dim = (4+1)*class_num x 256x = x.permute(0, 2, 3, 1) # dim = (B*rois.size()[0], 28, 28, 256)op = torch.matmul(x.contiguous().view(-1,256), mask_weight.t())# dim = (B*rois.size()[0]*28*28, 256) x (256, (4+1)*class_num) = (B*rois.size()[0]*28*28, 5*81)op = op.view(x.shape[0], x.shape[1], x.shape[2], self.class_num)# suppose to be dim = (B*rois.size()[0], 28, 28, class_num)# however dim = (B*rois.size()[0]*5), 28, 28, class_num)# the two are mismatch when w_det = cls+box, only correct when w_det is from cls