package org.deeplearning4j.optimize.solvers;

import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.exception.InvalidStepException;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.stepfunctions.NegativeGradientStepFunction;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.LineOptimizer;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarSetValue;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/solvers/BackTrackLineSearch.class */
public class BackTrackLineSearch implements LineOptimizer {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) BackTrackLineSearch.class);
    private Model layer;
    private StepFunction stepFunction;
    private ConvexOptimizer optimizer;
    private int maxIterations;
    double stepMax;
    private boolean minObjectiveFunction;
    private double relTolx;
    private double absTolx;
    protected final double ALF = 9.999999747378752E-5d;

    public BackTrackLineSearch(Model model, StepFunction stepFunction, ConvexOptimizer convexOptimizer) {
        this.stepMax = 100.0d;
        this.minObjectiveFunction = true;
        this.relTolx = 1.0000000116860974E-7d;
        this.absTolx = 9.999999747378752E-5d;
        this.ALF = 9.999999747378752E-5d;
        this.layer = model;
        this.stepFunction = stepFunction;
        this.optimizer = convexOptimizer;
        this.maxIterations = model.conf().getMaxNumLineSearchIterations();
    }

    public BackTrackLineSearch(Model model, ConvexOptimizer convexOptimizer) {
        this(model, new NegativeDefaultStepFunction(), convexOptimizer);
    }

    public void setStepMax(double d) {
        this.stepMax = d;
    }

    public double getStepMax() {
        return this.stepMax;
    }

    public void setRelTolx(double d) {
        this.relTolx = d;
    }

    public void setAbsTolx(double d) {
        this.absTolx = d;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setMaxIterations(int i) {
        this.maxIterations = i;
    }

    public double setScoreFor(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        this.layer.setParams(iNDArray);
        this.layer.computeGradientAndScore(layerWorkspaceMgr);
        return this.layer.score();
    }

    @Override // org.deeplearning4j.optimize.api.LineOptimizer
    public double optimize(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, LayerWorkspaceMgr layerWorkspaceMgr) throws InvalidStepException {
        double d;
        this.minObjectiveFunction = (this.stepFunction instanceof NegativeDefaultStepFunction) || (this.stepFunction instanceof NegativeGradientStepFunction);
        Level1 level1 = Nd4j.getBlasWrapper().level1();
        double nrm2 = level1.nrm2(iNDArray3);
        double dot = (-1.0d) * Nd4j.getBlasWrapper().dot(iNDArray3, iNDArray2);
        log.debug("slope = {}", Double.valueOf(dot));
        INDArray abs = Transforms.abs(iNDArray);
        Nd4j.getExecutioner().exec(new ScalarSetValue(abs, 1));
        double d2 = 1.0d;
        double d3 = this.relTolx / Transforms.abs(iNDArray2).divi(abs).max(Integer.MAX_VALUE).getDouble(0L);
        double d4 = 0.0d;
        double d5 = 0.0d;
        double score = this.layer.score();
        double d6 = score;
        double d7 = score;
        double d8 = 1.0d;
        if (log.isTraceEnabled()) {
            double asum = level1.asum(iNDArray3);
            double max = FastMath.max(Double.NEGATIVE_INFINITY, iNDArray3.getDouble(level1.iamax(iNDArray3)));
            log.trace("ENTERING BACKTRACK\n");
            log.trace("Entering BackTrackLineSearch, value = " + score + ",\ndirection.oneNorm:" + asum + "  direction.infNorm:" + max);
        }
        if (nrm2 > this.stepMax) {
            log.warn("Attempted step too big. scaling: sum= {}, stepMax= {}", Double.valueOf(nrm2), Double.valueOf(this.stepMax));
            iNDArray3.muli(Double.valueOf(this.stepMax / nrm2));
        }
        for (int i = 0; i < this.maxIterations; i++) {
            if (log.isTraceEnabled()) {
                log.trace("BackTrack loop iteration {} : step={}, oldStep={}", Integer.valueOf(i), Double.valueOf(d2), Double.valueOf(d4));
                log.trace("before step, x.1norm: {} \nstep: {} \noldStep: {}", iNDArray.norm1(Integer.MAX_VALUE), Double.valueOf(d2), Double.valueOf(d4));
            }
            if (d2 == d4) {
                throw new IllegalArgumentException("Current step == oldStep");
            }
            INDArray dup = iNDArray.dup('f');
            this.stepFunction.step(dup, iNDArray3, d2);
            d4 = d2;
            if (log.isTraceEnabled()) {
                log.trace("after step, x.1norm: " + level1.asum(dup));
            }
            if (d2 < d3 || Nd4j.getExecutioner().execAndReturn((TransformOp) new Eps(iNDArray, dup, Shape.toOffsetZeroCopy(dup, 'f'), dup.length())).sum(Integer.MAX_VALUE).getDouble(0L) == dup.length()) {
                log.debug("EXITING BACKTRACK: Jump too small (stepMin = {}). Exiting and using original params. Score = {}", Double.valueOf(d3), Double.valueOf(setScoreFor(iNDArray, layerWorkspaceMgr)));
                return 0.0d;
            }
            double scoreFor = setScoreFor(dup, layerWorkspaceMgr);
            log.debug("Model score after step = {}", Double.valueOf(scoreFor));
            if ((this.minObjectiveFunction && scoreFor < d7) || (!this.minObjectiveFunction && scoreFor > d7)) {
                d7 = scoreFor;
                d8 = d2;
            }
            if (this.minObjectiveFunction && scoreFor <= score + (9.999999747378752E-5d * d2 * dot)) {
                log.debug("Sufficient decrease (Wolfe cond.), exiting backtrack on iter {}: score={}, scoreAtStart={}", Integer.valueOf(i), Double.valueOf(scoreFor), Double.valueOf(score));
                if (scoreFor > score) {
                    throw new IllegalStateException("Function did not decrease: score = " + scoreFor + " > " + score + " = oldScore");
                }
                return d2;
            }
            if (!this.minObjectiveFunction && scoreFor >= score + (9.999999747378752E-5d * d2 * dot)) {
                log.debug("Sufficient increase (Wolfe cond.), exiting backtrack on iter {}: score={}, bestScore={}", Integer.valueOf(i), Double.valueOf(scoreFor), Double.valueOf(score));
                if (scoreFor < score) {
                    throw new IllegalStateException("Function did not increase: score = " + scoreFor + " < " + score + " = scoreAtStart");
                }
                return d2;
            }
            if (Double.isInfinite(scoreFor) || Double.isInfinite(d6) || Double.isNaN(scoreFor) || Double.isNaN(d6)) {
                log.warn("Value is infinite after jump. oldStep={}. score={}, score2={}. Scaling back step size...", Double.valueOf(d4), Double.valueOf(scoreFor), Double.valueOf(d6));
                d = 0.2d * d2;
                if (d2 < d3) {
                    log.warn("EXITING BACKTRACK: Jump too small (step={} < stepMin={}). Exiting and using previous parameters. Value={}", Double.valueOf(d2), Double.valueOf(d3), Double.valueOf(setScoreFor(iNDArray, layerWorkspaceMgr)));
                    return 0.0d;
                }
            } else if (this.minObjectiveFunction) {
                if (d2 == 1.0d) {
                    d = (-dot) / (2.0d * ((scoreFor - score) - dot));
                } else {
                    double d9 = (scoreFor - score) - (d2 * dot);
                    double d10 = (d6 - score) - (d5 * dot);
                    if (d2 == d5) {
                        throw new IllegalStateException("FAILURE: dividing by step-step2 which equals 0. step=" + d2);
                    }
                    double d11 = d2 * d2;
                    double d12 = d5 * d5;
                    double d13 = ((d9 / d11) - (d10 / d12)) / (d2 - d5);
                    double d14 = ((((-d5) * d9) / d11) + ((d2 * d10) / d12)) / (d2 - d5);
                    if (d13 == 0.0d) {
                        d = (-dot) / (2.0d * d14);
                    } else {
                        double d15 = (d14 * d14) - ((3.0d * d13) * dot);
                        d = d15 < 0.0d ? 0.5d * d2 : d14 <= 0.0d ? ((-d14) + FastMath.sqrt(d15)) / (3.0d * d13) : (-dot) / (d14 + FastMath.sqrt(d15));
                    }
                    if (d > 0.5d * d2) {
                        d = 0.5d * d2;
                    }
                }
            } else if (d2 == 1.0d) {
                d = (-dot) / (2.0d * ((score - scoreFor) - dot));
            } else {
                double d16 = (score - scoreFor) - (d2 * dot);
                double d17 = (score - d6) - (d5 * dot);
                if (d2 == d5) {
                    throw new IllegalStateException("FAILURE: dividing by step-step2 which equals 0. step=" + d2);
                }
                double d18 = d2 * d2;
                double d19 = d5 * d5;
                double d20 = ((d16 / d18) - (d17 / d19)) / (d2 - d5);
                double d21 = ((((-d5) * d16) / d18) + ((d2 * d17) / d19)) / (d2 - d5);
                if (d20 == 0.0d) {
                    d = (-dot) / (2.0d * d21);
                } else {
                    double d22 = (d21 * d21) - ((3.0d * d20) * dot);
                    d = d22 < 0.0d ? 0.5d * d2 : d21 <= 0.0d ? ((-d21) + FastMath.sqrt(d22)) / (3.0d * d20) : (-dot) / (d21 + FastMath.sqrt(d22));
                }
                if (d > 0.5d * d2) {
                    d = 0.5d * d2;
                }
            }
            d5 = d2;
            d6 = scoreFor;
            log.debug("tmpStep: {}", Double.valueOf(d));
            d2 = Math.max(d, 0.10000000149011612d * d2);
        }
        if (this.minObjectiveFunction && d7 < score) {
            log.debug("Exited line search after maxIterations termination condition; bestStepSize={}, bestScore={}, scoreAtStart={}", Double.valueOf(d8), Double.valueOf(d7), Double.valueOf(score));
            return d8;
        }
        if (!this.minObjectiveFunction && d7 > score) {
            log.debug("Exited line search after maxIterations termination condition; bestStepSize={}, bestScore={}, scoreAtStart={}", Double.valueOf(d8), Double.valueOf(d7), Double.valueOf(score));
            return d8;
        }
        log.debug("Exited line search after maxIterations termination condition; score did not improve (bestScore={}, scoreAtStart={}). Resetting parameters", Double.valueOf(d7), Double.valueOf(score));
        setScoreFor(iNDArray, layerWorkspaceMgr);
        return 0.0d;
    }
}
