package org.deeplearning4j.optimize.solvers;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.exception.InvalidStepException;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TerminationCondition;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction;
import org.deeplearning4j.optimize.stepfunctions.NegativeGradientStepFunction;
import org.deeplearning4j.optimize.terminations.EpsTermination;
import org.deeplearning4j.optimize.terminations.ZeroDirection;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/solvers/BaseOptimizer.class */
public abstract class BaseOptimizer implements ConvexOptimizer {
    protected NeuralNetConfiguration conf;
    protected int iteration;
    protected static final Logger log = LoggerFactory.getLogger(BaseOptimizer.class);
    protected StepFunction stepFunction;
    protected Collection<IterationListener> iterationListeners;
    protected Collection<TerminationCondition> terminationConditions;
    protected Model model;
    protected BackTrackLineSearch lineMaximizer;
    protected Updater updater;
    protected ComputationGraphUpdater computationGraphUpdater;
    protected double step;
    private int batchSize;
    protected double score;
    protected double oldScore;
    protected double stepMax;
    public static final String GRADIENT_KEY = "g";
    public static final String SCORE_KEY = "score";
    public static final String PARAMS_KEY = "params";
    public static final String SEARCH_DIR = "searchDirection";
    protected Map<String, Object> searchState;

    public BaseOptimizer(NeuralNetConfiguration neuralNetConfiguration, StepFunction stepFunction, Collection<IterationListener> collection, Model model) {
        this(neuralNetConfiguration, stepFunction, collection, Arrays.asList(new ZeroDirection(), new EpsTermination()), model);
    }

