欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

pytorch如何使用自带的模型剪枝工具prune

程序员文章站 2022-03-11 21:12:11
torch.nn.utils.prune可以对模型进行剪枝,官方指导如下:https://pytorch.org/tutorials/intermediate/pruning_tutorial.html直接上代码首先建立模型网络:import torchimport torch.nn as nnfrom torchsummary import summarydevice = torch.device("cuda" if torch.cuda.is_available() else "...

torch.nn.utils.prune可以对模型进行剪枝,官方指导如下:

https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

直接上代码

首先建立模型网络:

import torch
import torch.nn as nn
from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SimpleNet(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.conv3 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(in_features=16 * 16 * 24, out_features=num_classes)
    def forward(self, input):
        output = self.conv1(input)
        output = nn.ReLU()(output)
        output = self.conv2(output)
        output = nn.ReLU()(output)
        output = self.pool(output)
        output = self.conv3(output)
        output = nn.ReLU()(output)
        output = self.conv4(output)
        output = nn.ReLU()(output)
        output = output.view(-1, 16 * 16 * 24)
        output = self.fc(output)
        return output
model = SimpleNet().to(device=device)

看一下模型的 summary

summary(model, input_size=(3, 512, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 12, 512, 512]             336
            Conv2d-2         [-1, 12, 512, 512]           1,308
         MaxPool2d-3         [-1, 12, 256, 256]               0
            Conv2d-4         [-1, 24, 256, 256]           2,616
            Conv2d-5         [-1, 24, 256, 256]           5,208
            Linear-6                   [-1, 10]          61,450
================================================================
Total params: 70,918
Trainable params: 70,918
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.00
Forward/backward pass size (MB): 78.00
Params size (MB): 0.27
Estimated Total Size (MB): 81.27
----------------------------------------------------------------

打印一下模型结构各层的名称:

print(model.state_dict().keys())

结果:

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'conv3.weight', 'conv3.bias', 'conv4.weight', 'conv4.bias', 'fc.weight', 'fc.bias'])

接下来 对其进行剪枝操作:

import torch.nn.utils.prune as prune
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.conv4, 'weight'),
    (model.fc, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

执行结束后,再打印一下:

print(model.state_dict().keys())

结果:

odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.bias', 'conv2.weight_orig', 'conv2.weight_mask', 'conv3.weight', 'conv3.bias', 'conv4.bias', 'conv4.weight_orig', 'conv4.weight_mask', 'fc.bias', 'fc.weight_orig', 'fc.weight_mask'])

我们发现剪枝结束后conv*.weight已经 消失了,出现了两个weight:weight_orig和weight_mask

其实weight_orig就是剪枝以前的weight,而weight_mask里面 只是0和1,0代表的是被剪枝的

打印一下:

print(model.state_dict()['conv1.weight_orig'])

tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [0., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[0., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 0.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 0.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 0.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 0.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]], device='cuda:0')
prune.remove(module, 

剪枝后,其实还是比较鸡肋的,因为只是剪之后的神经元相当于置零了,模型大小不会变,下面打印一下,有点dropout的意思了

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 12, 512, 512]             336
            Conv2d-2         [-1, 12, 512, 512]           1,308
         MaxPool2d-3         [-1, 12, 256, 256]               0
            Conv2d-4         [-1, 24, 256, 256]           2,616
            Conv2d-5         [-1, 24, 256, 256]           5,208
            Linear-6                   [-1, 10]          61,450
================================================================
Total params: 70,918
Trainable params: 70,918
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.00
Forward/backward pass size (MB): 78.00
Params size (MB): 0.27
Estimated Total Size (MB): 81.27
----------------------------------------------------------------

是不是和剪枝之前实际上是一样的,可能会减少运算

本文地址:https://blog.csdn.net/zhou_438/article/details/109053992

相关标签: pytorch prune