欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

基于Java实现的一层简单人工神经网络算法示例

程序员文章站 2024-02-23 21:27:46
本文实例讲述了基于java实现的一层简单人工神经网络算法。分享给大家供大家参考,具体如下: 先来看看笔者绘制的算法图: 2、数据类 import ja...

本文实例讲述了基于java实现的一层简单人工神经网络算法。分享给大家供大家参考,具体如下:

先来看看笔者绘制的算法图:

基于Java实现的一层简单人工神经网络算法示例

基于Java实现的一层简单人工神经网络算法示例

2、数据类

import java.util.arrays;
public class data {
  double[] vector;
  int dimention;
  int type;
  public double[] getvector() {
    return vector;
  }
  public void setvector(double[] vector) {
    this.vector = vector;
  }
  public int getdimention() {
    return dimention;
  }
  public void setdimention(int dimention) {
    this.dimention = dimention;
  }
  public int gettype() {
    return type;
  }
  public void settype(int type) {
    this.type = type;
  }
  public data(double[] vector, int dimention, int type) {
    super();
    this.vector = vector;
    this.dimention = dimention;
    this.type = type;
  }
  public data() {
  }
  @override
  public string tostring() {
    return "data [vector=" + arrays.tostring(vector) + ", dimention=" + dimention + ", type=" + type + "]";
  }
}

3、简单人工神经网络

