欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

java实现随机森林RandomForest的示例代码

程序员文章站 2024-02-22 12:07:40
随机森林是由多棵树组成的分类或回归方法。主要思想来源于bagging算法,bagging技术思想主要是给定一弱分类器及训练集,让该学习算法训练多轮,每轮的训练集由原始训练集...

随机森林是由多棵树组成的分类或回归方法。主要思想来源于bagging算法,bagging技术思想主要是给定一弱分类器及训练集,让该学习算法训练多轮,每轮的训练集由原始训练集中有放回的随机抽取,大小一般跟原始训练集相当,这样依次训练多个弱分类器,最终的分类由这些弱分类器组合,对于分类问题一般采用多数投票法,对于回归问题一般采用简单平均法。随机森林在bagging的基础上,每个弱分类器都是决策树,决策树的生成过程中中,在属性的选择上增加了依一定概率选择属性,在这些属性中选择最佳属性及分割点,传统做法一般是全部属性中去选择最佳属性,这样随机森林有了样本选择的随机性,属性选择的随机性,这样一来增加了每个分类器的差异性、不稳定性及一定程度上避免每个分类器的过拟合(一般决策树有过拟合现象),由此组合分类器增加了最终的泛化能力。下面是代码的简单实现

/**
 * 随机森林 回归问题
 * @author ysh  1208706282
 *
 */
public class randomforest {
  list<sample> msamples;
  list<cart> mcarts;
  double mfeaturerate;
  int mmaxdepth;
  int mminleaf;
  random mrandom;
  /**
   * 加载数据  回归树
   * @param path
   * @param regex
   * @throws exception
   */
  public void loaddata(string path,string regex) throws exception{
    msamples = new arraylist<sample>();
    bufferedreader reader = new bufferedreader(new filereader(path));
    string line = null;
    string splits[] = null;
    sample sample = null;
    while(null != (line=reader.readline())){
      splits = line.split(regex);
      sample = new sample();
      sample.label = double.valueof(splits[0]);
      sample.feature = new arraylist<double>(splits.length-1);
      for(int i=0;i<splits.length-1;i++){
        sample.feature.add(new double(splits[i+1]));
      }
      msamples.add(sample);
    }
    reader.close();
  }
  public void train(int iters){
    mcarts = new arraylist<cart>(iters);
    cart cart = null;
    for(int iter=0;iter<iters;iter++){
      cart = new cart();
      cart.mfeaturerate = mfeaturerate;
      cart.mmaxdepth = mmaxdepth;
      cart.mminleaf = mminleaf;
      cart.mrandom = mrandom;
      list<sample> s = new arraylist<sample>(msamples.size());
      for(int i=0;i<msamples.size();i++){
        s.add(msamples.get(cart.mrandom.nextint(msamples.size())));
      }
      cart.setdata(s);
      cart.train();
      mcarts.add(cart);
      system.out.println("iter: "+iter);
      s = null;
    }
  }
  /**
   * 回归问题简单平均法 分类问题多数投票法
   * @param sample
   * @return
   */
  public double classify(sample sample){
    double val = 0;
    for(cart cart:mcarts){
      val += cart.classify(sample);
    }
    return val/mcarts.size();
  }
  /**
   * @param args
   * @throws exception 
   */
  public static void main(string[] args) throws exception {
    // todo auto-generated method stub
    randomforest forest = new randomforest();
    forest.loaddata("f:/2016-contest/20161001/train_data_1.csv", ",");
    forest.mfeaturerate = 0.8;
    forest.mmaxdepth = 3;
    forest.mminleaf = 1;
    forest.mrandom = new random();
    forest.mrandom.setseed(100);
    forest.train(100);
    
    list<sample> samples = cart.loadtestdata("f:/2016-contest/20161001/valid_data_1.csv", true, ",");
    double sum = 0;
    for(sample s:samples){
      double val = forest.classify(s);
      sum += (val-s.label)*(val-s.label);
      system.out.println(val+" "+s.label);
    }
    system.out.println(sum/samples.size()+" "+sum);
    system.out.println(system.currenttimemillis());
  }

}

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。