欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Pytorch flatten 和 merge

程序员文章站 2024-03-02 23:04:52
...

 

 

 

1 Flatten

Flatten就是将2D的特征图压扁为1D的特征向量,用于全连接层的输入。

 

# Flatten继承Module
class Flatten(nn.Module):
    # 构造函数,没有什么要做的
    def __init__(self):
        # 调用父类构造函数
        super(Flatten, self).__init__()

    # 实现forward函数
    def forward(self, input):
        # 保存batch维度,后面的维度全部压平,例如输入是28*28的特征图,压平后为784的向量
        return input.view(input.size(0), -1)

 

2 merge

    to be continue

 

 

 

 

Reference

[深度学习] pytorch学习笔记(4)(Module类、实现Flatten类、Module类作用、数据增强)

2 ...