    public BaseOptimizer(NeuralNetConfiguration neuralNetConfiguration, StepFunction stepFunction, Collection<IterationListener> collection, Collection<TerminationCondition> collection2, Model model) {
        this.iteration = 0;
        this.iterationListeners = new ArrayList();
        this.terminationConditions = new ArrayList();
        this.stepMax = Double.MAX_VALUE;
        this.searchState = new ConcurrentHashMap();
        this.conf = neuralNetConfiguration;
        this.stepFunction = stepFunction != null ? stepFunction : getDefaultStepFunctionForOptimizer(getClass());
        this.iterationListeners = collection != null ? collection : new ArrayList<>();
        this.terminationConditions = collection2;
        this.model = model;
        this.lineMaximizer = new BackTrackLineSearch(model, this.stepFunction, this);
        this.lineMaximizer.setStepMax(this.stepMax);
        this.lineMaximizer.setMaxIterations(neuralNetConfiguration.getMaxNumLineSearchIterations());
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public double score() {
        this.model.computeGradientAndScore();
        return this.model.score();
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public Updater getUpdater() {
        if (this.updater == null) {
            this.updater = UpdaterCreator.getUpdater(this.model);
        }
        return this.updater;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void setUpdater(Updater updater) {
        this.updater = updater;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public ComputationGraphUpdater getComputationGraphUpdater() {
        if (this.computationGraphUpdater == null && (this.model instanceof ComputationGraph)) {
            this.computationGraphUpdater = new ComputationGraphUpdater((ComputationGraph) this.model);
        }
        return this.computationGraphUpdater;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void setUpdaterComputationGraph(ComputationGraphUpdater computationGraphUpdater) {
        this.computationGraphUpdater = computationGraphUpdater;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void setListeners(Collection<IterationListener> collection) {
        if (collection == null) {
            this.iterationListeners = Collections.emptyList();
        } else {
            this.iterationListeners = collection;
        }
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public NeuralNetConfiguration getConf() {
        return this.conf;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public Pair<Gradient, Double> gradientAndScore() {
        this.oldScore = this.score;
        this.model.computeGradientAndScore();
        if (this.iterationListeners != null && this.iterationListeners.size() > 0) {
            for (IterationListener iterationListener : this.iterationListeners) {
                if (iterationListener instanceof TrainingListener) {
                    ((TrainingListener) iterationListener).onGradientCalculation(this.model);
                }
            }
        }
        Pair<Gradient, Double> gradientAndScore = this.model.gradientAndScore();
        this.score = gradientAndScore.getSecond().doubleValue();
        updateGradientAccordingToParams(gradientAndScore.getFirst(), this.model, this.model.batchSize());
        return gradientAndScore;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public boolean optimize() {
        this.model.validateInput();
        Pair<Gradient, Double> gradientAndScore = gradientAndScore();
        if (this.searchState.isEmpty()) {
            this.searchState.put(GRADIENT_KEY, gradientAndScore.getFirst().gradient());
            setupSearchState(gradientAndScore);
        } else {
            this.searchState.put(GRADIENT_KEY, gradientAndScore.getFirst().gradient());
        }
        for (TerminationCondition terminationCondition : this.terminationConditions) {
            if (terminationCondition.terminate(0.0d, 0.0d, new Object[]{gradientAndScore.getFirst().gradient()})) {
                log.info("Hit termination condition " + terminationCondition.getClass().getName());
                return true;
            }
        }
        preProcessLine();
        for (int i = 0; i < this.conf.getNumIterations(); i++) {
            INDArray iNDArray = (INDArray) this.searchState.get(GRADIENT_KEY);
            INDArray iNDArray2 = (INDArray) this.searchState.get(SEARCH_DIR);
            INDArray iNDArray3 = (INDArray) this.searchState.get(PARAMS_KEY);
            try {
                this.step = this.lineMaximizer.optimize(iNDArray3, iNDArray, iNDArray2);
            } catch (InvalidStepException e) {
                log.warn("Invalid step...continuing another iteration: {}", e.getMessage());
                this.step = 0.0d;
            }
            if (this.step != 0.0d) {
                this.stepFunction.step(iNDArray3, iNDArray2, this.step);
                this.model.setParams(iNDArray3);
            } else {
                log.debug("Step size returned by line search is 0.0.");
            }
            Pair<Gradient, Double> gradientAndScore2 = gradientAndScore();
            postStep(gradientAndScore2.getFirst().gradient());
            Iterator<IterationListener> it = this.iterationListeners.iterator();
            while (it.hasNext()) {
                it.next().iterationDone(this.model, i);
            }
            checkTerminalConditions(gradientAndScore2.getFirst().gradient(), this.oldScore, this.score, i);
            this.iteration++;
        }
        return true;
    }

    protected void postFirstStep(INDArray iNDArray) {
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public boolean checkTerminalConditions(INDArray iNDArray, double d, double d2, int i) {
        for (TerminationCondition terminationCondition : this.terminationConditions) {
            if (terminationCondition.terminate(d2, d, new Object[]{iNDArray})) {
                log.debug("Hit termination condition on iteration {}: score={}, oldScore={}, condition={}", new Object[]{Integer.valueOf(i), Double.valueOf(d2), Double.valueOf(d), terminationCondition});
                if (!(terminationCondition instanceof EpsTermination) || this.conf.getLayer() == null || this.conf.getLearningRatePolicy() != LearningRatePolicy.Score) {
                    return true;
                }
                this.model.applyLearningRateScoreDecay();
                return true;
            }
        }
        return false;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public int batchSize() {
        return this.batchSize;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void setBatchSize(int i) {
        this.batchSize = i;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void preProcessLine() {
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void postStep(INDArray iNDArray) {
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void updateGradientAccordingToParams(Gradient gradient, Model model, int i) {
        if (model instanceof ComputationGraph) {
            ComputationGraph computationGraph = (ComputationGraph) model;
            if (this.computationGraphUpdater == null) {
                this.computationGraphUpdater = new ComputationGraphUpdater(computationGraph);
            }
            this.computationGraphUpdater.update(computationGraph, gradient, this.iteration, i);
            return;
        }
        if (this.updater == null) {
            this.updater = UpdaterCreator.getUpdater(model);
        }
        this.updater.update((Layer) model, gradient, this.iteration, i);
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void setupSearchState(Pair<Gradient, Double> pair) {
        INDArray gradient = pair.getFirst().gradient(this.conf.variables());
        INDArray dup = this.model.params().dup();
        this.searchState.put(GRADIENT_KEY, gradient);
        this.searchState.put(SCORE_KEY, pair.getSecond());
        this.searchState.put(PARAMS_KEY, dup);
    }

    public static StepFunction getDefaultStepFunctionForOptimizer(Class<? extends ConvexOptimizer> cls) {
        return cls == StochasticGradientDescent.class ? new NegativeGradientStepFunction() : new NegativeDefaultStepFunction();
    }
}
