pytorch(四)——transforms详解
程序员文章站
2022-07-06 10:18:51
...
一、基础知识
1、计算机视觉工具包:torchvision
torchvision.transforms : 常用的图像预处理方法
torchvision.datasets : 常用数据集的dataset实现,MNIST,CIFAR-10,ImageNet等
torchvision.model : 常用的模型预训练,AlexNet,VGG, ResNet,GoogLeNet等
2、常用的图像预处理方法
数据中心化、数据标准化、缩放、裁剪、旋转、翻转、填充、噪声添加、灰度变换、线性变换、仿射变换、亮度、饱和度及对比度变换。同时transforms所处的位置为(在函数getitem中):
二、重点讲解
1、图像标准化:transforms.Normalize
功能:逐channel的对图像进行标准化、output = (input - mean) / std、 mean:各通道的均值、std:各通道的标准差、inplace:是否原地操作。图像标准化是将数据通过去均值实现中心化的处理,根据凸优化理论与数据概率分布相关知识,数据中心化符合数据分布规律,更容易取得训练之后的泛化效果, 数据标准化是数据预处理的常见方法之一。(这里的均值是像素均值)
transforms.Normalize(mean,std,inplace=False)
注:标准化的原理是我们默认自然图像是一类平稳的数据分布(即数据每一维的统计都服从相同分布),此时,在每个样本上减去数据的统计平均值可以移除共同的部分,凸显像素个体差异。 个人认为还可以去除图像的亮度信息,增加模型的泛化能力。
2、
三、代码补充
1、tensor.sub_:在操作中带下划线的都是直接对原始数据进行更改的操作。
下一篇: C语言入门教程-(1)简介及搭建环境