欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch 计算模型的参数量

程序员文章站 2022-03-20 12:11:36
...

参考于:https://blog.csdn.net/jdzwanghao/article/details/84196239

def model_structure(model):
    blank = ' '
    print('-'*90)
    print('|'+' '*11+'weight name'+' '*10+'|' \
            +' '*15+'weight shape'+' '*15+'|' \
            +' '*3+'number'+' '*3+'|')
    print('-'*90)
    num_para = 0
    type_size = 1  ##如果是浮点数就是4
    
    for index, (key, w_variable) in enumerate(model.named_parameters()):
        if len(key) <= 30: 
            key = key + (30-len(key)) * blank
        shape = str(w_variable.shape)
        if len(shape) <= 40:
            shape = shape + (40-len(shape)) * blank
        each_para = 1
        for k in w_variable.shape:
            each_para *= k
        num_para += each_para
        str_num = str(each_para)
        if len(str_num) <= 10:
            str_num = str_num + (10-len(str_num)) * blank
    
        print('| {} | {} | {} |'.format(key, shape, str_num))
    print('-'*90)
    print('The total number of parameters: ' + str(num_para))
    print('The parameters of Model {}: {:4f}M'.format(model._get_name(), num_para * type_size / 1000 / 1000))
    print('-'*90)

结果:

