center loss训练与可视化代码解析
程序员文章站
2024-03-14 21:38:41
...
参考博客:https://blog.csdn.net/sinat_24143931/article/details/79033414 损失函数改进总结
代码:https://github.com/pangyupo/mxnet_center_loss
1,定义网络结构并训练:
主要在于新定义了centerloss这个operator,加到网络结构上。
两种loss来自同一个全连接层的输出embedding,用mx.symbol.Group([ce_loss, center_loss])组合。
知识点1:Mxnet自定义Op
第一步继承CustomOp,重写方法forward()和backward(),然后继承CustomOpProp,重写成员方法,并在方法create_operator中
调用之前写好的Op,第三步调用operator.register()对操作进行注册。
2,训练与结果可视化
知识点2:MXnet获取特征输出
(1)加载预训练模型
model = mx.model.FeedForward.load('center_loss', 20, ctx=mx.cpu(0), numpy_batch_size=1) #加载预训练模型
(2)找到特征层,重建符号与模型
internals=model.symbol.get_internals() #列出所有的层
# internals.list_outputs() #展示网络结构
embedding_layer = internals['embedding_output'] #选择embedding层的特征
feature_extractor = mx.model.FeedForward(ctx=mx.cpu(0), symbol=embedding_layer, numpy_batch_size=1,\
arg_params = model.arg_params, aux_params=model.aux_params, allow_extra_params=True)
#feature_extractor.save('new_model_name',1) #保存新模型
(3)加载数据集
_, val = mnist_iterator(batch_size=100, input_shape=(1,28,28))
(4)提取特征并可视化
preds = feature_extractor.predict( i.data[0] ) #提取特征为二维
embeds.append( preds )
labels.append( i.label[0].asnumpy())
visual_feature_space(embeds, labels, 10, namedict)
可视化函数visual_feature_space画出二维点,描述分类结果
上一篇: MAC安装JDK