欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch中的nn.module类常见成员函数

程序员文章站 2022-06-12 22:50:34
...


参考
【PyTorch】torch.nn.Module 源码分析

多达48个函数,这里简单记录一下常见函数的作用
先创建一个Module,以这个为例

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

net = Model()

一、cpu(self)

将所有的参数和内存放在cpu上

net.cpu()  # 将所有的参数和内存放在cpu上

二、cuda(self, device=None)

将所有的参数和内存放在gpu上

net.gpu("cuda:0")  # 将所有的参数和内存放在gpu上

三、apply(self, fn)

将Module及其所有的SubModule传进给定的fn函数操作一遍。举个例子,我们可以用这个函数来对Module的网络模型参数用指定的方法初始化。

def init_weights(m): # 将所有子模型的linear参数赋值为1
     print(m)
     if type(m) == nn.Linear:
        m.weight.data.fill_(1.0)
        print(m.weight)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)

四、type(self, dst_type)

type函数是将所有parameters和buffers都转成指定的目标类型dst_type

net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.type(dst_type=torch.float16)
for model in net.children():
    print(model.weight.data)

五、float(self)、double(self)、half(self)、bfloat16(self)

float、double和half这三个函数是将所有floating point parameters分别转成float datatype、double datatype和half datatype。torch.Tensor.float即torch.float32;torch.Tensor.double即torch.float64;torch.Tensor.half即torch.float16。

net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.half()
for model in net.children():
     print(model.weight.data)

六、to(self, *args, **kwargs)

函数to的作用是原地 ( in-place ) 修改Module,它可以当成三种函数来使用:function:: to(device=None, dtype=None, non_blocking=False); function:: to(dtype, non_blocking=False); function:: to(tensor, non_blocking=False)。下边展示的是使用方法。
这里直接拷贝官方例子

>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

七、state_dict(self, destination=None, prefix=’’, keep_vars=False)

函数state_dict的作用是返回一个包含module的所有state的dictionary,而这个字典的Keys对应的就是parameter和buffer的名字names。该函数的源码部分有一个循环可以递归遍历Module中所有的SubModule。

>>> net = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 2))
>>> net.state_dict()
OrderedDict([('0.weight', tensor([[ 0.4792,  0.5772], [ 0.1039, -0.0552]])), 
        ('0.bias', tensor([-0.5175, -0.6469])), 
        ('1.weight', tensor([[-0.5346, -0.0173], [-0.2092,  0.0794]])), 
        ('1.bias', tensor([-0.2150,  0.2323]))])
>>> net.state_dict().keys()
odict_keys(['0.weight', '0.bias', '1.weight', '1.bias'])

八、def train(self, mode=True)和 eval(self)

函数train和函数eval的作用是将Module及其SubModule分别设置为training mode和evaluation mode。这两个函数只对特定的Module有影响,例如Class Dropout、Class BatchNorm。

其他更多见
https://zhuanlan.zhihu.com/p/88712978