pytorch中的nn.module类常见成员函数
文章目录
参考
【PyTorch】torch.nn.Module 源码分析
多达48个函数,这里简单记录一下常见函数的作用
先创建一个Module,以这个为例
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
net = Model()
一、cpu(self)
将所有的参数和内存放在cpu上
net.cpu() # 将所有的参数和内存放在cpu上
二、cuda(self, device=None)
将所有的参数和内存放在gpu上
net.gpu("cuda:0") # 将所有的参数和内存放在gpu上
三、apply(self, fn)
将Module及其所有的SubModule传进给定的fn函数操作一遍。举个例子,我们可以用这个函数来对Module的网络模型参数用指定的方法初始化。
def init_weights(m): # 将所有子模型的linear参数赋值为1
print(m)
if type(m) == nn.Linear:
m.weight.data.fill_(1.0)
print(m.weight)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)
四、type(self, dst_type)
type函数是将所有parameters和buffers都转成指定的目标类型dst_type
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.type(dst_type=torch.float16)
for model in net.children():
print(model.weight.data)
五、float(self)、double(self)、half(self)、bfloat16(self)
float、double和half这三个函数是将所有floating point parameters分别转成float datatype、double datatype和half datatype。torch.Tensor.float即torch.float32;torch.Tensor.double即torch.float64;torch.Tensor.half即torch.float16。
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.half()
for model in net.children():
print(model.weight.data)
六、to(self, *args, **kwargs)
函数to的作用是原地 ( in-place ) 修改Module,它可以当成三种函数来使用:function:: to(device=None, dtype=None, non_blocking=False); function:: to(dtype, non_blocking=False); function:: to(tensor, non_blocking=False)。下边展示的是使用方法。
这里直接拷贝官方例子
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
[-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
[-0.5113, -0.2325]], dtype=torch.float64)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
[-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
[-0.5112, -0.2324]], dtype=torch.float16)
七、state_dict(self, destination=None, prefix=’’, keep_vars=False)
函数state_dict的作用是返回一个包含module的所有state的dictionary,而这个字典的Keys对应的就是parameter和buffer的名字names。该函数的源码部分有一个循环可以递归遍历Module中所有的SubModule。
>>> net = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 2))
>>> net.state_dict()
OrderedDict([('0.weight', tensor([[ 0.4792, 0.5772], [ 0.1039, -0.0552]])),
('0.bias', tensor([-0.5175, -0.6469])),
('1.weight', tensor([[-0.5346, -0.0173], [-0.2092, 0.0794]])),
('1.bias', tensor([-0.2150, 0.2323]))])
>>> net.state_dict().keys()
odict_keys(['0.weight', '0.bias', '1.weight', '1.bias'])
八、def train(self, mode=True)和 eval(self)
函数train和函数eval的作用是将Module及其SubModule分别设置为training mode和evaluation mode。这两个函数只对特定的Module有影响,例如Class Dropout、Class BatchNorm。
上一篇: 用单元格格式设置功能让Excel待统计的单元格显示为横杠
下一篇: Python3 函数