package cn.edu.hbut.chenjie;
import java.util.arraylist;
import java.util.list;
import java.util.random;
import org.jfree.chart.chartfactory;
import org.jfree.chart.chartframe;
import org.jfree.chart.jfreechart;
import org.jfree.data.xy.defaultxydataset;
import org.jfree.ui.refineryutilities;
public class ann2 {
  private double eta;//学习率
  private int n_iter;//权重向量w[]训练次数
  private list<data> exercise;//训练数据集
  private double w0 = 0;//阈值
  private double x0 = 1;//固定值
  private double[] weights;//权重向量,其长度为训练数据维度+1,在本例中数据为2维,故长度为3
  private int testsum = 0;//测试数据总数
  private int error = 0;//错误次数
  defaultxydataset xydataset = new defaultxydataset();
  /**
   * 向图表中增加同类型的数据
   * @param type 类型
   * @param a 所有数据的第一个分量
   * @param b 所有数据的第二个分量
   */
  public void add(string type,double[] a,double[] b)
  {
    double[][] data = new double[2][a.length];
    for(int i=0;i<a.length;i++)
    {
      data[0][i] = a[i];
      data[1][i] = b[i];
    }
    xydataset.addseries(type, data);
  }
  /**
   * 画图
   */
  public void draw()
  {
    jfreechart jfreechart = chartfactory.createscatterplot("exercise", "x1", "x2", xydataset);
    chartframe frame = new chartframe("训练数据", jfreechart);
    frame.pack();
    refineryutilities.centerframeonscreen(frame);
    frame.setvisible(true);
  }
  public static void main(string[] args)
  {
    ann2 ann2 = new ann2(0.001,100);//构造人工神经网络
    list<data> exercise = new arraylist<data>();//构造训练集
    //人工模拟1000条训练数据 ,分界线为x2=x1+0.5
    for(int i=0;i<1000000;i++)
    {
      random rd = new random();
      double x1 = rd.nextdouble();//随机产生一个分量
      double x2 = rd.nextdouble();//随机产生另一个分量
      double[] da = {x1,x2};//产生数据向量
      data d = new data(da, 2, x2 > x1+0.5 ? 1 : -1);//构造数据
      exercise.add(d);//将训练数据加入训练集
    }
    int sum1 = 0;//记录类型1的训练记录数
    int sum2 = 0;//记录类型-1的训练记录数
    for(int i = 0; i < exercise.size(); i++)
    {
      if(exercise.get(i).gettype()==1)
        sum1++;
      else if(exercise.get(i).gettype()==-1)
        sum2++;
    }
    double[] x1 = new double[sum1];
    double[] y1 = new double[sum1];
    double[] x2 = new double[sum2];
    double[] y2 = new double[sum2];
    int index1 = 0;
    int index2 = 0;
    for(int i = 0; i < exercise.size(); i++)
    {
      if(exercise.get(i).gettype()==1)
      {
        x1[index1] = exercise.get(i).vector[0];
        y1[index1++] = exercise.get(i).vector[1];
      }
      else if(exercise.get(i).gettype()==-1)
      {
        x2[index2] = exercise.get(i).vector[0];
        y2[index2++] = exercise.get(i).vector[1];
      }
    }
    ann2.add("1", x1, y1);
    ann2.add("-1", x2, y2);
    ann2.draw();
    ann2.input(exercise);//将训练集输入人工神经网络
    ann2.fit();//训练
    ann2.showweigths();//显示权重向量
    //人工生成一千条测试数据
    for(int i=0;i<10000;i++)
    {
      random rd = new random();
      double x1_ = rd.nextdouble();
      double x2_ = rd.nextdouble();
      double[] da = {x1_,x2_};
      data test = new data(da, 2, x2_ > x1_+0.5 ? 1 : -1);
      ann2.predict(test);//测试
    }
    system.out.println("总共测试" + ann2.testsum + "条数据,有" + ann2.error + "条错误,错误率:" + ann2.error * 1.0 /ann2.testsum * 100 + "%");
  }
  /**
   *
   * @param eta 学习率
   * @param n_iter 权重分量学习次数
   */
  public ann2(double eta, int n_iter) {
    this.eta = eta;
    this.n_iter = n_iter;
  }
  /**
   * 输入训练集到人工神经网络
   * @param exercise
   */
  private void input(list<data> exercise) {
    this.exercise = exercise;//保存训练集
    weights = new double[exercise.get(0).dimention + 1];//初始化权重向量,其长度为训练数据维度+1
    weights[0] = w0;//权重向量第一个分量为w0
    for(int i = 1; i < weights.length; i++)
      weights[i] = 0;//其余分量初始化为0
  }
  private void fit() {
    for(int i = 0; i < n_iter; i++)//权重分量调整n_iter次
    {
      for(int j = 0; j < exercise.size(); j++)//对于训练集中的每条数据进行训练
      {
        int real_result = exercise.get(j).type;//y
        int calculate_result = calculateresult(exercise.get(j));//y'
        double delta0 = eta * (real_result - calculate_result);//计算阈值更新
        w0 += delta0;//阈值更新
        weights[0] = w0;//更新w[0]
        for(int k = 0; k < exercise.get(j).getdimention(); k++)//更新权重向量其它分量
        {
          double delta = eta * (real_result - calculate_result) * exercise.get(j).vector[k];
          //δw=η*(y-y')*x
          weights[k+1] += delta;
          //w=w+δw
        }
      }
    }
  }
  private int calculateresult(data data) {
    double z = w0 * x0;
    for(int i = 0; i < data.dimention; i++)
      z += data.vector[i] * weights[i+1];
    //z=w0x0+w1x1+...+wmxm
    //激活函数
    if(z>=0)
      return 1;
    else
      return -1;
  }
  private void showweigths()
  {
    for(double w : weights)
      system.out.println(w);
  }
  private void predict(data data) {
    int type = calculateresult(data);
    if(type == data.gettype())
    {
      //system.out.println("预测正确");
    }
    else
    {
      //system.out.println("预测错误");
      error ++;
    }
    testsum ++;
  }
}

运行结果:

-0.22000000000000017
-0.4416843982815453
0.442444202054685
总共测试10000条数据,有17条错误,错误率:0.16999999999999998%

基于Java实现的一层简单人工神经网络算法示例

更多关于java算法相关内容感兴趣的读者可查看本站专题:《java数据结构与算法教程》、《java操作dom节点技巧总结》、《java文件与目录操作技巧汇总》和《java缓存操作技巧汇总

希望本文所述对大家java程序设计有所帮助。