001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math.optimization.general;
019    
020    import org.apache.commons.math.ConvergenceException;
021    import org.apache.commons.math.FunctionEvaluationException;
022    import org.apache.commons.math.analysis.UnivariateRealFunction;
023    import org.apache.commons.math.analysis.solvers.BrentSolver;
024    import org.apache.commons.math.analysis.solvers.UnivariateRealSolver;
025    import org.apache.commons.math.exception.util.LocalizedFormats;
026    import org.apache.commons.math.optimization.GoalType;
027    import org.apache.commons.math.optimization.OptimizationException;
028    import org.apache.commons.math.optimization.RealPointValuePair;
029    import org.apache.commons.math.util.FastMath;
030    
031    /**
032     * Non-linear conjugate gradient optimizer.
033     * <p>
034     * This class supports both the Fletcher-Reeves and the Polak-Ribière
035     * update formulas for the conjugate search directions. It also supports
036     * optional preconditioning.
037     * </p>
038     *
039     * @version $Revision: 1070725 $ $Date: 2011-02-15 02:31:12 +0100 (mar. 15 f??vr. 2011) $
040     * @since 2.0
041     *
042     */
043    
044    public class NonLinearConjugateGradientOptimizer
045        extends AbstractScalarDifferentiableOptimizer {
046    
047        /** Update formula for the beta parameter. */
048        private final ConjugateGradientFormula updateFormula;
049    
050        /** Preconditioner (may be null). */
051        private Preconditioner preconditioner;
052    
053        /** solver to use in the line search (may be null). */
054        private UnivariateRealSolver solver;
055    
056        /** Initial step used to bracket the optimum in line search. */
057        private double initialStep;
058    
059        /** Simple constructor with default settings.
060         * <p>The convergence check is set to a {@link
061         * org.apache.commons.math.optimization.SimpleVectorialValueChecker}
062         * and the maximal number of iterations is set to
063         * {@link AbstractScalarDifferentiableOptimizer#DEFAULT_MAX_ITERATIONS}.
064         * @param updateFormula formula to use for updating the β parameter,
065         * must be one of {@link ConjugateGradientFormula#FLETCHER_REEVES} or {@link
066         * ConjugateGradientFormula#POLAK_RIBIERE}
067         */
068        public NonLinearConjugateGradientOptimizer(final ConjugateGradientFormula updateFormula) {
069            this.updateFormula = updateFormula;
070            preconditioner     = null;
071            solver             = null;
072            initialStep        = 1.0;
073        }
074    
075        /**
076         * Set the preconditioner.
077         * @param preconditioner preconditioner to use for next optimization,
078         * may be null to remove an already registered preconditioner
079         */
080        public void setPreconditioner(final Preconditioner preconditioner) {
081            this.preconditioner = preconditioner;
082        }
083    
084        /**
085         * Set the solver to use during line search.
086         * @param lineSearchSolver solver to use during line search, may be null
087         * to remove an already registered solver and fall back to the
088         * default {@link BrentSolver Brent solver}.
089         */
090        public void setLineSearchSolver(final UnivariateRealSolver lineSearchSolver) {
091            this.solver = lineSearchSolver;
092        }
093    
094        /**
095         * Set the initial step used to bracket the optimum in line search.
096         * <p>
097         * The initial step is a factor with respect to the search direction,
098         * which itself is roughly related to the gradient of the function
099         * </p>
100         * @param initialStep initial step used to bracket the optimum in line search,
101         * if a non-positive value is used, the initial step is reset to its
102         * default value of 1.0
103         */
104        public void setInitialStep(final double initialStep) {
105            if (initialStep <= 0) {
106                this.initialStep = 1.0;
107            } else {
108                this.initialStep = initialStep;
109            }
110        }
111    
112        /** {@inheritDoc} */
113        @Override
114        protected RealPointValuePair doOptimize()
115            throws FunctionEvaluationException, OptimizationException, IllegalArgumentException {
116            try {
117    
118                // initialization
119                if (preconditioner == null) {
120                    preconditioner = new IdentityPreconditioner();
121                }
122                if (solver == null) {
123                    solver = new BrentSolver();
124                }
125                final int n = point.length;
126                double[] r = computeObjectiveGradient(point);
127                if (goal == GoalType.MINIMIZE) {
128                    for (int i = 0; i < n; ++i) {
129                        r[i] = -r[i];
130                    }
131                }
132    
133                // initial search direction
134                double[] steepestDescent = preconditioner.precondition(point, r);
135                double[] searchDirection = steepestDescent.clone();
136    
137                double delta = 0;
138                for (int i = 0; i < n; ++i) {
139                    delta += r[i] * searchDirection[i];
140                }
141    
142                RealPointValuePair current = null;
143                while (true) {
144    
145                    final double objective = computeObjectiveValue(point);
146                    RealPointValuePair previous = current;
147                    current = new RealPointValuePair(point, objective);
148                    if (previous != null) {
149                        if (checker.converged(getIterations(), previous, current)) {
150                            // we have found an optimum
151                            return current;
152                        }
153                    }
154    
155                    incrementIterationsCounter();
156    
157                    double dTd = 0;
158                    for (final double di : searchDirection) {
159                        dTd += di * di;
160                    }
161    
162                    // find the optimal step in the search direction
163                    final UnivariateRealFunction lsf = new LineSearchFunction(searchDirection);
164                    final double step = solver.solve(lsf, 0, findUpperBound(lsf, 0, initialStep));
165    
166                    // validate new point
167                    for (int i = 0; i < point.length; ++i) {
168                        point[i] += step * searchDirection[i];
169                    }
170                    r = computeObjectiveGradient(point);
171                    if (goal == GoalType.MINIMIZE) {
172                        for (int i = 0; i < n; ++i) {
173                            r[i] = -r[i];
174                        }
175                    }
176    
177                    // compute beta
178                    final double deltaOld = delta;
179                    final double[] newSteepestDescent = preconditioner.precondition(point, r);
180                    delta = 0;
181                    for (int i = 0; i < n; ++i) {
182                        delta += r[i] * newSteepestDescent[i];
183                    }
184    
185                    final double beta;
186                    if (updateFormula == ConjugateGradientFormula.FLETCHER_REEVES) {
187                        beta = delta / deltaOld;
188                    } else {
189                        double deltaMid = 0;
190                        for (int i = 0; i < r.length; ++i) {
191                            deltaMid += r[i] * steepestDescent[i];
192                        }
193                        beta = (delta - deltaMid) / deltaOld;
194                    }
195                    steepestDescent = newSteepestDescent;
196    
197                    // compute conjugate search direction
198                    if ((getIterations() % n == 0) || (beta < 0)) {
199                        // break conjugation: reset search direction
200                        searchDirection = steepestDescent.clone();
201                    } else {
202                        // compute new conjugate search direction
203                        for (int i = 0; i < n; ++i) {
204                            searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
205                        }
206                    }
207    
208                }
209    
210            } catch (ConvergenceException ce) {
211                throw new OptimizationException(ce);
212            }
213        }
214    
215        /**
216         * Find the upper bound b ensuring bracketing of a root between a and b
217         * @param f function whose root must be bracketed
218         * @param a lower bound of the interval
219         * @param h initial step to try
220         * @return b such that f(a) and f(b) have opposite signs
221         * @exception FunctionEvaluationException if the function cannot be computed
222         * @exception OptimizationException if no bracket can be found
223         */
224        private double findUpperBound(final UnivariateRealFunction f,
225                                      final double a, final double h)
226            throws FunctionEvaluationException, OptimizationException {
227            final double yA = f.value(a);
228            double yB = yA;
229            for (double step = h; step < Double.MAX_VALUE; step *= FastMath.max(2, yA / yB)) {
230                final double b = a + step;
231                yB = f.value(b);
232                if (yA * yB <= 0) {
233                    return b;
234                }
235            }
236            throw new OptimizationException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH);
237        }
238    
239        /** Default identity preconditioner. */
240        private static class IdentityPreconditioner implements Preconditioner {
241    
242            /** {@inheritDoc} */
243            public double[] precondition(double[] variables, double[] r) {
244                return r.clone();
245            }
246    
247        }
248    
249        /** Internal class for line search.
250         * <p>
251         * The function represented by this class is the dot product of
252         * the objective function gradient and the search direction. Its
253         * value is zero when the gradient is orthogonal to the search
254         * direction, i.e. when the objective function value is a local
255         * extremum along the search direction.
256         * </p>
257         */
258        private class LineSearchFunction implements UnivariateRealFunction {
259            /** Search direction. */
260            private final double[] searchDirection;
261    
262            /** Simple constructor.
263             * @param searchDirection search direction
264             */
265            public LineSearchFunction(final double[] searchDirection) {
266                this.searchDirection = searchDirection;
267            }
268    
269            /** {@inheritDoc} */
270            public double value(double x) throws FunctionEvaluationException {
271    
272                // current point in the search direction
273                final double[] shiftedPoint = point.clone();
274                for (int i = 0; i < shiftedPoint.length; ++i) {
275                    shiftedPoint[i] += x * searchDirection[i];
276                }
277    
278                // gradient of the objective function
279                final double[] gradient;
280                gradient = computeObjectiveGradient(shiftedPoint);
281    
282                // dot product with the search direction
283                double dotProduct = 0;
284                for (int i = 0; i < gradient.length; ++i) {
285                    dotProduct += gradient[i] * searchDirection[i];
286                }
287    
288                return dotProduct;
289    
290            }
291    
292        }
293    
294    }