001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math.optimization.direct;
019    
020    import java.util.Comparator;
021    
022    import org.apache.commons.math.FunctionEvaluationException;
023    import org.apache.commons.math.optimization.OptimizationException;
024    import org.apache.commons.math.optimization.RealPointValuePair;
025    
026    /**
027     * This class implements the Nelder-Mead direct search method.
028     *
029     * @version $Revision: 811685 $ $Date: 2009-09-05 13:36:48 -0400 (Sat, 05 Sep 2009) $
030     * @see MultiDirectional
031     * @since 1.2
032     */
033    public class NelderMead extends DirectSearchOptimizer {
034    
035        /** Reflection coefficient. */
036        private final double rho;
037    
038        /** Expansion coefficient. */
039        private final double khi;
040    
041        /** Contraction coefficient. */
042        private final double gamma;
043    
044        /** Shrinkage coefficient. */
045        private final double sigma;
046    
047        /** Build a Nelder-Mead optimizer with default coefficients.
048         * <p>The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
049         * for both gamma and sigma.</p>
050         */
051        public NelderMead() {
052            this.rho   = 1.0;
053            this.khi   = 2.0;
054            this.gamma = 0.5;
055            this.sigma = 0.5;
056        }
057    
058        /** Build a Nelder-Mead optimizer with specified coefficients.
059         * @param rho reflection coefficient
060         * @param khi expansion coefficient
061         * @param gamma contraction coefficient
062         * @param sigma shrinkage coefficient
063         */
064        public NelderMead(final double rho, final double khi,
065                          final double gamma, final double sigma) {
066            this.rho   = rho;
067            this.khi   = khi;
068            this.gamma = gamma;
069            this.sigma = sigma;
070        }
071    
072        /** {@inheritDoc} */
073        @Override
074        protected void iterateSimplex(final Comparator<RealPointValuePair> comparator)
075            throws FunctionEvaluationException, OptimizationException {
076    
077            incrementIterationsCounter();
078    
079            // the simplex has n+1 point if dimension is n
080            final int n = simplex.length - 1;
081    
082            // interesting values
083            final RealPointValuePair best       = simplex[0];
084            final RealPointValuePair secondBest = simplex[n-1];
085            final RealPointValuePair worst      = simplex[n];
086            final double[] xWorst = worst.getPointRef();
087    
088            // compute the centroid of the best vertices
089            // (dismissing the worst point at index n)
090            final double[] centroid = new double[n];
091            for (int i = 0; i < n; ++i) {
092                final double[] x = simplex[i].getPointRef();
093                for (int j = 0; j < n; ++j) {
094                    centroid[j] += x[j];
095                }
096            }
097            final double scaling = 1.0 / n;
098            for (int j = 0; j < n; ++j) {
099                centroid[j] *= scaling;
100            }
101    
102            // compute the reflection point
103            final double[] xR = new double[n];
104            for (int j = 0; j < n; ++j) {
105                xR[j] = centroid[j] + rho * (centroid[j] - xWorst[j]);
106            }
107            final RealPointValuePair reflected = new RealPointValuePair(xR, evaluate(xR), false);
108    
109            if ((comparator.compare(best, reflected) <= 0) &&
110                (comparator.compare(reflected, secondBest) < 0)) {
111    
112                // accept the reflected point
113                replaceWorstPoint(reflected, comparator);
114    
115            } else if (comparator.compare(reflected, best) < 0) {
116    
117                // compute the expansion point
118                final double[] xE = new double[n];
119                for (int j = 0; j < n; ++j) {
120                    xE[j] = centroid[j] + khi * (xR[j] - centroid[j]);
121                }
122                final RealPointValuePair expanded = new RealPointValuePair(xE, evaluate(xE), false);
123    
124                if (comparator.compare(expanded, reflected) < 0) {
125                    // accept the expansion point
126                    replaceWorstPoint(expanded, comparator);
127                } else {
128                    // accept the reflected point
129                    replaceWorstPoint(reflected, comparator);
130                }
131    
132            } else {
133    
134                if (comparator.compare(reflected, worst) < 0) {
135    
136                    // perform an outside contraction
137                    final double[] xC = new double[n];
138                    for (int j = 0; j < n; ++j) {
139                        xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]);
140                    }
141                    final RealPointValuePair outContracted = new RealPointValuePair(xC, evaluate(xC), false);
142    
143                    if (comparator.compare(outContracted, reflected) <= 0) {
144                        // accept the contraction point
145                        replaceWorstPoint(outContracted, comparator);
146                        return;
147                    }
148    
149                } else {
150    
151                    // perform an inside contraction
152                    final double[] xC = new double[n];
153                    for (int j = 0; j < n; ++j) {
154                        xC[j] = centroid[j] - gamma * (centroid[j] - xWorst[j]);
155                    }
156                    final RealPointValuePair inContracted = new RealPointValuePair(xC, evaluate(xC), false);
157    
158                    if (comparator.compare(inContracted, worst) < 0) {
159                        // accept the contraction point
160                        replaceWorstPoint(inContracted, comparator);
161                        return;
162                    }
163    
164                }
165    
166                // perform a shrink
167                final double[] xSmallest = simplex[0].getPointRef();
168                for (int i = 1; i < simplex.length; ++i) {
169                    final double[] x = simplex[i].getPoint();
170                    for (int j = 0; j < n; ++j) {
171                        x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]);
172                    }
173                    simplex[i] = new RealPointValuePair(x, Double.NaN, false);
174                }
175                evaluateSimplex(comparator);
176    
177            }
178    
179        }
180    
181    }