pytorch如何使用自带的模型剪枝工具prune
torch.nn.utils.prune可以对模型进行剪枝,官方指导如下:
https://pytorch.org/tutorials/intermediate/pruning_tutorial.html
直接上代码
首先建立模型网络:
import torch
import torch.nn as nn
from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SimpleNet(nn.Module):
def __init__(self, num_classes=10):
super(SimpleNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2)
self.conv3 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=3, stride=1, padding=1)
self.fc = nn.Linear(in_features=16 * 16 * 24, out_features=num_classes)
def forward(self, input):
output = self.conv1(input)
output = nn.ReLU()(output)
output = self.conv2(output)
output = nn.ReLU()(output)
output = self.pool(output)
output = self.conv3(output)
output = nn.ReLU()(output)
output = self.conv4(output)
output = nn.ReLU()(output)
output = output.view(-1, 16 * 16 * 24)
output = self.fc(output)
return output
model = SimpleNet().to(device=device)
看一下模型的 summary
summary(model, input_size=(3, 512, 512))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 12, 512, 512] 336
Conv2d-2 [-1, 12, 512, 512] 1,308
MaxPool2d-3 [-1, 12, 256, 256] 0
Conv2d-4 [-1, 24, 256, 256] 2,616
Conv2d-5 [-1, 24, 256, 256] 5,208
Linear-6 [-1, 10] 61,450
================================================================
Total params: 70,918
Trainable params: 70,918
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.00
Forward/backward pass size (MB): 78.00
Params size (MB): 0.27
Estimated Total Size (MB): 81.27
----------------------------------------------------------------
打印一下模型结构各层的名称:
print(model.state_dict().keys())
结果:
odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'conv3.weight', 'conv3.bias', 'conv4.weight', 'conv4.bias', 'fc.weight', 'fc.bias'])
接下来 对其进行剪枝操作:
import torch.nn.utils.prune as prune
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.conv4, 'weight'),
(model.fc, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
执行结束后,再打印一下:
print(model.state_dict().keys())
结果:
odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.bias', 'conv2.weight_orig', 'conv2.weight_mask', 'conv3.weight', 'conv3.bias', 'conv4.bias', 'conv4.weight_orig', 'conv4.weight_mask', 'fc.bias', 'fc.weight_orig', 'fc.weight_mask'])
我们发现剪枝结束后conv*.weight已经 消失了,出现了两个weight:weight_orig和weight_mask
其实weight_orig就是剪枝以前的weight,而weight_mask里面 只是0和1,0代表的是被剪枝的
打印一下:
print(model.state_dict()['conv1.weight_orig'])
tensor([[[[1., 1., 1.],
[1., 1., 1.],
[0., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[0., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 0.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 0.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 0.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 0.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]],
[[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]]], device='cuda:0')
prune.remove(module,
剪枝后,其实还是比较鸡肋的,因为只是剪之后的神经元相当于置零了,模型大小不会变,下面打印一下,有点dropout的意思了
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 12, 512, 512] 336
Conv2d-2 [-1, 12, 512, 512] 1,308
MaxPool2d-3 [-1, 12, 256, 256] 0
Conv2d-4 [-1, 24, 256, 256] 2,616
Conv2d-5 [-1, 24, 256, 256] 5,208
Linear-6 [-1, 10] 61,450
================================================================
Total params: 70,918
Trainable params: 70,918
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.00
Forward/backward pass size (MB): 78.00
Params size (MB): 0.27
Estimated Total Size (MB): 81.27
----------------------------------------------------------------
是不是和剪枝之前实际上是一样的,可能会减少运算
本文地址:https://blog.csdn.net/zhou_438/article/details/109053992
上一篇: CHERRY KC 6000 SLIM键盘亮相:300元
下一篇: js中的四种继承方式
推荐阅读
-
如何使用谷歌浏览器自带的调试工具?chrome自带调试工具使用方法实例
-
mac自带的画图工具在哪?如何使用苹果电脑自带的预览工具进行画图操作
-
win10如何清理C盘垃圾文件 系统自带磁盘清理工具的使用教程
-
pytorch如何使用自带的模型剪枝工具prune
-
【NLP】torch hub工具的使用:torch.hub.load、pytorch预训练模型加载、
-
win10如何清理C盘垃圾文件 系统自带磁盘清理工具的使用教程
-
如何使用谷歌浏览器自带的调试工具?chrome自带调试工具使用方法实例
-
mac自带的画图工具在哪?如何使用苹果电脑自带的预览工具进行画图操作
-
pytorch如何使用自带的模型剪枝工具prune