------------------------------------------------------------------------------------------
|           weight name          |               weight shape               |   number   |
------------------------------------------------------------------------------------------
| embed_in.0.weight              | torch.Size([28, 1, 1, 5, 5])             | 700        |
| embed_in.0.bias                | torch.Size([28])                         | 28         |
| downC.0.block1.0.weight        | torch.Size([28, 28, 1, 3, 3])            | 7056       |
| downC.0.block1.1.weight        | torch.Size([28])                         | 28         |
| downC.0.block1.1.bias          | torch.Size([28])                         | 28         |
| downC.0.block2.0.weight        | torch.Size([28, 28, 3, 3, 3])            | 21168      |
| downC.0.block2.1.weight        | torch.Size([28])                         | 28         |
| downC.0.block2.1.bias          | torch.Size([28])                         | 28         |
| downC.0.block2.3.weight        | torch.Size([28, 28, 3, 3, 3])            | 21168      |
| downC.0.block3.weight          | torch.Size([28])                         | 28         |
| downC.0.block3.bias            | torch.Size([28])                         | 28         |
| downC.1.block1.0.weight        | torch.Size([36, 28, 1, 3, 3])            | 9072       |
| downC.1.block1.1.weight        | torch.Size([36])                         | 36         |
| downC.1.block1.1.bias          | torch.Size([36])                         | 36         |
| downC.1.block2.0.weight        | torch.Size([36, 36, 3, 3, 3])            | 34992      |
| downC.1.block2.1.weight        | torch.Size([36])                         | 36         |
| downC.1.block2.1.bias          | torch.Size([36])                         | 36         |
| downC.1.block2.3.weight        | torch.Size([36, 36, 3, 3, 3])            | 34992      |
| downC.1.block3.weight          | torch.Size([36])                         | 36         |
| downC.1.block3.bias            | torch.Size([36])                         | 36         |
| downC.2.block1.0.weight        | torch.Size([48, 36, 1, 3, 3])            | 15552      |
| downC.2.block1.1.weight        | torch.Size([48])                         | 48         |
| downC.2.block1.1.bias          | torch.Size([48])                         | 48         |
| downC.2.block2.0.weight        | torch.Size([48, 48, 3, 3, 3])            | 62208      |
| downC.2.block2.1.weight        | torch.Size([48])                         | 48         |
| downC.2.block2.1.bias          | torch.Size([48])                         | 48         |
| downC.2.block2.3.weight        | torch.Size([48, 48, 3, 3, 3])            | 62208      |
| downC.2.block3.weight          | torch.Size([48])                         | 48         |
| downC.2.block3.bias            | torch.Size([48])                         | 48         |
| center.block1.0.weight         | torch.Size([64, 48, 1, 3, 3])            | 27648      |
| center.block1.1.weight         | torch.Size([64])                         | 64         |
| center.block1.1.bias           | torch.Size([64])                         | 64         |
| center.block2.0.weight         | torch.Size([64, 64, 3, 3, 3])            | 110592     |
| center.block2.1.weight         | torch.Size([64])                         | 64         |
| center.block2.1.bias           | torch.Size([64])                         | 64         |
| center.block2.3.weight         | torch.Size([64, 64, 3, 3, 3])            | 110592     |
| center.block3.weight           | torch.Size([64])                         | 64         |
| center.block3.bias             | torch.Size([64])                         | 64         |
| upS.0.1.weight                 | torch.Size([48, 64, 1, 1, 1])            | 3072       |
| upS.0.1.bias                   | torch.Size([48])                         | 48         |
| upS.1.1.weight                 | torch.Size([36, 48, 1, 1, 1])            | 1728       |
| upS.1.1.bias                   | torch.Size([36])                         | 36         |
| upS.2.1.weight                 | torch.Size([28, 36, 1, 1, 1])            | 1008       |
| upS.2.1.bias                   | torch.Size([28])                         | 28         |
| upC.0.block1.0.weight          | torch.Size([48, 48, 1, 3, 3])            | 20736      |
| upC.0.block1.1.weight          | torch.Size([48])                         | 48         |
| upC.0.block1.1.bias            | torch.Size([48])                         | 48         |
| upC.0.block2.0.weight          | torch.Size([48, 48, 3, 3, 3])            | 62208      |
| upC.0.block2.1.weight          | torch.Size([48])                         | 48         |
| upC.0.block2.1.bias            | torch.Size([48])                         | 48         |
| upC.0.block2.3.weight          | torch.Size([48, 48, 3, 3, 3])            | 62208      |
| upC.0.block3.weight            | torch.Size([48])                         | 48         |
| upC.0.block3.bias              | torch.Size([48])                         | 48         |
| upC.1.block1.0.weight          | torch.Size([36, 36, 1, 3, 3])            | 11664      |
| upC.1.block1.1.weight          | torch.Size([36])                         | 36         |
| upC.1.block1.1.bias            | torch.Size([36])                         | 36         |
| upC.1.block2.0.weight          | torch.Size([36, 36, 3, 3, 3])            | 34992      |
| upC.1.block2.1.weight          | torch.Size([36])                         | 36         |
| upC.1.block2.1.bias            | torch.Size([36])                         | 36         |
| upC.1.block2.3.weight          | torch.Size([36, 36, 3, 3, 3])            | 34992      |
| upC.1.block3.weight            | torch.Size([36])                         | 36         |
| upC.1.block3.bias              | torch.Size([36])                         | 36         |
| upC.2.block1.0.weight          | torch.Size([28, 28, 1, 3, 3])            | 7056       |
| upC.2.block1.1.weight          | torch.Size([28])                         | 28         |
| upC.2.block1.1.bias            | torch.Size([28])                         | 28         |
| upC.2.block2.0.weight          | torch.Size([28, 28, 3, 3, 3])            | 21168      |
| upC.2.block2.1.weight          | torch.Size([28])                         | 28         |
| upC.2.block2.1.bias            | torch.Size([28])                         | 28         |
| upC.2.block2.3.weight          | torch.Size([28, 28, 3, 3, 3])            | 21168      |
| upC.2.block3.weight            | torch.Size([28])                         | 28         |
| upC.2.block3.bias              | torch.Size([28])                         | 28         |
| embed_out.0.weight             | torch.Size([28, 28, 1, 5, 5])            | 19600      |
| embed_out.0.bias               | torch.Size([28])                         | 28         |
| out_affs_2.0.weight            | torch.Size([3, 28, 1, 1, 1])             | 84         |
| out_affs_2.0.bias              | torch.Size([3])                          | 3          |
------------------------------------------------------------------------------------------
The total number of parameters: 821531
The parameters of Model UNet_PNI: 0.821531M
------------------------------------------------------------------------------------------