计算机视觉学习 基于OpenCV+svm的手写数字识别
`
/* Includes ------------------------------------------------------------------*/
#include
#include
#include
#include
#include
#include “dtector.h”
#include “ui_dtector.h”
#include “subwinconfig/subwinconfig.h”
#include
#include
#include
#include
#include “dtector.h”
#include"opencv2/opencv.hpp"
#include<opencv2/core.hpp>
#include<opencv2/imgproc.hpp>
#include<opencv2/highgui.hpp>
#include<opencv2/ml.hpp>
#include
#include
#include
//https://blog.csdn.net/maweifei/article/details/60152242
QString trainImage = “/train-images.idx3-ubyte”;
QString trainLabel = “/train-labels.idx1-ubyte”;
QString testImage = “/t10k-images.idx3-ubyte”;
QString testLabel = “/t10k-labels.idx1-ubyte”;
int testLblsNum;
int imgVectorLen ;
#define OPENCV3 1
//#define SVM_HOG 1
//#define SVM_NORMAL 1
#define SVM_DILATE 1
/** ****************************************************************************
-
@remarks dtector::dtector(QWidget *parent) :
CBaseSubWin(parent),
ui(new Ui::dtector) -
@brief 构造函数
-
@param [IN ] parent 父窗口指针
-
@return NONE
-
@attention
********************************************************************************/
dtector::dtector(QWidget *parent) :
CBaseSubWin(parent),
ui(new Ui::dtector)
{
ui->setupUi(this);isDrawing = false;
//image = QImage(180,180, QImage::Format_RGB32);
image= QImage(200,200,QImage::Format_RGB32);backColor = qRgb(0, 0, 0);
image.fill(backColor);
btnGroupSvm = new QButtonGroup(this);
btnGroupSvm->addButton(ui->norRadio,1);
btnGroupSvm->addButton(ui->theRadio,2);
connect(ui->norRadio, SIGNAL(clicked()), this, SLOT(onRadioClick()));
connect(ui->theRadio, SIGNAL(clicked()), this, SLOT(onRadioClick()));
ui->norRadio->setChecked(true);
}
/** ****************************************************************************
- @remarks dtector::~dtector()
- @brief 析构函数
- @param NONE
- @return NONE
- @attention
********************************************************************************/
dtector::~dtector()
{
delete ui;
}
/** ****************************************************************************
- @remarks void dtector::SetLabelTxt
- @brief 设置LABEL字符
- @param
- @return
- @attention
********************************************************************************/
void dtector::SetLabelTxt(QLabel *pLabel,
const QString& SetStrn,
bool IsTrue)
{
pLabel->setText(SetStrn);
if(IsTrue)
{
pLabel->setStyleSheet(“background-color:white;”);
}
else
{
pLabel->setStyleSheet(“background-color:red;”);
}
}
/** ****************************************************************************
-
@remarks void dtector::SetLabelTxt
-
@brief 设置LABEL字符
-
@param
-
@return
-
@attention
********************************************************************************/
void dtector::SetLabelTxt(QLabel *pLabel,
double Value,
bool IsTrue,
char Format,
uint32_t Perc)
{
QString SetStrn = QString::number(Value, Format, Perc);pLabel->setText(SetStrn);
if(IsTrue)
{
pLabel->setStyleSheet(“background-color:white;”);
}else
{
pLabel->setStyleSheet(“background-color:red;”);
}
}
void dtector::SetLabelTxtColor(QLabel *pLabel,
double Value,
char IsTrue,
char Format,
uint32_t Perc)
{
QString SetStrn = QString::number(Value, Format, Perc);pLabel->setText(SetStrn);
if(IsTrue==1)
{
pLabel->setStyleSheet(“background-color:white;”);
}else if(IsTrue == 3)
{
pLabel->setStyleSheet(“background-color:green;”);
}
else
{
pLabel->setStyleSheet(“background-color:red;”);
}
}
/** **************************************************************************** -
@remarks void dtector::SetLabelTxt
-
@brief 设置LABEL字符
-
@param
-
@return
-
@attention
********************************************************************************/
void dtector::SetLabelTxt(QLabel *pLabel,
bool IsTrue)
{
if(IsTrue)
{
pLabel->setText(“正常”);
pLabel->setStyleSheet(“background-color:green;”);
}
else
{
pLabel->setText(“故障”);
pLabel->setStyleSheet(“background-color:red;”);
}
}
void dtector::paintEvent(QPaintEvent *)
{//可理解为是一个绘图终端函数,在本程序中只通过update()触发,调用结束后
//这也是一个状态函数,只要没关闭mainwindow,一直待机等待update()触发
QPainter painter(this);//QPainter是绘图操作,父亲是paintwidget类而paintwidget的父亲又是mainwindow
//既关闭mainwindow就关闭paintwidget就关闭了painter
painter.drawImage(0,0, image);//把图画在image上
}
void dtector::mousePressEvent(QMouseEvent *event){
if (event->button() == Qt::LeftButton){//-------------------------鼠标按下且为左键
lastPoint = event->pos();//----------------------------------设置起点为鼠标按下的点
endPoint = event->pos();//-----------------------------------设置终点为鼠标按下的点
isDrawing = true;//------------------------------------------开始绘图了
}
}
void dtector::mouseMoveEvent(QMouseEvent *event){ //重点理解部分
if (event->buttons() & Qt::LeftButton){//-------------------------鼠标按下左键并移动
endPoint = event->pos();//-----------------------------------鼠标每移动一次都刷新终点
paint(image); //--------------理解清楚这个函数--------仔细看看void paint (QImage &theImage)函数
//---------------------------------------最后会通过update()函数调用void paintEvent(QPaintEvent *)重绘函数
//---------------------------------------再仔细看看void paintEvent(QPaintEvent *)重绘函数会把图画在image画布上
//---------------------------------------我觉得理解了这个就理解得差不多了
}
}
void dtector::mouseReleaseEvent(QMouseEvent *event){
isDrawing = false;//--------------------------------------------绘图完毕
paint(image);//--------------------------------------------------把最后一点画在image画布上,可参考上面注释
}
void dtector::paint(QImage &theImage){//------------------(画图函数)调用这个函数就是调用重绘函数,把图画在image上
QPainter p(&theImage); //------------------------------------把图画在theImage(theImage是painterDevice类型参数,由于是引用,其实就是画在image上)
QPen apen;
apen.setWidth(10);//------------------------------------------画笔线条宽度设置为5
apen.setColor(Qt::white);
p.setPen(apen);//--------------------------------------------设置画笔线条宽度,也可以不设置,既把这两句注释掉,线条默认宽度为1
p.drawLine(lastPoint, endPoint);//----------------------------画线,由于鼠标移动事件会调用此函数,因此lastPoint和endPoint相距近似为0
//---------------------------因此可近似看成画点,点连起来就是画笔的痕迹了
lastPoint = endPoint;//--------------------------------------把终点复制给起点
update();//--------------------------------------------------刷新
}
Mat dtector::ReadMnistImage(QString pathName){
int magicNumber = 0;
int imageNumber = 0;
int rows = 0;
int cols = 0;
Mat dataMat;
ifstream file(pathName.toStdString(), ios::binary);
if (file.is_open() == true){
qDebug()<<pathName<<" open ok ";
file.read((char*)&magicNumber, sizeof(magicNumber));
file.read((char*)&imageNumber, sizeof(imageNumber));
file.read((char*)&rows, sizeof(rows));
file.read((char*)&cols, sizeof(cols));
magicNumber = CvtToLittleEndian(magicNumber);
qDebug() << "图像数据库的magic number为:" << magicNumber ;
imageNumber = CvtToLittleEndian(imageNumber);
qDebug() << "图像数据库的图像总数为:" << imageNumber;
rows = CvtToLittleEndian(rows);
qDebug() << "图像数据库的图像维度row为:" << rows ;
cols = CvtToLittleEndian(cols);
qDebug() << "图像数据库的图像维度col为:" << cols ;
imgVectorLen = rows *cols;
// 每张数字图像为一个一维向量,构成imageNumber * (rows * cols)的矩阵
dataMat = Mat::zeros(imageNumber, rows * cols, CV_32FC1);
for (int i = 0; i < imageNumber / 1; i++){
for (int j = 0; j < rows * cols; j++){
unsigned char temp = 0;
file.read((char*)&temp, sizeof(temp));
float value = float((temp + 0.0) / 255.0);
dataMat.at<float>(i, j) = value;
}
}
}else{
qDebug()<<" open failed "<<pathName;
}
file.close();
return dataMat;
}
Mat dtector::ReadMnistLabel(QString pathName){
int magicNumber;
int labelNumber;
Mat labelMat;
ifstream file(pathName.toStdString(), ios::binary);
if (file.is_open() == true){
qDebug()<<pathName<<" open ok ";
file.read((char*)&magicNumber, sizeof(magicNumber));
file.read((char*)&labelNumber, sizeof(labelNumber));
magicNumber = CvtToLittleEndian(magicNumber);
qDebug() << "图像标签数据库的magic number为:" << magicNumber ;
labelNumber = CvtToLittleEndian(labelNumber);
qDebug() << "图像标签数据库的标签总数为:" << labelNumber ;
testLblsNum = labelNumber;
labelMat = Mat::zeros(labelNumber, 1, CV_32SC1);
for (int i = 0; i < labelNumber / 1; i++){
unsigned char temp = 0;
file.read((char*)&temp, sizeof(temp));
labelMat.at<unsigned int>(i, 0) = (unsigned int)temp;
}
}else{
qDebug()<<pathName <<" open failed ";
}
file.close();
return labelMat;
}
int dtector::CvtToLittleEndian(int i){
unsigned char c1, c2, c3, c4;
c1 =i & 255;
c2 = (i >> 8) & 255;
c3 = (i >> 16) & 255;
c4 = (i >> 24) & 255;
return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
}
#if 0
void dtector::test()
{
//creat SVM classfier
Ptr svm = SVM::create();
//load train file
svm = SVM::load(“SVM_HOG.xml”);
if (!svm)
{
// ui.loadXmllabel->setText("load file faile...");
}
Mat test;
test = imread("Win32/Debug/demo.png");
//imshow("image", test);
// ui.loadImagelabel->setText(“succcess”);
//winsize(64,128),blocksize(16,16),blockstep(8,8),cellsize(8,8),bins9
//检测窗口(64,128),块尺寸(16,16),块步长(8,8),cell尺寸(8,8),直方图bin个数9
HOGDescriptor hog(Size(128, 128), Size(16, 16), Size(8, 8), Size(8, 8), 9);
vector<float> descriptors;//HOG描述子向量
hog.compute(test, descriptors, Size(8, 8));//计算HOG描述子,检测窗口移动步长(8,8)
int r = svm->predict(descriptors); //对所有行进行预测
// ui.testResultlabel->setText("The number is " + QString::number® + “.”);
//ui.TestButton->setText(“success”);
}
void dtector::clearPaint()
{
image = QImage(128, 128, QImage::Format_RGB32);
backColor = qRgb(0, 0, 0);
image.fill(backColor);
// ui.loadXmllabel->setText(“Waitting…”);
// ui.loadImagelabel->setText(“Waitting…”);
// ui.testResultlabel->setText(“Waitting…”);
update();
}
#endif
void dtector::on_starttrain_clicked()
{
double consumeTime = 0;
std::clock_t startTime = 0;
std::clock_t endTime = 0;
cv::Mat trainData;
cv::Mat trainDataLabels;
QString runPath = QCoreApplication::applicationDirPath();
//【1】读入训练样本
trainData = ReadMnistImage( runPath+trainImage);
trainDataLabels = ReadMnistLabel(runPath+trainLabel);
//【2】设置支持向量机的参数,SVM中的参数有很多,但是与C_SVC有关的就只有gamma和C,所以只要设置好这两个就可以了
// 其实,很多资料将gamma设置为0.01,这样训练的收敛速度就会快很多
#if OPECV2
CvSVMParams params;
params.svm_type = SVM::C_SVC;
params.kernel_type = SVM::RBF;
params.degree = 10.0;
params.gamma = 0.01;
params.coef0 = 1.0;
params.C = 10.0;
params.nu = 0.5;
params.p = 0.1;
params.term_crit = cv::TermCriteria(CV_TERMCRIT_EPS,1000,FLT_EPSILON);
//【3】训练SVM
std::cout<<"[NOTICE]Starting training process!"<<std::endl;
startTime = std::clock();
CvSVM svm;
svm.train(trainData,trainDataLabels,cv::Mat(),cv::Mat(),params);
endTime = std::clock();
consumeTime = (endTime - startTime);
qDebug()<<"[NOTICE]Finished training process…consumeTime = "<<consumeTime<<“ms”;
svm.save(“mnist_dataset/mnist_svm.xml”);
#endif
qDebug()<<“current path :”<<runPath;
#if OPENCV3
Ptr<SVM> svm = SVM::create();
svm->setType(SVM::C_SVC);//用于多类分类
svm->setKernel(SVM::RBF);//采用高斯核函数
svm->setDegree(10.0);//高斯核的参数设置
svm->setGamma(0.01);
svm->setCoef0(1.0);
svm->setC(10.0);
svm->setNu(0.5);
svm->setP(0.1);
//训练
//【3】训练SVM
#if EPS
svm->setTermCriteria(cv::TermCriteria(CV_TERMCRIT_EPS, 200, FLT_EPSILON));//
#endif
svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 200, FLT_EPSILON));
//训练开始
trainData.convertTo(trainData,CV_32FC1);
trainDataLabels.convertTo(trainDataLabels, CV_32SC1);
qDebug()<<"[NOTICE]Starting training process!";
this->ui->statu1->setText("训练需要二、三十分钟");
startTime = std::clock();
svm->train(trainData, ROW_SAMPLE, trainDataLabels);
// Ptr tData = TrainData::create(trainData, ROW_SAMPLE, trainDataLabels);
// svm->train(tData);
endTime = std::clock();
consumeTime = (endTime - startTime);
qDebug()<<"[NOTICE]Finished training process…consumeTime = “<<consumeTime<<“ms”;
QString xmlPath = runPath +”/SVM_DATA.xml";
svm->save(xmlPath.toStdString());
//https://blog.csdn.net/LIT_Elric/article/details/79202469
#endif
}
void dtector::on_pushButton_2_clicked()
{
double consumeTime = 0;
std::clock_t startTime = 0;
std::clock_t endTime = 0;
cv::Mat testDataimage;
cv::Mat testDataLabels;
QString runPath = QCoreApplication::applicationDirPath();
//【1】读入训练样本
testDataimage = ReadMnistImage( runPath+testImage);
testDataLabels = ReadMnistLabel(runPath+testLabel);
testDataimage.convertTo(testDataimage,CV_32FC1);
testDataLabels.convertTo(testDataLabels, CV_32SC1);
qDebug()<<"[NOTICE]Starting predict process!";
this->ui->statu1->setText("预测需要二、三分钟");
//载入训练好的SVM模型
QString xmlPath = runPath +"/SVM_DATA.xml";
Ptr<SVM> svm = SVM::load(xmlPath.toStdString());
int sum = 0;
startTime = std::clock();
//对每一个测试图像进行SVM分类预测
for (int i = 0; i < testLblsNum; i++)
{
Mat predict_mat = Mat::zeros(1, imgVectorLen, CV_32FC1);
memcpy(predict_mat.data, testDataimage.data + i*imgVectorLen * sizeof(float), imgVectorLen * sizeof(float));
//预测
float predict_label = svm->predict(predict_mat);
//真实的样本标签
float truth_label = testDataLabels.at<int>(i);
//比较判定是否预测正确
if ((int)predict_label == (int)truth_label)
{
sum++;
}
}
endTime = std::clock();
consumeTime = (endTime - startTime);
SetLabelTxt(this->ui->timeplay,consumeTime,true);
qDebug() << "预测准确率为:"<<(double)sum / (double)testLblsNum ;
double recRate= (double)sum / (double)testLblsNum ;
QString b= QString::number(recRate,10,5);
this->ui->statu1->setText("预测准确率为:"+b);
//blog.csdn.net/wblgers1234/article/details/73477860
}
//https://blog.csdn.net/Almost_Miao/article/details/79132319
void dtector::onRadioClick()
{
checkedId = btnGroupSvm->checkedId();
}
void dtector::on_pushButton_4_clicked()
{
QString path = QString("%1/demot.bmp").arg(QApplication::applicationDirPath());
QString xmlPath = QString("%1/SVM_DATA.xml").arg(QApplication::applicationDirPath());
Ptr<SVM> svm = SVM::load(xmlPath.toStdString());
if (!svm)
{
qDebug() << "解压模型falie";
ui->statu1->setText("load file faile...");
}else {
ui->statu1->setText("解压模型成功");
qDebug() << "解压模型成功";
}
if(checkedId ==1){
qDebug() << "pic path "<<path;
Mat testIma = imread(path.toStdString(),IMREAD_COLOR);
Mat gray;
cvtColor(testIma, gray, CV_BGR2GRAY);
// testIma.
qDebug() << "加载图像 深度 : " << testIma.depth();
Mat imgReadScal = Mat::zeros(28, 28, CV_8UC1);
Mat color_mat = Mat::zeros(28, 28, CV_32FC1);
Mat show_mat = Mat::zeros(28, 28, CV_32FC1);
// cvtColor(testIma,color_mat,CV_BGR5552GRAY);
cv::resize(gray,imgReadScal,imgReadScal.size());
imgReadScal.convertTo(show_mat, CV_32FC1);
// cvtColor(show_mat,show_mat,CV_RGB2GRAY);
// cv::imshow(“my”,testIma);
// testIma.convertTo(testIma,CV_32FC1);
show_mat = show_mat / 255;
Mat predict_mat = Mat::zeros(1, 28*28, CV_32FC1);
memcpy(predict_mat.data, show_mat.data, 28*28 * sizeof(float));
imshow("test", testIma);
float predict_label = svm->predict(predict_mat);
#if 0
Mat predict_mat = Mat::zeros(1, 2828, CV_32FC1);
memcpy(predict_mat.data, testIma.data , 2828 * sizeof(float));
float predict_label = svm->predict(predict_mat);
//真实的样本标签
#endif
qDebug() << “预测结果为:”<<predict_label;
QString b= QString::number(predict_label,10,0);
this->ui->result->setText(b);
// this->ui->statu1->setText(“预测准确率为:”+b);
//blog.csdn.net/wblgers1234/article/details/73477860
}else if(checkedId == 2){
Mat test;
test = imread(path.toStdString(),0);
Mat show_mat = Mat::zeros(28, 28, CV_8UC3);
// test.convertTo(show_mat, CV_32FC1);
//检测窗口(28,28),块尺寸(14,14),块步长(7,7),cell尺寸(7,7),直方图bin个数9
HOGDescriptor hog(Size(28, 28), Size(14, 14), Size(7, 7), Size(7, 7), 9);
vector<float> descriptors;//HOG描述子向量
hog.compute(test, descriptors, Size(7, 7));//计算HOG描述子,检测窗口移动步长(8,8)
// CvMat* SVMtrainMat=cvCreateMat(1,descriptors.size(),CV_32FC1);
Mat predict_mat = Mat::zeros(1, 28*28, CV_32FC1);
memcpy(predict_mat.data,descriptors.data(),28*28* sizeof(float));
predict_mat=predict_mat/255;
// for(vector::iterator iter=descriptors.begin();iter!=descriptors.end();iter++)
// {
// // mats(SVMtrainMat,0,n,*iter);
// predict_mat.data ++ = *iter;
// n++;
// }
imshow(“HOG”, predict_mat);
float predict_label = svm->predict(predict_mat); //对所有行进行预测
qDebug() << "预测结果为:"<<predict_label;
QString b= QString::number(predict_label,10,0);
this->ui->result->setText(b);
//blog.csdn.net/Almost_Miao/article/details/79132319
}else {
Mat src, gray,medblurImg,threImg;
Mat structElem,eroImg,dilateImg,cannyImg;
src = imread(path.toStdString(),IMREAD_COLOR);
qDebug() << "加载图像 深度 : " << src.depth()<< src.channels()<<src.type() <<src.rows;
cvtColor(src, gray, CV_BGR2GRAY);
// imshow("gray", gray);
qDebug() << "huidu图像 深度 : " << gray.depth()<< gray.channels()<<gray.type() << gray.rows;
medianBlur(gray, medblurImg, 3);
threshold(medblurImg, threImg, 125, 255, CV_THRESH_BINARY);
structElem = getStructuringElement(CV_SHAPE_RECT, Size(3,3));
dilate(threImg, dilateImg, structElem);
erode(dilateImg, eroImg, structElem);
imshow("dilateImg", threImg);
Mat contourImg = Mat::zeros(eroImg.size(),CV_8UC3);
Mat testData = Mat::zeros(28,28, CV_32FC1);
threImg.convertTo(testData,CV_32FC1,1.0/255.0);//归一化
vector<vector<Point>> contours;
findContours(eroImg, contours, RETR_EXTERNAL, CHAIN_APPROX_NONE);
vector<Rect> numRect;
int testDataRow = 0;
int addpixel = 10;
Mat predict_mat = Mat::zeros(1, 28*28, CV_32FC1);
memcpy(predict_mat.data, testData.data, 28*28 * sizeof(float));
float predict_label = svm->predict(predict_mat);
qDebug() << "预测结果为:"<< predict_label;
QString b= QString::number(predict_label,10,0);
this->ui->result->setText(b);
//blog.csdn.net/qq_29441995/article/details/82887475
}
}
void dtector::on_clearscreen_clicked()
{
image = QImage(200, 200, QImage::Format_RGB32);
backColor = qRgb(0, 0, 0);
image.fill(backColor);
// ui.loadXmllabel->setText(“Waitting…”);
// ui.loadImagelabel->setText(“Waitting…”);
// ui.testResultlabel->setText(“Waitting…”);
update();
}
void dtector::on_pushButton_clicked()
{
QString path = QString("%1/demot.bmp").arg(QApplication::applicationDirPath());
#ifdef SVM_NORMAL
QImage newimg3 = image.scaled(28,28,Qt::KeepAspectRatio);
newimg3.save(path);
#elif SVM_HOG
QImage newimg3 = image.scaled(28,28,Qt::KeepAspectRatio);
newimg3.allGray();
newimg3.save(path);
#endif
#if SVM_DILATE
QImage newimg3 = image.scaled(28,28,Qt::KeepAspectRatio);
newimg3.save(path);
#endif
this->ui->statu1->setText(“图片保存成功”);
}
推荐阅读
-
计算机视觉学习 基于OpenCV+svm的手写数字识别
-
基于MNIST手写数字数据集的数字识别小程序
-
手写数字识别 ----在已经训练好的数据上根据28*28的图片获取识别概率(基于Tensorflow,Python)
-
【机器学习】【KNN】线性扫描算法,python实现识别手写数字的系统
-
机器学习实战学习笔记(二)-KNN算法(2)-使用KNN算法进行手写数字的识别
-
机器学习_KNN实验(手写数字的识别)
-
基于MNIST手写数字数据集的数字识别小程序
-
手写数字识别 ----在已经训练好的数据上根据28*28的图片获取识别概率(基于Tensorflow,Python)
-
机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)
-
深度学习与计算机视觉(12)_tensorflow实现基于深度学习的图像补全