测试一个训练好的caffe模型
程序员文章站
2024-03-15 11:44:23
...
在学习caffe的过程中,训练了出了模型出来,出了当时的准确率和loss值,并没有看到给定输入看到真正的输出,这个时候需要测试一下训练出来的模型,实际查看一下效果,其中用到的配置文件和网络模型在caffe的目录下都有,自己测试自己的模型时需要修改为自己的*.prototxt和*.caffemodel
#!/usr/bin/env python
#coding=utf-8
#因为需要sudo权限,只能作如下处理,自己添加caffe的位置,编译之后,然后用sudo权限运行该.py程序
import sys
caffe_root='/home/dyh/caffe/caffe/'
sys.path.insert(0, caffe_root+'python')
import caffe
import numpy as np
'''
这个测试模型可以帮助自己测试自己训练出来的模型效果如何
'''
caffe.set_mode_gpu()
# caffe.set_device(0)
#deploy文件就是用来测试训练好的网络的,给其输入,自己写测试来输出类别
model_def = '/home/dyh/caffe/caffe/models/bvlc_reference_caffenet/deploy.prototxt'
# model_def = '/home/dyh/caffe-workspace/face_detect/deploy_full_conv.prototxt'
model_weights = '/home/dyh/caffe-workspace/caffe_case/caffe_case模板/bvlc_reference_caffenet.caffemodel'
# model_weights = '/home/dyh/caffe-workspace/face_detect/model/solver_iter_25000.caffemodel'
net = caffe.Net(model_def, #测试的模型,caffe已经给出了,照着用
model_weights,#训练好的参数
caffe.TEST) #使用的模式
#加载均值文件
mu = np.load('/home/dyh/caffe/caffe/python/caffe/imagenet/ilsvrc_2012_mean.npy')
mu = mu.mean(1).mean(1)
print 'mean-substracted values',zip('BGR',mu)
transformer = caffe.io.Transformer({'data':net.blobs['data'].data.shape})
#[h,w,c]->[c,h,w]
transformer.set_transpose('data', [2,0,1])
# transformer.set_mean('data', mu) #减均值
transformer.set_raw_scale('data', 255)#变换到[0-1]
transformer.set_channel_swap('data', [2,1,0])#RGB->BGR
#按照caffe的输入格式reshape输入
net.blobs['data'].reshape(1, #batch,想一张的测试
3, #channel
227, #height
227) #weight
img = caffe.io.load_image('/home/dyh/caffe/caffe/examples/images/cat.jpg')
# img = caffe.io.load_image('/home/dyh/caffe-workspace/face_detect/train/1/4_nonface_0image54477.jpg')
#将输入进行预处理达到ceffe的输入格式要求
transformer_img = transformer.preprocess('data',img)
#让deploy里面的数据层接收到输入的图片
net.blobs['data'].data[...] = transformer_img
#前向传播一次就行
output = net.forward()
#在网络的最后一个层是输出的每个类别的概率
output_pro = output['prob'][0]
#概率最大的就是
print 'predict class is:',output_pro.argmax()
lables_path = '/home/dyh/caffe/caffe/data/ilsvrc12/synset_words.txt'
lables = np.loadtxt(lables_path, str,delimiter= '\t')#一行一行的读取并转换为ndarray
print lables[output_pro.argmax()]#第xx行是类别
输入图像是一只猫,最终的结果如下
I0422 07:28:21.366673 14366 net.cpp:744] Ignoring source layer loss
mean-substracted values [('B', 104.0069879317889), ('G', 116.66876761696767), ('R', 122.6789143406786)]
predict class is: 282
n02123159 tiger cat