最近公司要做多元线性回归,计算的是某一列与多列的关系,入参是一个一维数组和一个二维数组。
首先,maven引入的依赖,此为apach提供的公共计算包,包含了各种各样的计算模型
<!-- https://mvnrepository.com/artifact/org.apache.commons/commons-math3 --> <dependency> <groupId>org.apache.commons</groupId> <artifactId>commons-math3</artifactId> <version>3.6.1</version> </dependency>代码片段:
package success; import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation; import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression; /** * 线性回归 */ public class ClineNone { //测试 public static void main(String[] args) { double [] y ={114, 49, 84, 79, 87, 74, 77, 82, 80, 88, 123, 82, 98, 65, 61, 78, 51, 121, 78, 50, 75, 65, 113, 122, 78, 119, 45, 89, 102, 75}; //标准化y y = dataStandardization(y); double[][] x ={{38,13,27,25,18,29,30,20,23,32,38,28,34,19,20,25,16,36,25,17,24,18,30,35,22,34,12,26,29,21}, {37,15,22,21,29,24,26,27,17,28,34,25,26,21,18,21,16,30,15,14,22,18,32,40,25,34,15,26,32,27}, {12,13,21,20,20,12,8,17,19,12,25,15,19,11,11,18,13,25,17,12,15,16,24,21,15,25,7,20,21,12}, {31,29,44,23,21,34,37,36,26,26,18,34,22,30,34,29,50,14,26,36,27,33,23,25,31,18,35,20,21,29}, {31,29,19,24,26,18,27,24,25,29,27,27,26,32,31,28,35,27,17,25,17,31,22,26,20,22,30,23,23,24}}; x = dataStandardizationDouble(x); OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); regression.newSampleData(y,x); double rSquared = regression.calculateRSquared(); //R方 System.out.println("R方"+rSquared); double[] doubles = regression.estimateRegressionParameters(); for (double d : doubles) { System.out.println("打印: " + d); } double f = getF(x, y, doubles); System.out.println("F:"+f); } /** * 数组进行数据标准化 * */ public static double[] dataStandardization(double array[]){ StandardDeviation deviation =new StandardDeviation(); double sum = 0; for(double i : array){ sum += i; } //均值 double avg = sum / array.length; //标准差 double evaluate = deviation.evaluate(array); //进行标准化 for(int i=0;i<array.length;i++){ array[i]=(array[i] - avg)/evaluate; } return array; } /** * 标准化多维数组 * */ public static double[][] dataStandardizationDouble(double arrays[][]){ double [][] result = new double[arrays[0].length][arrays.length]; for(int i=0;i<arrays.length;i++){ double[] doubles = dataStandardization(arrays[i]); for(int k=0;k<result.length;k++){ result[k][i]=doubles[k]; } } return result; } /** * 线性回归方程拿到F值 * @param x * @param y * @param back * @return */ public static double getF(double[][] x, double[] y,double[] back){ if(x.length!=y.length){ System.out.println("数组不相等"); } double sumY = 0; for(double d : y){ sumY += d; } double avgY = sumY/y.length; //回归的平方和 double SSR = 0; //残差的平方和 double SSE = 0; //y的估值 double yTemp = 0; for(int k=0;k<x.length;k++){ double temp = 0; for(int j =0; j<x[k].length; j++){ if(j==0){ temp += back[j]; } temp += x[k][j] * back[j + 1]; } yTemp = temp; //回归平方 temp = Math.pow(temp - avgY,2); SSR += temp; //残差平方 SSE += Math.pow(y[k]-yTemp,2); } //p值,自变量的个数 int p = x[0].length; //n值,为观测总值的个数 int n = y.length; //求F的计算公式 f=(SSR/p)/(SSE/(n-p-1)) double f = (SSR / p) / (SSE / (n - p - 1)); return f; } }