001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018 package org.apache.commons.math.optimization.general; 019 020 import org.apache.commons.math.ConvergenceException; 021 import org.apache.commons.math.FunctionEvaluationException; 022 import org.apache.commons.math.analysis.UnivariateRealFunction; 023 import org.apache.commons.math.analysis.solvers.BrentSolver; 024 import org.apache.commons.math.analysis.solvers.UnivariateRealSolver; 025 import org.apache.commons.math.optimization.GoalType; 026 import org.apache.commons.math.optimization.OptimizationException; 027 import org.apache.commons.math.optimization.DifferentiableMultivariateRealOptimizer; 028 import org.apache.commons.math.optimization.RealPointValuePair; 029 030 /** 031 * Non-linear conjugate gradient optimizer. 032 * <p> 033 * This class supports both the Fletcher-Reeves and the Polak-Ribière 034 * update formulas for the conjugate search directions. It also supports 035 * optional preconditioning. 036 * </p> 037 * 038 * @version $Revision: 811685 $ $Date: 2009-09-05 13:36:48 -0400 (Sat, 05 Sep 2009) $ 039 * @since 2.0 040 * 041 */ 042 043 public class NonLinearConjugateGradientOptimizer 044 extends AbstractScalarDifferentiableOptimizer 045 implements DifferentiableMultivariateRealOptimizer { 046 047 /** Update formula for the beta parameter. */ 048 private final ConjugateGradientFormula updateFormula; 049 050 /** Preconditioner (may be null). */ 051 private Preconditioner preconditioner; 052 053 /** solver to use in the line search (may be null). */ 054 private UnivariateRealSolver solver; 055 056 /** Initial step used to bracket the optimum in line search. */ 057 private double initialStep; 058 059 /** Simple constructor with default settings. 060 * <p>The convergence check is set to a {@link 061 * org.apache.commons.math.optimization.SimpleVectorialValueChecker} 062 * and the maximal number of iterations is set to 063 * {@link AbstractScalarDifferentiableOptimizer#DEFAULT_MAX_ITERATIONS}. 064 * @param updateFormula formula to use for updating the β parameter, 065 * must be one of {@link ConjugateGradientFormula#FLETCHER_REEVES} or {@link 066 * ConjugateGradientFormula#POLAK_RIBIERE} 067 */ 068 public NonLinearConjugateGradientOptimizer(final ConjugateGradientFormula updateFormula) { 069 this.updateFormula = updateFormula; 070 preconditioner = null; 071 solver = null; 072 initialStep = 1.0; 073 } 074 075 /** 076 * Set the preconditioner. 077 * @param preconditioner preconditioner to use for next optimization, 078 * may be null to remove an already registered preconditioner 079 */ 080 public void setPreconditioner(final Preconditioner preconditioner) { 081 this.preconditioner = preconditioner; 082 } 083 084 /** 085 * Set the solver to use during line search. 086 * @param lineSearchSolver solver to use during line search, may be null 087 * to remove an already registered solver and fall back to the 088 * default {@link BrentSolver Brent solver}. 089 */ 090 public void setLineSearchSolver(final UnivariateRealSolver lineSearchSolver) { 091 this.solver = lineSearchSolver; 092 } 093 094 /** 095 * Set the initial step used to bracket the optimum in line search. 096 * <p> 097 * The initial step is a factor with respect to the search direction, 098 * which itself is roughly related to the gradient of the function 099 * </p> 100 * @param initialStep initial step used to bracket the optimum in line search, 101 * if a non-positive value is used, the initial step is reset to its 102 * default value of 1.0 103 */ 104 public void setInitialStep(final double initialStep) { 105 if (initialStep <= 0) { 106 this.initialStep = 1.0; 107 } else { 108 this.initialStep = initialStep; 109 } 110 } 111 112 /** {@inheritDoc} */ 113 @Override 114 protected RealPointValuePair doOptimize() 115 throws FunctionEvaluationException, OptimizationException, IllegalArgumentException { 116 try { 117 118 // initialization 119 if (preconditioner == null) { 120 preconditioner = new IdentityPreconditioner(); 121 } 122 if (solver == null) { 123 solver = new BrentSolver(); 124 } 125 final int n = point.length; 126 double[] r = computeObjectiveGradient(point); 127 if (goal == GoalType.MINIMIZE) { 128 for (int i = 0; i < n; ++i) { 129 r[i] = -r[i]; 130 } 131 } 132 133 // initial search direction 134 double[] steepestDescent = preconditioner.precondition(point, r); 135 double[] searchDirection = steepestDescent.clone(); 136 137 double delta = 0; 138 for (int i = 0; i < n; ++i) { 139 delta += r[i] * searchDirection[i]; 140 } 141 142 RealPointValuePair current = null; 143 while (true) { 144 145 final double objective = computeObjectiveValue(point); 146 RealPointValuePair previous = current; 147 current = new RealPointValuePair(point, objective); 148 if (previous != null) { 149 if (checker.converged(getIterations(), previous, current)) { 150 // we have found an optimum 151 return current; 152 } 153 } 154 155 incrementIterationsCounter(); 156 157 double dTd = 0; 158 for (final double di : searchDirection) { 159 dTd += di * di; 160 } 161 162 // find the optimal step in the search direction 163 final UnivariateRealFunction lsf = new LineSearchFunction(searchDirection); 164 final double step = solver.solve(lsf, 0, findUpperBound(lsf, 0, initialStep)); 165 166 // validate new point 167 for (int i = 0; i < point.length; ++i) { 168 point[i] += step * searchDirection[i]; 169 } 170 r = computeObjectiveGradient(point); 171 if (goal == GoalType.MINIMIZE) { 172 for (int i = 0; i < n; ++i) { 173 r[i] = -r[i]; 174 } 175 } 176 177 // compute beta 178 final double deltaOld = delta; 179 final double[] newSteepestDescent = preconditioner.precondition(point, r); 180 delta = 0; 181 for (int i = 0; i < n; ++i) { 182 delta += r[i] * newSteepestDescent[i]; 183 } 184 185 final double beta; 186 if (updateFormula == ConjugateGradientFormula.FLETCHER_REEVES) { 187 beta = delta / deltaOld; 188 } else { 189 double deltaMid = 0; 190 for (int i = 0; i < r.length; ++i) { 191 deltaMid += r[i] * steepestDescent[i]; 192 } 193 beta = (delta - deltaMid) / deltaOld; 194 } 195 steepestDescent = newSteepestDescent; 196 197 // compute conjugate search direction 198 if ((getIterations() % n == 0) || (beta < 0)) { 199 // break conjugation: reset search direction 200 searchDirection = steepestDescent.clone(); 201 } else { 202 // compute new conjugate search direction 203 for (int i = 0; i < n; ++i) { 204 searchDirection[i] = steepestDescent[i] + beta * searchDirection[i]; 205 } 206 } 207 208 } 209 210 } catch (ConvergenceException ce) { 211 throw new OptimizationException(ce); 212 } 213 } 214 215 /** 216 * Find the upper bound b ensuring bracketing of a root between a and b 217 * @param f function whose root must be bracketed 218 * @param a lower bound of the interval 219 * @param h initial step to try 220 * @return b such that f(a) and f(b) have opposite signs 221 * @exception FunctionEvaluationException if the function cannot be computed 222 * @exception OptimizationException if no bracket can be found 223 */ 224 private double findUpperBound(final UnivariateRealFunction f, 225 final double a, final double h) 226 throws FunctionEvaluationException, OptimizationException { 227 final double yA = f.value(a); 228 double yB = yA; 229 for (double step = h; step < Double.MAX_VALUE; step *= Math.max(2, yA / yB)) { 230 final double b = a + step; 231 yB = f.value(b); 232 if (yA * yB <= 0) { 233 return b; 234 } 235 } 236 throw new OptimizationException("unable to bracket optimum in line search"); 237 } 238 239 /** Default identity preconditioner. */ 240 private static class IdentityPreconditioner implements Preconditioner { 241 242 /** {@inheritDoc} */ 243 public double[] precondition(double[] variables, double[] r) { 244 return r.clone(); 245 } 246 247 } 248 249 /** Internal class for line search. 250 * <p> 251 * The function represented by this class is the dot product of 252 * the objective function gradient and the search direction. Its 253 * value is zero when the gradient is orthogonal to the search 254 * direction, i.e. when the objective function value is a local 255 * extremum along the search direction. 256 * </p> 257 */ 258 private class LineSearchFunction implements UnivariateRealFunction { 259 /** Search direction. */ 260 private final double[] searchDirection; 261 262 /** Simple constructor. 263 * @param searchDirection search direction 264 */ 265 public LineSearchFunction(final double[] searchDirection) { 266 this.searchDirection = searchDirection; 267 } 268 269 /** {@inheritDoc} */ 270 public double value(double x) throws FunctionEvaluationException { 271 272 // current point in the search direction 273 final double[] shiftedPoint = point.clone(); 274 for (int i = 0; i < shiftedPoint.length; ++i) { 275 shiftedPoint[i] += x * searchDirection[i]; 276 } 277 278 // gradient of the objective function 279 final double[] gradient = computeObjectiveGradient(shiftedPoint); 280 281 // dot product with the search direction 282 double dotProduct = 0; 283 for (int i = 0; i < gradient.length; ++i) { 284 dotProduct += gradient[i] * searchDirection[i]; 285 } 286 287 return dotProduct; 288 289 } 290 291 } 292 293 }