pytorch中的 nn.ModuleList 和 nn.Sequential
程序员文章站
2022-06-12 22:43:42
...
nn.ModuleList 和 nn.Sequential都是用来组合深度网络中的nn.Module/block,从而创建一个新的网络用的,
但在使用上有所差异。
nn.Sequential 可以让你创作一个新的网络,只需要按照顺序排列它们即可,如下:
class Flatten(nn.Module):
def forward(self, x):
N, C, H, W = x.size() # read in N, C, H, W
return x.view(N, -1)
simple_cnn = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=7, stride=2),
nn.ReLU(inplace=True),
Flatten(),
nn.Linear(5408, 10),
)
nn.ModuleList则像是一个python里的list,可以把nn.Module向元素一样添加到list里面,从而构建一个新的网络,
其优点就是,这个网络可以动态构建,如下:
class LinearNet(nn.Module):
def __init__(self, input_size, num_layers, layers_size, output_size):
super(LinearNet, self).__init__()
self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
self.linears.append(nn.Linear(layers_size, output_size)