欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

libsvm支持向量机回归示例

程序员文章站 2024-02-26 22:14:58
libsvm支持向量机算法包的基本使用,此处演示的是支持向量回归机 复制代码 代码如下:import java.io.bufferedreader;import java...

libsvm支持向量机算法包的基本使用,此处演示的是支持向量回归机

复制代码 代码如下:

import java.io.bufferedreader;
import java.io.file;
import java.io.filereader;
import java.util.arraylist;
import java.util.list;

import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;

public class svm {
 public static void main(string[] args) {
  // 定义训练集点a{10.0, 10.0} 和 点b{-10.0, -10.0},对应lable为{1.0, -1.0}
  list<double> label = new arraylist<double>();
  list<svm_node[]> nodeset = new arraylist<svm_node[]>();
  getdata(nodeset, label, "file/train.txt");

  int datarange=nodeset.get(0).length;
  svm_node[][] datas = new svm_node[nodeset.size()][datarange]; // 训练集的向量表
  for (int i = 0; i < datas.length; i++) {
   for (int j = 0; j < datarange; j++) {
    datas[i][j] = nodeset.get(i)[j];
   }
  }
  double[] lables = new double[label.size()]; // a,b 对应的lable
  for (int i = 0; i < lables.length; i++) {
   lables[i] = label.get(i);
  }

  // 定义svm_problem对象
  svm_problem problem = new svm_problem();
  problem.l = nodeset.size(); // 向量个数
  problem.x = datas; // 训练集向量表
  problem.y = lables; // 对应的lable数组

  // 定义svm_parameter对象
  svm_parameter param = new svm_parameter();
  param.svm_type = svm_parameter.epsilon_svr;
  param.kernel_type = svm_parameter.linear;
  param.cache_size = 100;
  param.eps = 0.00001;
  param.c = 1.9;
  // 训练svm分类模型
  system.out.println(svm.svm_check_parameter(problem, param));
  // 如果参数没有问题,则svm.svm_check_parameter()函数返回null,否则返回error描述。
  svm_model model = svm.svm_train(problem, param);
  // svm.svm_train()训练出svm分类模型

  // 获取测试数据
  list<double> testlabel = new arraylist<double>();
  list<svm_node[]> testnodeset = new arraylist<svm_node[]>();
  getdata(testnodeset, testlabel, "file/test.txt");

  svm_node[][] testdatas = new svm_node[testnodeset.size()][datarange]; // 训练集的向量表
  for (int i = 0; i < testdatas.length; i++) {
   for (int j = 0; j < datarange; j++) {
    testdatas[i][j] = testnodeset.get(i)[j];
   }
  }
  double[] testlables = new double[testlabel.size()]; // a,b 对应的lable
  for (int i = 0; i < testlables.length; i++) {
   testlables[i] = testlabel.get(i);
  }

  // 预测测试数据的lable
  double err = 0.0;
  for (int i = 0; i < testdatas.length; i++) {
   double truevalue = testlables[i];
   system.out.print(truevalue + " ");
   double predictvalue = svm.svm_predict(model, testdatas[i]);
   system.out.println(predictvalue);
   err += math.abs(predictvalue - truevalue);
  }
  system.out.println("err=" + err / datas.length);
 }

 public static void getdata(list<svm_node[]> nodeset, list<double> label,
   string filename) {
  try {

   filereader fr = new filereader(new file(filename));
   bufferedreader br = new bufferedreader(fr);
   string line = null;
   while ((line = br.readline()) != null) {
    string[] datas = line.split(",");
    svm_node[] vector = new svm_node[datas.length - 1];
    for (int i = 0; i < datas.length - 1; i++) {
     svm_node node = new svm_node();
     node.index = i + 1;
     node.value = double.parsedouble(datas[i]);
     vector[i] = node;
    }
    nodeset.add(vector);
    double lablevalue = double.parsedouble(datas[datas.length - 1]);
    label.add(lablevalue);
   }
  } catch (exception e) {
   e.printstacktrace();
  }

 }
}

训练数据,最后一列为目标值

复制代码 代码如下:

17.6,17.7,17.7,17.7,17.8
17.7,17.7,17.7,17.8,17.8
17.7,17.7,17.8,17.8,17.9
17.7,17.8,17.8,17.9,18
17.8,17.8,17.9,18,18.1
17.8,17.9,18,18.1,18.2
17.9,18,18.1,18.2,18.4
18,18.1,18.2,18.4,18.6
18.1,18.2,18.4,18.6,18.7
18.2,18.4,18.6,18.7,18.9
18.4,18.6,18.7,18.9,19.1
18.6,18.7,18.9,19.1,19.3

测试数据

复制代码 代码如下:

18.7,18.9,19.1,19.3,19.6
18.9,19.1,19.3,19.6,19.9
19.1,19.3,19.6,19.9,20.2
19.3,19.6,19.9,20.2,20.6
19.6,19.9,20.2,20.6,21
19.9,20.2,20.6,21,21.5
20.2,20.6,21,21.5,22


libsvm支持向量机回归示例