欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

java commons-math3 计算cook距离 (基于最小二乘法的二元线性回归)

程序员文章站 2022-04-03 21:41:05
...
最近在利用java开发数据分析相关的web应用,在开发中遇到需要通过计算cook距离判断数据点是否为离异点的功能(cook距离自行百度,
数学这块不是很懂,模型是别人做的……)。
在python中statsmodels模块可以一句代码搞定,但是java好像没有现成的实现,在摸索下发现可以利用common-math3开源包实现,以下是具体代码:

在python中,cook距离计算有现成的包:

from statsmodels.formula.api import ols
import statsmodels.api as sm

X = [-3, -2, -1, 0, 1, 2, 3]
y = [4, 2, 3, 0, -1, -2, -5]
X = sm.add_constant(X)
lm = sm.OLS(y, X)
results = lm.fit()
infl = results.get_influence()
cook_d = infl.cooks_distance[0]
print(infl.summary_table())

java commons-math3 计算cook距离 (基于最小二乘法的二元线性回归)
Cook’s d - cook距离
Student residual - 学生化(内)残差
hat diag - 帽子矩阵对角线元素

以下是通过java common-math3 实现Cook距离计算

Maven中引入common-math3依赖

<dependency>
   <groupId>org.apache.commons</groupId>
    <artifactId>commons-math3</artifactId>
    <version>3.5</version>
</dependency>

java代码如下:

public class Test {
    public static void main(String[] args){
    	// 声明线性回归模型
        OLSMultipleLinearRegression OLS = 
        	  new OLSMultipleLinearRegression();
        // 添加测试数据
        double[] y = new double[]{4,2,3,0,-1,-2,-5};
        double[][] x = new double[7][];
        x[0] = new double[]{-3};
        x[1] = new double[]{-2};
        x[2] = new double[]{-1};
        x[3] = new double[]{0};
        x[4] = new double[]{1};
        x[5] = new double[]{2};
        x[6] = new double[]{3};
        OLS.newSampleData(y, x);
		// 帽子矩阵
        RealMatrix hat = OLS.calculateHat();
        // 斜率
        double slope = OLS.estimateRegressionParameters()[1];
        // 截距
        double intercept = OLS.estimateRegressionParameters()[0];
        // 多项式数量(本例是二元线性回归,k_var = 2)
        double k_var = OLS.estimateRegressionParameters().length;
        // 标准残差
        double[] residuals = OLS.estimateResiduals();
        // 均方误差MSE
        double MSE = OLS.estimateErrorVariance();

		// 计算并打印结果
        System.out.println("X:\t\tY:\t\tCook's d\tstudent residual\that diag");
        int idx = 0;
        while(idx < y.length){
        	// 帽子矩阵对角线第idx个元素
            double hatDiag = hat.getRow(idx)[idx];
            // 学生化(内)残差
            double ResidStudentizedInternal =
              residuals[idx] / Math.sqrt(MSE)/Math.sqrt(1.0d - hatDiag);
            // cook距离
            double cookD = 
              ResidStudentizedInternal * ResidStudentizedInternal / k_var;
            cookD *= hatDiag / (1 - hatDiag);

			// 打印结果
            System.out.println((x[idx][0] >= 0 ? " " + x[idx][0]: x[idx][0] + "") + "\t"
                    + (y[idx] >= 0 ? " " + y[idx]: y[idx] + "") + "\t\t"
                    + ((double) Math.round(cookD * 1000) / 1000) + "\t\t"
                    + ((double) Math.round(ResidStudentizedInternal * 1000) / 1000) + "\t\t\t\t"
                    + ((double) Math.round(hatDiag * 1000) / 1000));
            idx++;
        }
    }
}

结果如下:
java commons-math3 计算cook距离 (基于最小二乘法的二元线性回归)

相关标签: 数据分析 java