java commons-math3 计算cook距离 (基于最小二乘法的二元线性回归)
程序员文章站
2022-04-03 21:41:05
...
最近在利用java开发数据分析相关的web应用,在开发中遇到需要通过计算cook距离判断数据点是否为离异点的功能(cook距离自行百度,
数学这块不是很懂,模型是别人做的……)。
在python中statsmodels模块可以一句代码搞定,但是java好像没有现成的实现,在摸索下发现可以利用common-math3开源包实现,以下是具体代码:
在python中,cook距离计算有现成的包:
from statsmodels.formula.api import ols
import statsmodels.api as sm
X = [-3, -2, -1, 0, 1, 2, 3]
y = [4, 2, 3, 0, -1, -2, -5]
X = sm.add_constant(X)
lm = sm.OLS(y, X)
results = lm.fit()
infl = results.get_influence()
cook_d = infl.cooks_distance[0]
print(infl.summary_table())
Cook’s d - cook距离
Student residual - 学生化(内)残差
hat diag - 帽子矩阵对角线元素
以下是通过java common-math3 实现Cook距离计算
Maven中引入common-math3依赖
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.5</version>
</dependency>
java代码如下:
public class Test {
public static void main(String[] args){
// 声明线性回归模型
OLSMultipleLinearRegression OLS =
new OLSMultipleLinearRegression();
// 添加测试数据
double[] y = new double[]{4,2,3,0,-1,-2,-5};
double[][] x = new double[7][];
x[0] = new double[]{-3};
x[1] = new double[]{-2};
x[2] = new double[]{-1};
x[3] = new double[]{0};
x[4] = new double[]{1};
x[5] = new double[]{2};
x[6] = new double[]{3};
OLS.newSampleData(y, x);
// 帽子矩阵
RealMatrix hat = OLS.calculateHat();
// 斜率
double slope = OLS.estimateRegressionParameters()[1];
// 截距
double intercept = OLS.estimateRegressionParameters()[0];
// 多项式数量(本例是二元线性回归,k_var = 2)
double k_var = OLS.estimateRegressionParameters().length;
// 标准残差
double[] residuals = OLS.estimateResiduals();
// 均方误差MSE
double MSE = OLS.estimateErrorVariance();
// 计算并打印结果
System.out.println("X:\t\tY:\t\tCook's d\tstudent residual\that diag");
int idx = 0;
while(idx < y.length){
// 帽子矩阵对角线第idx个元素
double hatDiag = hat.getRow(idx)[idx];
// 学生化(内)残差
double ResidStudentizedInternal =
residuals[idx] / Math.sqrt(MSE)/Math.sqrt(1.0d - hatDiag);
// cook距离
double cookD =
ResidStudentizedInternal * ResidStudentizedInternal / k_var;
cookD *= hatDiag / (1 - hatDiag);
// 打印结果
System.out.println((x[idx][0] >= 0 ? " " + x[idx][0]: x[idx][0] + "") + "\t"
+ (y[idx] >= 0 ? " " + y[idx]: y[idx] + "") + "\t\t"
+ ((double) Math.round(cookD * 1000) / 1000) + "\t\t"
+ ((double) Math.round(ResidStudentizedInternal * 1000) / 1000) + "\t\t\t\t"
+ ((double) Math.round(hatDiag * 1000) / 1000));
idx++;
}
}
}
结果如下:
上一篇: ???永远不要相信广告拍的照片
下一篇: 以父爱的名义