/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.optimize;

import cc.mallet.optimize.InvalidOptimizableException;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import java.util.LinkedList;
import java.util.logging.Logger;

public class OrthantWiseLimitedMemoryBFGS
implements Optimizer {
    private static Logger logger = MalletLogger.getLogger(OrthantWiseLimitedMemoryBFGS.class.getName());
    boolean converged = false;
    Optimizable.ByGradientValue optimizable;
    String optName;
    final int maxIterations = 1000;
    final double tolerance = 1.0E-4;
    final double gradientTolerance = 0.001;
    final double eps = 1.0E-5;
    double l1Weight;
    final int m = 4;
    double oldValue;
    double value;
    double yDotY;
    double[] grad;
    double[] oldGrad;
    double[] direction;
    double[] steepestDescentDirection;
    double[] parameters;
    double[] oldParameters;
    LinkedList<double[]> s;
    LinkedList<double[]> y;
    LinkedList<Double> rhos;
    double[] alphas;
    int iterations;

    public OrthantWiseLimitedMemoryBFGS(Optimizable.ByGradientValue function) {
        this(function, 0.0);
    }

    public OrthantWiseLimitedMemoryBFGS(Optimizable.ByGradientValue function, double l1wt) {
        this.optimizable = function;
        this.l1Weight = l1wt;
        String[] parts = this.optimizable.getClass().getName().split("\\.");
        this.optName = parts[parts.length - 1];
        this.iterations = 0;
        this.s = new LinkedList();
        this.y = new LinkedList();
        this.rhos = new LinkedList();
        this.alphas = new double[4];
        MatrixOps.setAll(this.alphas, 0.0);
        this.yDotY = 0.0;
        int numParameters = this.optimizable.getNumParameters();
        this.parameters = new double[numParameters];
        this.optimizable.getParameters(this.parameters);
        this.value = this.evalL1();
        this.grad = new double[numParameters];
        this.evalGradient();
        this.direction = new double[numParameters];
        this.steepestDescentDirection = new double[numParameters];
        this.oldParameters = new double[numParameters];
        this.oldGrad = new double[numParameters];
    }

    @Override
    public Optimizable getOptimizable() {
        return this.optimizable;
    }

    @Override
    public boolean isConverged() {
        return this.converged;
    }

    public int getIteration() {
        return this.iterations;
    }

    @Override
    public boolean optimize() {
        return this.optimize(Integer.MAX_VALUE);
    }

    @Override
    public boolean optimize(int numIterations) {
        logger.fine("Entering OWL-BFGS.optimize(). L1 weight=" + this.l1Weight + " Initial Value=" + this.value);
        for (int iter = 0; iter < numIterations; ++iter) {
            this.makeSteepestDescDir();
            this.mapDirByInverseHessian(this.yDotY);
            this.fixDirSigns();
            this.storeSrcInDest(this.parameters, this.oldParameters);
            this.storeSrcInDest(this.grad, this.oldGrad);
            this.backTrackingLineSearch();
            this.evalGradient();
            if (this.checkValueTerminationCondition()) {
                logger.info("Exiting OWL-BFGS on termination #1:");
                logger.info("value difference below tolerance (oldValue: " + this.oldValue + " newValue: " + this.value);
                this.converged = true;
                return true;
            }
            if (this.checkGradientTerminationCondition()) {
                logger.info("Exiting OWL-BFGS on termination #2:");
                logger.info("gradient=" + MatrixOps.twoNorm(this.grad) + " < " + 0.001);
                this.converged = true;
                return true;
            }
            this.yDotY = this.shift();
            ++this.iterations;
            if (this.iterations <= 1000) continue;
            logger.info("Too many iterations in OWL-BFGS. Continuing with current parameters.");
            this.converged = true;
            return true;
        }
        return false;
    }

    private double evalL1() {
        double val = -this.optimizable.getValue();
        double sumAbsWt = 0.0;
        if (this.l1Weight > 0.0) {
            for (double param : this.parameters) {
                if (Double.isInfinite(param)) continue;
                sumAbsWt += Math.abs(param) * this.l1Weight;
            }
        }
        logger.info("getValue() (" + this.optName + ".getValue() = " + val + " + |w|=" + sumAbsWt + ") = " + (val + sumAbsWt));
        return val + sumAbsWt;
    }

    private void evalGradient() {
        this.optimizable.getValueGradient(this.grad);
        this.adjustGradForInfiniteParams(this.grad);
        MatrixOps.timesEquals(this.grad, -1.0);
    }

    private void makeSteepestDescDir() {
        if (this.l1Weight == 0.0) {
            for (int i = 0; i < this.grad.length; ++i) {
                this.direction[i] = -this.grad[i];
            }
        } else {
            for (int i = 0; i < this.grad.length; ++i) {
                this.direction[i] = this.parameters[i] < 0.0 ? -this.grad[i] + this.l1Weight : (this.parameters[i] > 0.0 ? -this.grad[i] - this.l1Weight : (this.grad[i] < -this.l1Weight ? -this.grad[i] - this.l1Weight : (this.grad[i] > this.l1Weight ? -this.grad[i] + this.l1Weight : 0.0)));
            }
        }
        this.storeSrcInDest(this.direction, this.steepestDescentDirection);
    }

    private void adjustGradForInfiniteParams(double[] d) {
        for (int i = 0; i < this.parameters.length; ++i) {
            if (!Double.isInfinite(this.parameters[i])) continue;
            d[i] = 0.0;
        }
    }

    private void mapDirByInverseHessian(double yDotY) {
        if (this.s.size() == 0) {
            return;
        }
        int count = this.s.size();
        for (int i = count - 1; i >= 0; --i) {
            this.alphas[i] = -MatrixOps.dotProduct(this.s.get(i), this.direction) / this.rhos.get(i);
            MatrixOps.plusEquals(this.direction, this.y.get(i), this.alphas[i]);
        }
        double scalar = this.rhos.get(count - 1) / yDotY;
        logger.fine("Direction multiplier = " + scalar);
        MatrixOps.timesEquals(this.direction, scalar);
        for (int i = 0; i < count; ++i) {
            double beta = MatrixOps.dotProduct(this.y.get(i), this.direction) / this.rhos.get(i);
            MatrixOps.plusEquals(this.direction, this.s.get(i), -this.alphas[i] - beta);
        }
    }

    private void fixDirSigns() {
        if (this.l1Weight > 0.0) {
            for (int i = 0; i < this.direction.length; ++i) {
                if (!(this.direction[i] * this.steepestDescentDirection[i] <= 0.0)) continue;
                this.direction[i] = 0.0;
            }
        }
    }

    private double dirDeriv() {
        if (this.l1Weight == 0.0) {
            return MatrixOps.dotProduct(this.direction, this.grad);
        }
        double val = 0.0;
        for (int i = 0; i < this.direction.length; ++i) {
            if (this.direction[i] == 0.0) continue;
            if (this.parameters[i] < 0.0) {
                val += this.direction[i] * (this.grad[i] - this.l1Weight);
                continue;
            }
            if (this.parameters[i] > 0.0) {
                val += this.direction[i] * (this.grad[i] + this.l1Weight);
                continue;
            }
            if (this.direction[i] < 0.0) {
                val += this.direction[i] * (this.grad[i] - this.l1Weight);
                continue;
            }
            if (!(this.direction[i] > 0.0)) continue;
            val += this.direction[i] * (this.grad[i] + this.l1Weight);
        }
        return val;
    }

    private double shift() {
        double[] nextS = null;
        double[] nextY = null;
        int listSize = this.s.size();
        if (listSize < 4) {
            nextS = new double[this.parameters.length];
            nextY = new double[this.parameters.length];
        } else {
            nextS = this.s.removeFirst();
            nextY = this.y.removeFirst();
            this.rhos.removeFirst();
        }
        double rho = 0.0;
        double yDotY = 0.0;
        for (int i = 0; i < this.parameters.length; ++i) {
            nextS[i] = Double.isInfinite(this.parameters[i]) && Double.isInfinite(this.oldParameters[i]) && this.parameters[i] * this.oldParameters[i] > 0.0 ? 0.0 : this.parameters[i] - this.oldParameters[i];
            nextY[i] = Double.isInfinite(this.grad[i]) && Double.isInfinite(this.oldGrad[i]) && this.grad[i] * this.oldGrad[i] > 0.0 ? 0.0 : this.grad[i] - this.oldGrad[i];
            rho += nextS[i] * nextY[i];
            yDotY += nextY[i] * nextY[i];
        }
        logger.fine("rho=" + rho);
        if (rho < 0.0) {
            throw new InvalidOptimizableException("rho = " + rho + " < 0: " + "Invalid hessian inverse. " + "Gradient change should be opposite of parameter change.");
        }
        this.s.addLast(nextS);
        this.y.addLast(nextY);
        this.rhos.addLast(rho);
        this.storeSrcInDest(this.parameters, this.oldParameters);
        this.storeSrcInDest(this.grad, this.oldGrad);
        return yDotY;
    }

    private void storeSrcInDest(double[] src, double[] dest) {
        System.arraycopy(src, 0, dest, 0, src.length);
    }

    private void backTrackingLineSearch() {
        double origDirDeriv = this.dirDeriv();
        if (origDirDeriv >= 0.0) {
            throw new InvalidOptimizableException("L-BFGS chose a non-ascent direction: check your gradient!");
        }
        double alpha = 1.0;
        double backoff = 0.5;
        if (this.iterations == 0) {
            double normDir = Math.sqrt(MatrixOps.dotProduct(this.direction, this.direction));
            alpha = 1.0 / normDir;
            backoff = 0.1;
        }
        double c1 = 1.0E-4;
        this.oldValue = this.value;
        logger.fine("*** Starting line search iter=" + this.iterations);
        logger.fine("iter[" + this.iterations + "] Value at start of line search = " + this.value);
        while (true) {
            this.getNextPoint(alpha);
            this.value = this.evalL1();
            logger.fine("iter[" + this.iterations + "] Using alpha = " + alpha + " new value = " + this.value + " |grad|=" + MatrixOps.twoNorm(this.grad) + " |x|=" + MatrixOps.twoNorm(this.parameters));
            if (this.value <= this.oldValue + 1.0E-4 * origDirDeriv * alpha) break;
            alpha *= backoff;
        }
    }

    private void getNextPoint(double alpha) {
        for (int i = 0; i < this.parameters.length; ++i) {
            this.parameters[i] = this.oldParameters[i] + this.direction[i] * alpha;
            if (!(this.l1Weight > 0.0) || !(this.oldParameters[i] * this.parameters[i] < 0.0)) continue;
            this.parameters[i] = 0.0;
        }
        this.optimizable.setParameters(this.parameters);
    }

    private boolean checkValueTerminationCondition() {
        return 2.0 * Math.abs(this.value - this.oldValue) <= 1.0E-4 * (Math.abs(this.value) + Math.abs(this.oldValue) + 1.0E-5);
    }

    private boolean checkGradientTerminationCondition() {
        return MatrixOps.twoNorm(this.grad) < 0.001;
    }
}

