欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch中的 nn.ModuleList 和 nn.Sequential

程序员文章站 2022-06-12 22:43:42
...

nn.ModuleList 和 nn.Sequential都是用来组合深度网络中的nn.Module/block,从而创建一个新的网络用的,

但在使用上有所差异。

nn.Sequential 可以让你创作一个新的网络,只需要按照顺序排列它们即可,如下:

class Flatten(nn.Module):
  def forward(self, x):
    N, C, H, W = x.size() # read in N, C, H, W
    return x.view(N, -1)

simple_cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=7, stride=2),
            nn.ReLU(inplace=True),
            Flatten(), 
            nn.Linear(5408, 10),
          )

nn.ModuleList则像是一个python里的list,可以把nn.Module向元素一样添加到list里面,从而构建一个新的网络,

其优点就是,这个网络可以动态构建,如下:

class LinearNet(nn.Module):
  def __init__(self, input_size, num_layers, layers_size, output_size):
     super(LinearNet, self).__init__()

     self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
     self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
     self.linears.append(nn.Linear(layers_size, output_size)

 

 

 

相关标签: 深度学习算法