欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

center loss训练与可视化代码解析

程序员文章站 2024-03-14 21:38:41
...

参考博客:https://blog.csdn.net/sinat_24143931/article/details/79033414  损失函数改进总结

代码:https://github.com/pangyupo/mxnet_center_loss

1,定义网络结构并训练:

主要在于新定义了centerloss这个operator,加到网络结构上。

两种loss来自同一个全连接层的输出embedding,用mx.symbol.Group([ce_loss, center_loss])组合。

知识点1:Mxnet自定义Op 

第一步继承CustomOp,重写方法forward()和backward(),然后继承CustomOpProp,重写成员方法,并在方法create_operator中

调用之前写好的Op,第三步调用operator.register()对操作进行注册。

2,训练与结果可视化

知识点2:MXnet获取特征输出

(1)加载预训练模型
model = mx.model.FeedForward.load('center_loss', 20, ctx=mx.cpu(0), numpy_batch_size=1) #加载预训练模型

(2)找到特征层,重建符号与模型

internals=model.symbol.get_internals()                           #列出所有的层
# internals.list_outputs()                                       #展示网络结构
embedding_layer = internals['embedding_output']                  #选择embedding层的特征    
feature_extractor = mx.model.FeedForward(ctx=mx.cpu(0), symbol=embedding_layer, numpy_batch_size=1,\
            arg_params = model.arg_params, aux_params=model.aux_params, allow_extra_params=True)
#feature_extractor.save('new_model_name',1)                       #保存新模型

(3)加载数据集

_, val = mnist_iterator(batch_size=100, input_shape=(1,28,28))

(4)提取特征并可视化

preds = feature_extractor.predict( i.data[0] )   #提取特征为二维
embeds.append( preds )
labels.append( i.label[0].asnumpy())
visual_feature_space(embeds, labels, 10, namedict)

可视化函数visual_feature_space画出二维点,描述分类结果

center loss训练与可视化代码解析


相关标签: loss