pytorch 计算模型的参数量
程序员文章站
2022-03-20 12:11:36
...
参考于:https://blog.csdn.net/jdzwanghao/article/details/84196239
def model_structure(model):
blank = ' '
print('-'*90)
print('|'+' '*11+'weight name'+' '*10+'|' \
+' '*15+'weight shape'+' '*15+'|' \
+' '*3+'number'+' '*3+'|')
print('-'*90)
num_para = 0
type_size = 1 ##如果是浮点数就是4
for index, (key, w_variable) in enumerate(model.named_parameters()):
if len(key) <= 30:
key = key + (30-len(key)) * blank
shape = str(w_variable.shape)
if len(shape) <= 40:
shape = shape + (40-len(shape)) * blank
each_para = 1
for k in w_variable.shape:
each_para *= k
num_para += each_para
str_num = str(each_para)
if len(str_num) <= 10:
str_num = str_num + (10-len(str_num)) * blank
print('| {} | {} | {} |'.format(key, shape, str_num))
print('-'*90)
print('The total number of parameters: ' + str(num_para))
print('The parameters of Model {}: {:4f}M'.format(model._get_name(), num_para * type_size / 1000 / 1000))
print('-'*90)
结果:
------------------------------------------------------------------------------------------
| weight name | weight shape | number |
------------------------------------------------------------------------------------------
| embed_in.0.weight | torch.Size([28, 1, 1, 5, 5]) | 700 |
| embed_in.0.bias | torch.Size([28]) | 28 |
| downC.0.block1.0.weight | torch.Size([28, 28, 1, 3, 3]) | 7056 |
| downC.0.block1.1.weight | torch.Size([28]) | 28 |
| downC.0.block1.1.bias | torch.Size([28]) | 28 |
| downC.0.block2.0.weight | torch.Size([28, 28, 3, 3, 3]) | 21168 |
| downC.0.block2.1.weight | torch.Size([28]) | 28 |
| downC.0.block2.1.bias | torch.Size([28]) | 28 |
| downC.0.block2.3.weight | torch.Size([28, 28, 3, 3, 3]) | 21168 |
| downC.0.block3.weight | torch.Size([28]) | 28 |
| downC.0.block3.bias | torch.Size([28]) | 28 |
| downC.1.block1.0.weight | torch.Size([36, 28, 1, 3, 3]) | 9072 |
| downC.1.block1.1.weight | torch.Size([36]) | 36 |
| downC.1.block1.1.bias | torch.Size([36]) | 36 |
| downC.1.block2.0.weight | torch.Size([36, 36, 3, 3, 3]) | 34992 |
| downC.1.block2.1.weight | torch.Size([36]) | 36 |
| downC.1.block2.1.bias | torch.Size([36]) | 36 |
| downC.1.block2.3.weight | torch.Size([36, 36, 3, 3, 3]) | 34992 |
| downC.1.block3.weight | torch.Size([36]) | 36 |
| downC.1.block3.bias | torch.Size([36]) | 36 |
| downC.2.block1.0.weight | torch.Size([48, 36, 1, 3, 3]) | 15552 |
| downC.2.block1.1.weight | torch.Size([48]) | 48 |
| downC.2.block1.1.bias | torch.Size([48]) | 48 |
| downC.2.block2.0.weight | torch.Size([48, 48, 3, 3, 3]) | 62208 |
| downC.2.block2.1.weight | torch.Size([48]) | 48 |
| downC.2.block2.1.bias | torch.Size([48]) | 48 |
| downC.2.block2.3.weight | torch.Size([48, 48, 3, 3, 3]) | 62208 |
| downC.2.block3.weight | torch.Size([48]) | 48 |
| downC.2.block3.bias | torch.Size([48]) | 48 |
| center.block1.0.weight | torch.Size([64, 48, 1, 3, 3]) | 27648 |
| center.block1.1.weight | torch.Size([64]) | 64 |
| center.block1.1.bias | torch.Size([64]) | 64 |
| center.block2.0.weight | torch.Size([64, 64, 3, 3, 3]) | 110592 |
| center.block2.1.weight | torch.Size([64]) | 64 |
| center.block2.1.bias | torch.Size([64]) | 64 |
| center.block2.3.weight | torch.Size([64, 64, 3, 3, 3]) | 110592 |
| center.block3.weight | torch.Size([64]) | 64 |
| center.block3.bias | torch.Size([64]) | 64 |
| upS.0.1.weight | torch.Size([48, 64, 1, 1, 1]) | 3072 |
| upS.0.1.bias | torch.Size([48]) | 48 |
| upS.1.1.weight | torch.Size([36, 48, 1, 1, 1]) | 1728 |
| upS.1.1.bias | torch.Size([36]) | 36 |
| upS.2.1.weight | torch.Size([28, 36, 1, 1, 1]) | 1008 |
| upS.2.1.bias | torch.Size([28]) | 28 |
| upC.0.block1.0.weight | torch.Size([48, 48, 1, 3, 3]) | 20736 |
| upC.0.block1.1.weight | torch.Size([48]) | 48 |
| upC.0.block1.1.bias | torch.Size([48]) | 48 |
| upC.0.block2.0.weight | torch.Size([48, 48, 3, 3, 3]) | 62208 |
| upC.0.block2.1.weight | torch.Size([48]) | 48 |
| upC.0.block2.1.bias | torch.Size([48]) | 48 |
| upC.0.block2.3.weight | torch.Size([48, 48, 3, 3, 3]) | 62208 |
| upC.0.block3.weight | torch.Size([48]) | 48 |
| upC.0.block3.bias | torch.Size([48]) | 48 |
| upC.1.block1.0.weight | torch.Size([36, 36, 1, 3, 3]) | 11664 |
| upC.1.block1.1.weight | torch.Size([36]) | 36 |
| upC.1.block1.1.bias | torch.Size([36]) | 36 |
| upC.1.block2.0.weight | torch.Size([36, 36, 3, 3, 3]) | 34992 |
| upC.1.block2.1.weight | torch.Size([36]) | 36 |
| upC.1.block2.1.bias | torch.Size([36]) | 36 |
| upC.1.block2.3.weight | torch.Size([36, 36, 3, 3, 3]) | 34992 |
| upC.1.block3.weight | torch.Size([36]) | 36 |
| upC.1.block3.bias | torch.Size([36]) | 36 |
| upC.2.block1.0.weight | torch.Size([28, 28, 1, 3, 3]) | 7056 |
| upC.2.block1.1.weight | torch.Size([28]) | 28 |
| upC.2.block1.1.bias | torch.Size([28]) | 28 |
| upC.2.block2.0.weight | torch.Size([28, 28, 3, 3, 3]) | 21168 |
| upC.2.block2.1.weight | torch.Size([28]) | 28 |
| upC.2.block2.1.bias | torch.Size([28]) | 28 |
| upC.2.block2.3.weight | torch.Size([28, 28, 3, 3, 3]) | 21168 |
| upC.2.block3.weight | torch.Size([28]) | 28 |
| upC.2.block3.bias | torch.Size([28]) | 28 |
| embed_out.0.weight | torch.Size([28, 28, 1, 5, 5]) | 19600 |
| embed_out.0.bias | torch.Size([28]) | 28 |
| out_affs_2.0.weight | torch.Size([3, 28, 1, 1, 1]) | 84 |
| out_affs_2.0.bias | torch.Size([3]) | 3 |
------------------------------------------------------------------------------------------
The total number of parameters: 821531
The parameters of Model UNet_PNI: 0.821531M
------------------------------------------------------------------------------------------