Inception-ResNet模型框架(PyTorch)
程序员文章站
2024-01-28 10:00:58
...
I. 前言
在Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning一文中,除了提出Inception Network的v4版本,还与ResNet进行结合,提出了Inception-ResNet-v1及Inception-ResNet-v2两个模型.
II. 模型构架图
【注】Inception-ResNet-v1及Inception-ResNet-v2的总体构架一致,但各部分的结构不尽相同,现予以说明.
1. Inception-ResNet-v1
1.1 Stem
1.2 Inception-ResNet-A
1.3 Inception-ResNet-B
1.4 Inception-ResNet-C
1.5 Reduction-A
1.6 Reduction-B
2. Inception-ResNet-v2
2.1 Stem
2.2 Inception-ResNet-A
2.3 Inception-ResNet-B
2.4 Inception-ResNet-C
2.5 Reduction-A
2.6 Reduction-B
III. 代码复现
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = self.conv(x)
out = self.bn(out)
return out
class InceptionResNetA(nn.Module):
def __init__(self, in_channels):
super(InceptionResNetA, self).__init__()
#branch1: conv1*1(32)
self.b1 = BasicConv2d(in_channels, 32, kernel_size=1)
#branch2: conv1*1(32) --> con3*3(32)
self.b2_1 = BasicConv2d(in_channels, 32, kernel_size=1)
self.b2_2 = BasicConv2d(32, 32, kernel_size=3, padding=1)
#branch3: conv1*1(32) --> conv3*3(32) --> conv3*3(32)
self.b3_1 = BasicConv2d(in_channels, 32, kernel_size=1)
self.b3_2 = BasicConv2d(32, 32, kernel_size=3, padding=1)
self.b3_3 = BasicConv2d(32, 32, kernel_size=3, padding=1)
#totalbranch: conv1*1(256)
self.tb = BasicConv2d(96, 256, kernel_size=1)
def forward(self, x):
x = F.relu(x)
b_out1 = F.relu(self.b1(x))
b_out2 = F.relu(self.b2_2(F.relu(self.b2_1(x))))
b_out3 = F.relu(self.b3_3(F.relu(self.b3_2(F.relu(self.b3_1(x))))))
b_out = torch.cat([b_out1, b_out2, b_out3], 1)
b_out = self.tb(b_out)
y = b_out + x
out = F.relu(y)
return out
class InceptionResNetB(nn.Module):
def __init__(self, in_channels):
super(InceptionResNetB, self).__init__()
#branch1: conv1*1(128)
self.b1 = BasicConv2d(in_channels, 128, kernel_size=1)
#branch2: conv1*1(128) --> con1*7(128) --> conv7*1(128)
self.b2_1 = BasicConv2d(in_channels, 128, kernel_size=1)
self.b2_2 = BasicConv2d(128, 128, kernel_size=(1,7), padding=(0,3))
self.b2_3 = BasicConv2d(128, 128, kernel_size=(7,1), padding=(3,0))
#totalbranch: conv1*1(896)
self.tb = BasicConv2d(256, 896, kernel_size=1)
def forward(self, x):
x = F.relu(x)
b_out1 = F.relu(self.b1(x))
b_out2 = F.relu(self.b2_3(F.relu(self.b2_2(F.relu(self.b2_1(x))))))
b_out = torch.cat([b_out1, b_out2], 1)
b_out = self.tb(b_out)
y = b_out + x
out = F.relu(y)
return out
class InceptionResNetC(nn.Module):
def __init__(self, in_channels):
super(InceptionResNetC, self).__init__()
#branch1: conv1*1(192)
self.b1 = BasicConv2d(in_channels, 192, kernel_size=1)
#branch2: conv1*1(192) --> con1*3(192) --> conv3*1(192)
self.b2_1 = BasicConv2d(in_channels, 192, kernel_size=1)
self.b2_2 = BasicConv2d(192, 192, kernel_size=(1,3), padding=(0,1))
self.b2_3 = BasicConv2d(192, 192, kernel_size=(3,1), padding=(1,0))
#totalbranch: conv1*1(1792)
self.tb = BasicConv2d(384, 1792, kernel_size=1)
def forward(self, x):
x = F.relu(x)
b_out1 = F.relu(self.b1(x))
b_out2 = F.relu(self.b2_3(F.relu(self.b2_2(F.relu(self.b2_1(x))))))
b_out = torch.cat([b_out1, b_out2], 1)
b_out = self.tb(b_out)
y = b_out + x
out = F.relu(y)
return out
class ReductionA(nn.Module):
def __init__(self, in_channels, k, l, m, n):
super(ReductionA, self).__init__()
#branch1: maxpool3*3(stride2 valid)
self.b1 = nn.MaxPool2d(kernel_size=3, stride=2)
#branch2: conv3*3(n stride2 valid)
self.b2 = BasicConv2d(in_channels, n, kernel_size=3, stride=2)
#branch3: conv1*1(k) --> conv3*3(l) --> conv3*3(m stride2 valid)
self.b3_1 = BasicConv2d(in_channels, k, kernel_size=1)
self.b3_2 = BasicConv2d(k, l, kernel_size=3, padding=1)
self.b3_3 = BasicConv2d(l, m, kernel_size=3, stride=2)
def forward(self, x):
y1 = self.b1(x)
y2 = F.relu(self.b2(x))
y3 = F.relu(self.b3_3(F.relu(self.b3_2(F.relu(self.b3_1(x))))))
outputsRedA = [y1, y2, y3]
return torch.cat(outputsRedA, 1)
class ReductionB(nn.Module):
def __init__(self, in_channels):
super(ReductionB, self).__init__()
#branch1: maxpool3*3(stride2 valid)
self.b1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
#branch2: conv1*1(256) --> conv3*3(384 stride2 valid)
self.b2_1 = BasicConv2d(in_channels, 256, kernel_size=1)
self.b2_2 = BasicConv2d(256, 384, kernel_size=3, stride=2)
#branch3: conv1*1(256) --> conv3*3(256 stride2 valid)
self.b3_1 = BasicConv2d(in_channels, 256, kernel_size=1)
self.b3_2 = BasicConv2d(256, 256, kernel_size=3, stride=2)
#branch4: conv1*1(256) --> conv3*3(256) --> conv3*3(256 stride2 valid)
self.b4_1 = BasicConv2d(in_channels, 256, kernel_size=1)
self.b4_2 = BasicConv2d(256, 256, kernel_size=3, padding=1)
self.b4_3 = BasicConv2d(256, 256, kernel_size=3, stride=2)
def forward(self, x):
y1 = self.b1(x)
y2 = F.relu(self.b2_2(F.relu(self.b2_1(x))))
y3 = F.relu(self.b3_2(F.relu(self.b3_1(x))))
y4 = F.relu(self.b4_3(F.relu(self.b4_2(F.relu(self.b4_1(x))))))
outputsRedB = [y1, y2, y3, y4]
return torch.cat(outputsRedB, 1)
class StemForIR1(nn.Module):
def __init__(self, in_channels):
super(StemForIR1, self).__init__()
#conv3*3(32 stride2 valid)
self.conv1 = BasicConv2d(in_channels, 32, kernel_size=3, stride=2)
#conv3*3(32 valid)
self.conv2 = BasicConv2d(32, 32, kernel_size=3)
#conv3*3(64)
self.conv3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
#maxpool3*3(stride2 valid)
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
#conv1*1(80)
self.conv4 = BasicConv2d(64, 80, kernel_size=1)
#conv3*3(192 valid)
self.conv5 = BasicConv2d(80, 192, kernel_size=3)
#conv3*3(256, stride2 valid)
self.conv6 = BasicConv2d(192, 256, kernel_size=3, stride=2)
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.relu(self.conv2(out))
out = F.relu(self.conv3(out))
out = self.maxpool1(out)
out = F.relu(self.conv4(out))
out = F.relu(self.conv5(out))
out = F.relu(self.conv6(out))
return out
class InceptionResNetv1(nn.Module):
def __init__(self):
super(InceptionResNetv1, self).__init__()
self.stem = StemForIR1(3)
self.irA = InceptionResNetA(256)
self.redA = ReductionA(256, 192, 192, 256, 384)
self.irB = InceptionResNetB(896)
self.redB = ReductionB(896)
self.irC = InceptionResNetC(1792)
self.avgpool = nn.MaxPool2d(kernel_size=8)
self.dropout = nn.Dropout(p=0.8)
self.linear = nn.Linear(1792, 1000)
def forward(self, x):
n = [5, 10, 5]
out = self.stem(x)
if n[0] > 0:
out = self.irA(out)
n[0] -= 1
out = self.redA(out)
if n[1] > 0:
out = self.irB(out)
n[1] -= 1
out = self.redB(out)
if n[2] > 0:
out = self.irC(out)
n[2] -= 1
out = self.avgpool(out)
out = self.dropout(out)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
上一篇: 滑动过渡之Scroller
推荐阅读
-
Inception-ResNet模型框架(PyTorch)
-
【求助】 YII 框架跨应用模型间的调用
-
PyTorch使用cpu调用gpu训练的模型
-
Laravel框架学习笔记(二)项目实战之模型(Models)_php实例
-
php框架 - CodeIgniter的模型如何对应特定的数据表,类似thinkphp的real_tableName?
-
Laravel 5框架学习之模型、控制器、视图基础流程
-
pytorch导出onnx模型支持动态图
-
thinkphp框架之模型(数据库查询),thinkphp框架
-
Pytorch从入门到放弃(7)——可视化模型训练过程中的loss变化
-
ThinkPHP框架模型连贯操作(8)