06Attention-transfer学习笔记
文章目录
0、注意力动机
并不是首次提出,已经应用效果好;
来自知识蒸馏,只学习最后层的知识,本研究用在训练过程特征图中
1、 注意力转移
1、注意力机制
视觉注意力机制是人类视觉所特有的大脑信号处理机制。人类是决通过快速扫描全局图像,获得需要关注的目标区域,然后重点获取所需关注的目标信息,抑制其他无用信息!
3个类型
空间域spatial domain:更多关注特征空间信息,决定空间中哪些区域重要,那些不重要
论文:dynamic
通道域channel domain:关注通道信息
论文:squeeze and excitation networks
SENet:squeeze压缩;excitation激励;–降维再升维;scale重定向
混合域mixed domain:同时关注
CBAM
4个应用
2、基于**的注意力转移,前馈过程
3、基于梯度的注意力转移,反馈过程
关注那些对输出影响大的区域,例如只改变输入图像的很小一部分,就能使得网络发生巨大变化,改变的这部分就值得我们关注
1、对教师网和学生网的交叉熵损失函数分别求梯度,作为梯度注意力转移的损失函数,教师网的不更新
2、转移教师网中的变化较大的部分,构造损失函数
3、若教师网的参数和输入均给定,则对学生网的参数求导
提出迫使梯度注意力图的水平翻转不变的损失函数,把原图的水平翻转的图都做传播,然后加上得到的注意力图和输出的L2损失,做第二次传播—直接一点:就是使得一张图翻转后和翻转前的梯度图接近,因为是同一个目标
将输入的梯度作为注意力图,转移到学生网络
2、轻量化改进手段
知识蒸馏、剪枝、量化、低秩分解
人工模型设计:squeezenet、mobilenet
NAS:强化学习得到
3、研究意义
引入注意力机制,提升学生网络性能
基于**的注意力和基于梯度的注意力两种方式,过程的知识
启发了知识蒸馏,知识蒸馏是在最后一层,输出的知识
4、总结
1、网络不同层有不同的注意力**区域,会产生不同的注意力图,前面层注意力比较均匀,中间层对最具有判别性的区域**程度最高,最高层反应目标整体的特征区域
2、
3、AT有效的降低了识别错误率,AT和KD结合,进一步降低错误率,与全部**都转移相比,任何一种注意力转移的方式都降低错误率
基于**的,
三种**都降低了错误率,求和再平方的效果最好。
基于梯度的,
没有基于**的效果好,但是也变好了
5、论文总结
知识的表现形式:
蒸馏:最后层的知识,**知识,梯度知识
判定某个区域重要性的指标:特征图值的大小
基于**的注意力**和基于梯度的注意力**
创新:
将注意力机制引入知识蒸馏,提升CNN模型性能
设计了两类注意力转移方法
提供了多种注意力图生成方式
启发:
6、最后
构建学生、教师网络
利用注意力转移进行知识蒸馏
在SVHN数据集上进行验证
:::彩蛋:::
模型训练:
在网络训练过程中,都是计算前向误差,反向误差,更新权重
只是模型、数据,评价函数不同
torch的engine给训练过程提供了一个模板,该模板建立model,datasetiterator,criterion,meter之间的关系
将训练过程进行包装,抽象成一个类,提供train和test方法,外部可以通过state变量和Engine训练过程交互
1.engine.train(network,iterator,maxepoch,optimizer):训练的网络,数据迭代,最大轮数,优化方式
2.engine.test(network,iterator):训练的网络,当前迭代
3.state = {
['network'] = network, --设置了model
['criterion'] = criterion, -- 设置损失函数
['iterator'] = iterator, -- 数据迭代器
['lr'] = lr, -- 学习率
['maxepoch'] = maxepoch, --最大epoch数
['sample'] = {}, -- 当前采集的样本,可以在onSample中通过该阈值查看采样样本
['epoch'] = 0 , -- 当前的epoch
['t'] = 0, -- 已经训练样本的个数
['training'] = true -- 训练过程
}
4.hooks = {
['onStart'] = function() end, --用于训练开始前的设置和初始化
['onStartEpoch'] = function() end, -- 每一个epoch前的操作
['onSample'] = function() end, -- 每次采样一个样本之后的操作
['onForward'] = function() end, -- 在model:forward()之后的操作
['onForwardCriterion'] = function() end, -- 前向计算损失函数之后的操作
['onBackwardCriterion'] = function() end, -- 反向计算损失误差之后的操作
['onBackward'] = function() end, -- 反向传递误差之后的操作
['onUpdate'] = function() end, -- 权重参数更新之后的操作
['onEndEpoch'] = function() end, -- 每一个epoch结束时的操作
['onEnd'] = function() end, -- 整个训练过程结束后的收拾现场
}
5.Meter
add() reset() value()
per_acc=tnt.meter.APMeter()#每一类样本的平均准确率
meter_loss = tnt.meter.AverageValueMeter()#平均损失
classacc = tnt.meter.ClassErrorMeter(accuracy=True)#返回top-k准确率
timer_train = tnt.meter.TimeMeter('s')#用于统计事件之间的时间间隔,也可以用来统计batch数据的平均处理时间
上一篇: LeetCode——两数之和
下一篇: MySQL学习——添加新用户(1)