import torch
import torch.nn as nn
import torch_pruning as tp
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
class CMUNeXtBlock(nn.Module):
def __init__(self, ch_in, ch_out, depth=1, k=3):
super(CMUNeXtBlock, self).__init__()
self.block = nn.Sequential(
*[nn.Sequential(
Residual(nn.Sequential(
# deep wise
nn.Conv2d(ch_in, ch_in, kernel_size=(k, k), groups=ch_in, padding=(k // 2, k // 2)),
nn.ReLU(),
nn.BatchNorm2d(ch_in)
)),
nn.Conv2d(ch_in, ch_in * 4, kernel_size=(1, 1)),
nn.ReLU(),
nn.BatchNorm2d(ch_in * 4),
nn.Conv2d(ch_in * 4, ch_in, kernel_size=(1, 1)),
nn.ReLU(),
nn.BatchNorm2d(ch_in)
) for i in range(depth)]
)
self.up = conv_block(ch_in, ch_out)
def forward(self, x):
x = self.block(x)
x = self.up(x)
return x
class conv_block(nn.Module):
def __init__(self, ch_in, ch_out):
super(conv_block, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class up_conv(nn.Module):
def __init__(self, ch_in, ch_out):
super(up_conv, self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.up(x)
return x
class fusion_conv(nn.Module):
def __init__(self, ch_in, ch_out):
super(fusion_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_in, kernel_size=3, stride=1, padding=1, groups=2, bias=True),
nn.ReLU(),
nn.BatchNorm2d(ch_in),
nn.Conv2d(ch_in, ch_out * 4, kernel_size=(1, 1)),
nn.ReLU(),
nn.BatchNorm2d(ch_out * 4),
nn.Conv2d(ch_out * 4, ch_out, kernel_size=(1, 1)),
nn.ReLU(),
nn.BatchNorm2d(ch_out)
)
def forward(self, x):
x = self.conv(x)
return x
class CMUNeXt(nn.Module):
def __init__(self, input_channel=3, num_classes=1, dims=[16, 32, 128, 160, 256], depths=[1, 1, 1, 3, 1], kernels=[3, 3, 7, 7, 7]):
"""
Args:
input_channel : input channel.
num_classes: output channel.
dims: length of channels
depths: length of cmunext blocks
kernels: kernal size of cmunext blocks
"""
super(CMUNeXt, self).__init__()
# Encoder
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.stem = conv_block(ch_in=input_channel, ch_out=dims[0])
self.encoder1 = CMUNeXtBlock(ch_in=dims[0], ch_out=dims[0], depth=depths[0], k=kernels[0])
self.encoder2 = CMUNeXtBlock(ch_in=dims[0], ch_out=dims[1], depth=depths[1], k=kernels[1])
self.encoder3 = CMUNeXtBlock(ch_in=dims[1], ch_out=dims[2], depth=depths[2], k=kernels[2])
self.encoder4 = CMUNeXtBlock(ch_in=dims[2], ch_out=dims[3], depth=depths[3], k=kernels[3])
self.encoder5 = CMUNeXtBlock(ch_in=dims[3], ch_out=dims[4], depth=depths[4], k=kernels[4])
# Decoder
self.Up5 = up_conv(ch_in=dims[4], ch_out=dims[3])
self.Up_conv5 = fusion_conv(ch_in=dims[3] * 2, ch_out=dims[3])
self.Up4 = up_conv(ch_in=dims[3], ch_out=dims[2])
self.Up_conv4 = fusion_conv(ch_in=dims[2] * 2, ch_out=dims[2])
self.Up3 = up_conv(ch_in=dims[2], ch_out=dims[1])
self.Up_conv3 = fusion_conv(ch_in=dims[1] * 2, ch_out=dims[1])
self.Up2 = up_conv(ch_in=dims[1], ch_out=dims[0])
self.Up_conv2 = fusion_conv(ch_in=dims[0] * 2, ch_out=dims[0])
self.Conv_1x1 = nn.Conv2d(dims[0], num_classes, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x1 = self.stem(x)
x1 = self.encoder1(x1)
x2 = self.Maxpool(x1)
x2 = self.encoder2(x2)
x3 = self.Maxpool(x2)
x3 = self.encoder3(x3)
x4 = self.Maxpool(x3)
x4 = self.encoder4(x4)
x5 = self.Maxpool(x4)
x5 = self.encoder5(x5)
d5 = self.Up5(x5)
d5 = torch.cat((x4, d5), dim=1)
d5 = self.Up_conv5(d5)
d4 = self.Up4(d5)
d4 = torch.cat((x3, d4), dim=1)
d4 = self.Up_conv4(d4)
d3 = self.Up3(d4)
d3 = torch.cat((x2, d3), dim=1)
d3 = self.Up_conv3(d3)
d2 = self.Up2(d3)
d2 = torch.cat((x1, d2), dim=1)
d2 = self.Up_conv2(d2)
d1 = self.Conv_1x1(d2)
return d1
def cmunext(dims=[16, 32, 128, 160, 256], depths=[1, 1, 1, 3, 1], kernels=[3, 3, 7, 7, 7]):
return CMUNeXt(dims=dims, depths=depths, kernels=kernels)
def cmunext_s(dims=[8, 16, 32, 64, 128], depths=[1, 1, 1, 1, 1], kernels=[3, 3, 7, 7, 9]):
return CMUNeXt(dims=dims, depths=depths, kernels=kernels)
def cmunext_l(dims=[32, 64, 128, 256, 512], depths=[1, 1, 1, 6, 3], kernels=[3, 3, 7, 7, 7]):
return CMUNeXt(dims=dims, depths=depths, kernels=kernels)
if __name__ == '__main__':
model = cmunext()
example_inputs = torch.rand(1,3,512,512)
out = model(example_inputs)
print(out.size())
imp = tp.importance.GroupMagnitudeImportance(p=2)
ignored_layers = []
pruner = tp.pruner.BasePruner( # We can always choose BasePruner if sparse training is not required.
model,
example_inputs,
importance=imp,
pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
# pruning_ratio_dict = {model.conv1: 0.2, model.layer2: 0.8}, # customized pruning ratios for layers or blocks
ignored_layers=ignored_layers,
round_to=8, # It's recommended to round dims/channels to 4x or 8x for acceleration. Please see: https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html
)
# 3. Prune the model
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
tp.utils.print_tool.before_pruning(model) # or print(model)
pruner.step()
tp.utils.print_tool.after_pruning(model) # or print(model), this util will show the difference before and after pruning
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")
the above code can reproduce the issue.
the above code can reproduce the issue.
my env: