欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

java实现逻辑回归

程序员文章站 2022-03-22 17:31:29
...

pom.xml

		<!-- 用于矩阵运算 -->
                <dependency>
			<groupId>org.ujmp</groupId>
			<artifactId>ujmp-core</artifactId>
			<version>0.3.0</version>
		</dependency>
                <!-- 用于显示散点图-->
		<dependency>
			<groupId>org.jfree</groupId>
			<artifactId>jfreechart</artifactId>
			<version>1.5.0</version>
		</dependency>

 

    LogisticRegression主类

package logisticregression;

import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;

public class LogisticRegression {

	public static double[] train(double[][] data, double[] classValues) {

		if (data != null && classValues != null && data.length == classValues.length) {
			Matrix matrWeights = DenseMatrix.Factory.zeros(data[0].length + 1, 1);
			Matrix matrData = DenseMatrix.Factory.zeros(data.length, data[0].length + 1);
			Matrix matrLable = DenseMatrix.Factory.zeros(data.length, 1);
			for (int i = 0; i < data.length; i++) {
				matrData.setAsDouble(1.0, i, 0);
				matrLable.setAsDouble(classValues[i], i, 0);
				for (int j = 0; j < data[0].length; j++) {
					matrData.setAsDouble(data[i][j], i, j + 1);
					if (i == 0) {
						matrWeights.setAsDouble(1.0, j, 0);

					}
				}
			}
			matrWeights.setAsDouble(-0.5, data[0].length, 0);

			double step = 0.01;
			int maxCycle = 5000000;

			for (int i = 0; i < maxCycle; i++) {
				Matrix h = sigmoid(matrData.mtimes(matrWeights));
				Matrix difference = matrLable.minus(h);
				matrWeights = matrWeights.plus(matrData.transpose().mtimes(difference).times(step));
			}

			double[] rtn = new double[(int) matrWeights.getRowCount()];
			for (long i = 0; i < matrWeights.getRowCount(); i++) {
				rtn[(int) i] = matrWeights.getAsDouble(i, 0);
			}

			return rtn;

		}

		return null;
	}

	public static Matrix sigmoid(Matrix sourceMatrix) {
		Matrix rtn = DenseMatrix.Factory.zeros(sourceMatrix.getRowCount(), sourceMatrix.getColumnCount());
		for (int i = 0; i < sourceMatrix.getRowCount(); i++) {
			for (int j = 0; j < sourceMatrix.getColumnCount(); j++) {
				rtn.setAsDouble(sigmoid(sourceMatrix.getAsDouble(i, j)), i, j);
			}

		}

		return rtn;
	}

	public static double sigmoid(double source) {
		return 1.0 / (1 + Math.exp(-1 * source));
	}

	public static double getValue(double[] sourceData, double[] model) {
		double logisticRegressionValue = model[0];
		for (int i = 0; i < sourceData.length; i++) {
			logisticRegressionValue = logisticRegressionValue + sourceData[i] * model[i + 1];
		}
		logisticRegressionValue = sigmoid(logisticRegressionValue);

		return logisticRegressionValue;
	}

}

 

    逻辑回归测试类

package logisticregression;

import common.ScatterPlot;

public class LogisicRegressionTest {

	public static void main(String[] args) {
		double[][] sourceData = new double[][] { { -1, 1 }, { 0, 1 }, { 1, -1 }, { 1, 0 }, { 0, 0.1 }, { 0, -0.1 }, { -1, -1.1 }, { 1, 0.9 } };
		double[] classValue = new double[] { 1, 1, 0, 0, 1, 0, 0, 0 };
		double[] modle = LogisticRegression.train(sourceData, classValue);
		double logicValue = LogisticRegression.getValue(new double[] { 0, 0 }, modle);
		System.out.println("---model---");
		for (int i = 0; i < modle.length; i++) {
			System.out.println(modle[i]);
		}
		System.out.println("-----------");
		System.out.println(logicValue);

		double[][][] chartData = new double[3][][];
		double[][] c0 = new double[2][5];
		double[][] c1 = new double[2][3];
		c1[0][0] = sourceData[0][0];
		c1[1][0] = sourceData[0][1];

		c1[0][1] = sourceData[1][0];
		c1[1][1] = sourceData[1][1];

		c0[0][0] = sourceData[2][0];
		c0[1][0] = sourceData[2][1];

		c0[0][1] = sourceData[3][0];
		c0[1][1] = sourceData[3][1];

		c1[0][2] = sourceData[4][0];
		c1[1][2] = sourceData[4][1];

		c0[0][2] = sourceData[5][0];
		c0[1][2] = sourceData[5][1];

		c0[0][3] = sourceData[6][0];
		c0[1][3] = sourceData[6][1];

		c0[0][4] = sourceData[7][0];
		c0[1][4] = sourceData[7][1];

		String[] c = new String[] { "1", "0", "L" };
		double[][] c2 = new double[2][21];
		int ind = 0;
		for (double x = -1; x <= 1; x = x + 0.1) {
			c2[0][ind] = x;
			c2[1][ind] = (-modle[0] - modle[1] * x) / modle[2];
			ind++;
		}

		chartData[0] = c0;
		chartData[1] = c1;
		chartData[2] = c2;

		ScatterPlot.showScatterPlotChart("LogisticRegression", c, chartData);

	}

}

 

 

 

 

 

 

相关标签: 机器学习