001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 package org.apache.commons.math.stat.regression; 018 019 import org.apache.commons.math.MathRuntimeException; 020 import org.apache.commons.math.linear.RealMatrix; 021 import org.apache.commons.math.linear.Array2DRowRealMatrix; 022 import org.apache.commons.math.linear.RealVector; 023 import org.apache.commons.math.linear.ArrayRealVector; 024 025 /** 026 * Abstract base class for implementations of MultipleLinearRegression. 027 * @version $Revision: 811685 $ $Date: 2009-09-05 13:36:48 -0400 (Sat, 05 Sep 2009) $ 028 * @since 2.0 029 */ 030 public abstract class AbstractMultipleLinearRegression implements 031 MultipleLinearRegression { 032 033 /** X sample data. */ 034 protected RealMatrix X; 035 036 /** Y sample data. */ 037 protected RealVector Y; 038 039 /** 040 * Loads model x and y sample data from a flat array of data, overriding any previous sample. 041 * Assumes that rows are concatenated with y values first in each row. 042 * 043 * @param data input data array 044 * @param nobs number of observations (rows) 045 * @param nvars number of independent variables (columns, not counting y) 046 */ 047 public void newSampleData(double[] data, int nobs, int nvars) { 048 double[] y = new double[nobs]; 049 double[][] x = new double[nobs][nvars + 1]; 050 int pointer = 0; 051 for (int i = 0; i < nobs; i++) { 052 y[i] = data[pointer++]; 053 x[i][0] = 1.0d; 054 for (int j = 1; j < nvars + 1; j++) { 055 x[i][j] = data[pointer++]; 056 } 057 } 058 this.X = new Array2DRowRealMatrix(x); 059 this.Y = new ArrayRealVector(y); 060 } 061 062 /** 063 * Loads new y sample data, overriding any previous sample 064 * 065 * @param y the [n,1] array representing the y sample 066 */ 067 protected void newYSampleData(double[] y) { 068 this.Y = new ArrayRealVector(y); 069 } 070 071 /** 072 * Loads new x sample data, overriding any previous sample 073 * 074 * @param x the [n,k] array representing the x sample 075 */ 076 protected void newXSampleData(double[][] x) { 077 this.X = new Array2DRowRealMatrix(x); 078 } 079 080 /** 081 * Validates sample data. 082 * 083 * @param x the [n,k] array representing the x sample 084 * @param y the [n,1] array representing the y sample 085 * @throws IllegalArgumentException if the x and y array data are not 086 * compatible for the regression 087 */ 088 protected void validateSampleData(double[][] x, double[] y) { 089 if ((x == null) || (y == null) || (x.length != y.length)) { 090 throw MathRuntimeException.createIllegalArgumentException( 091 "dimension mismatch {0} != {1}", 092 (x == null) ? 0 : x.length, 093 (y == null) ? 0 : y.length); 094 } else if ((x.length > 0) && (x[0].length > x.length)) { 095 throw MathRuntimeException.createIllegalArgumentException( 096 "not enough data ({0} rows) for this many predictors ({1} predictors)", 097 x.length, x[0].length); 098 } 099 } 100 101 /** 102 * Validates sample data. 103 * 104 * @param x the [n,k] array representing the x sample 105 * @param covariance the [n,n] array representing the covariance matrix 106 * @throws IllegalArgumentException if the x sample data or covariance 107 * matrix are not compatible for the regression 108 */ 109 protected void validateCovarianceData(double[][] x, double[][] covariance) { 110 if (x.length != covariance.length) { 111 throw MathRuntimeException.createIllegalArgumentException( 112 "dimension mismatch {0} != {1}", x.length, covariance.length); 113 } 114 if (covariance.length > 0 && covariance.length != covariance[0].length) { 115 throw MathRuntimeException.createIllegalArgumentException( 116 "a {0}x{1} matrix was provided instead of a square matrix", 117 covariance.length, covariance[0].length); 118 } 119 } 120 121 /** 122 * {@inheritDoc} 123 */ 124 public double[] estimateRegressionParameters() { 125 RealVector b = calculateBeta(); 126 return b.getData(); 127 } 128 129 /** 130 * {@inheritDoc} 131 */ 132 public double[] estimateResiduals() { 133 RealVector b = calculateBeta(); 134 RealVector e = Y.subtract(X.operate(b)); 135 return e.getData(); 136 } 137 138 /** 139 * {@inheritDoc} 140 */ 141 public double[][] estimateRegressionParametersVariance() { 142 return calculateBetaVariance().getData(); 143 } 144 145 /** 146 * {@inheritDoc} 147 */ 148 public double[] estimateRegressionParametersStandardErrors() { 149 double[][] betaVariance = estimateRegressionParametersVariance(); 150 double sigma = calculateYVariance(); 151 int length = betaVariance[0].length; 152 double[] result = new double[length]; 153 for (int i = 0; i < length; i++) { 154 result[i] = Math.sqrt(sigma * betaVariance[i][i]); 155 } 156 return result; 157 } 158 159 /** 160 * {@inheritDoc} 161 */ 162 public double estimateRegressandVariance() { 163 return calculateYVariance(); 164 } 165 166 /** 167 * Calculates the beta of multiple linear regression in matrix notation. 168 * 169 * @return beta 170 */ 171 protected abstract RealVector calculateBeta(); 172 173 /** 174 * Calculates the beta variance of multiple linear regression in matrix 175 * notation. 176 * 177 * @return beta variance 178 */ 179 protected abstract RealMatrix calculateBetaVariance(); 180 181 /** 182 * Calculates the Y variance of multiple linear regression. 183 * 184 * @return Y variance 185 */ 186 protected abstract double calculateYVariance(); 187 188 /** 189 * Calculates the residuals of multiple linear regression in matrix 190 * notation. 191 * 192 * <pre> 193 * u = y - X * b 194 * </pre> 195 * 196 * @return The residuals [n,1] matrix 197 */ 198 protected RealVector calculateResiduals() { 199 RealVector b = calculateBeta(); 200 return Y.subtract(X.operate(b)); 201 } 202 203 }