package org.deeplearning4j.nn.layers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.clustering.kdtree.KDTree;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.api.Classifier;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.optimize.Solver;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.LossFunction;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossCalculation;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.linalg.util.LinAlgExceptions;

/* loaded from: input_file:org/deeplearning4j/nn/layers/BaseOutputLayer.class */
public abstract class BaseOutputLayer<LayerConfT extends org.deeplearning4j.nn.conf.layers.BaseOutputLayer> extends BaseLayer<LayerConfT> implements Serializable, Classifier {
    protected INDArray labels;
    private transient Solver solver;
    private double fullNetworkL1;
    private double fullNetworkL2;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.deeplearning4j.nn.layers.BaseOutputLayer$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/nn/layers/BaseOutputLayer$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction = new int[LossFunctions.LossFunction.values().length];

        static {
            try {
                $SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction[LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction[LossFunctions.LossFunction.MCXENT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction[LossFunctions.LossFunction.XENT.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction[LossFunctions.LossFunction.MSE.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction[LossFunctions.LossFunction.EXPLL.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction[LossFunctions.LossFunction.RMSE_XENT.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction[LossFunctions.LossFunction.SQUARED_LOSS.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
        }
    }

    public BaseOutputLayer(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public BaseOutputLayer(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public double computeScore(double d, double d2, boolean z) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels");
        }
        this.fullNetworkL1 = d;
        this.fullNetworkL2 = d2;
        INDArray preOutput2d = preOutput2d(z);
        LossFunctions.LossFunction lossFunction = ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) this.conf.getLayer()).getLossFunction();
        if ((lossFunction == LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD || lossFunction == LossFunctions.LossFunction.MCXENT) && ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) layerConf()).getActivationFunction().equals("softmax")) {
            setScore(null, preOutput2d);
        } else {
            setScoreWithZ(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(conf().getLayer().getActivationFunction(), preOutput2d)));
        }
        return this.score;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public INDArray computeScoreForExamples(double d, double d2) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels");
        }
        INDArray preOutput2d = preOutput2d(false);
        return LossCalculation.builder().l1(d).l2(d2).labels(getLabels2d()).z(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(conf().getLayer().getActivationFunction(), preOutput2d.dup()))).preOut(preOutput2d).activationFn(conf().getLayer().getActivationFunction()).lossFunction(((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) layerConf()).getLossFunction()).useRegularization(this.conf.isUseRegularization()).mask(this.maskArray).build().scoreExamples();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void computeGradientAndScore() {
        if (this.input == null || this.labels == null) {
            return;
        }
        INDArray preOutput2d = preOutput2d(true);
        Triple<Gradient, INDArray, INDArray> gradientsAndDelta = getGradientsAndDelta(preOutput2d);
        this.gradient = gradientsAndDelta.getFirst();
        setScore(gradientsAndDelta.getThird(), preOutput2d);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer
    protected void setScoreWithZ(INDArray iNDArray) {
        setScore(iNDArray, null);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void setScore(INDArray iNDArray, INDArray iNDArray2) {
        if (((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) layerConf()).getLossFunction() != LossFunctions.LossFunction.CUSTOM) {
            this.score = LossCalculation.builder().l1(this.fullNetworkL1).l2(this.fullNetworkL2).labels(getLabels2d()).z(iNDArray).preOut(iNDArray2).activationFn(conf().getLayer().getActivationFunction()).lossFunction(((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) layerConf()).getLossFunction()).miniBatch(this.conf.isMiniBatch()).miniBatchSize(getInputMiniBatchSize()).useRegularization(this.conf.isUseRegularization()).mask(this.maskArray).build().score();
            return;
        }
        LossFunction createLossFunction = Nd4j.getOpFactory().createLossFunction(((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) layerConf()).getCustomLossFunction(), this.input, iNDArray);
        createLossFunction.exec();
        this.score = createLossFunction.getFinalResult().doubleValue();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<>(gradient(), Double.valueOf(score()));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        Triple<Gradient, INDArray, INDArray> gradientsAndDelta = getGradientsAndDelta(preOutput2d(true));
        return new Pair<>(gradientsAndDelta.getFirst(), this.params.get("W").mmul(gradientsAndDelta.getSecond().transpose()).transpose());
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        LinAlgExceptions.assertRows(this.input, getLabels2d());
        return this.gradient;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Triple<Gradient, INDArray, INDArray> getGradientsAndDelta(INDArray iNDArray) {
        Triple<Gradient, INDArray, INDArray> triple;
        INDArray execAndReturn = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(conf().getLayer().getActivationFunction(), iNDArray.dup()));
        INDArray sub = execAndReturn.sub(getLabels2d());
        DefaultGradient defaultGradient = new DefaultGradient();
        INDArray iNDArray2 = this.gradientViews.get("W");
        INDArray iNDArray3 = this.gradientViews.get("b");
        defaultGradient.gradientForVariable().put("W", iNDArray2);
        defaultGradient.gradientForVariable().put("b", iNDArray3);
        if (this.maskArray != null) {
            sub.muliColumnVector(this.maskArray);
        }
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$lossfunctions$LossFunctions$LossFunction[((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) layerConf()).getLossFunction().ordinal()]) {
            case KDTree.GREATER /* 1 */:
            case 2:
                Nd4j.gemm(this.input, sub, iNDArray2, true, false, 1.0d, 0.0d);
                iNDArray3.assign(sub.sum(new int[]{0}));
                triple = new Triple<>(defaultGradient, sub, execAndReturn);
                break;
            case 3:
                Nd4j.gemm(this.input, sub.div(execAndReturn.mul(execAndReturn.rsub(1))), iNDArray2, true, false, 1.0d, 0.0d);
                iNDArray3.assign(sub.sum(new int[]{0}));
                triple = new Triple<>(defaultGradient, sub, execAndReturn);
                break;
            case 4:
                INDArray mul = sub.mul(derivativeActivation(iNDArray));
                Nd4j.gemm(this.input, mul, iNDArray2, true, false, 1.0d, 0.0d);
                iNDArray3.assign(mul.sum(new int[]{0}));
                triple = new Triple<>(defaultGradient, mul, execAndReturn);
                break;
            case RegressionEvaluation.DEFAULT_PRECISION /* 5 */:
                Nd4j.gemm(this.input, this.labels.rsub(1).divi(execAndReturn), iNDArray2, true, false, 1.0d, 0.0d);
                iNDArray3.assign(sub.sum(new int[]{0}));
                triple = new Triple<>(defaultGradient, sub, execAndReturn);
                break;
            case 6:
                Nd4j.gemm(this.input, Transforms.sqrt(Transforms.pow(sub, Double.valueOf(2.0d))), iNDArray2, true, false, 1.0d, 0.0d);
                iNDArray3.assign(sub.sum(new int[]{0}));
                triple = new Triple<>(defaultGradient, sub, execAndReturn);
                break;
            case 7:
                Nd4j.gemm(this.input, sub.mul(sub), iNDArray2, true, false, 1.0d, 0.0d);
                iNDArray3.assign(sub.sum(new int[]{0}));
                triple = new Triple<>(defaultGradient, sub, execAndReturn);
                break;
            default:
                throw new IllegalStateException("Invalid loss function: " + ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) layerConf()).getLossFunction());
        }
        return triple;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z) {
        setInput(iNDArray);
        return output(z);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray) {
        setInput(iNDArray);
        return output(true);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate() {
        return output(false);
    }

    public INDArray output(INDArray iNDArray, boolean z) {
        setInput(iNDArray);
        return output(z);
    }

    public INDArray output(INDArray iNDArray) {
        setInput(iNDArray);
        return output(false);
    }

    public INDArray output(boolean z) {
        if (this.input == null) {
            throw new IllegalArgumentException("No null input allowed");
        }
        return super.activate(z);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double f1Score(DataSet dataSet) {
        return f1Score(dataSet.getFeatureMatrix(), dataSet.getLabels());
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double f1Score(INDArray iNDArray, INDArray iNDArray2) {
        Evaluation evaluation = new Evaluation();
        evaluation.eval(iNDArray2, labelProbabilities(iNDArray));
        return evaluation.f1();
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public int numLabels() {
        return this.labels.size(1);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(DataSetIterator dataSetIterator) {
        while (dataSetIterator.hasNext()) {
            fit((DataSet) dataSetIterator.next());
        }
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public int[] predict(INDArray iNDArray) {
        INDArray output = output(iNDArray);
        int[] iArr = new int[iNDArray.rows()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
        }
        return iArr;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public List<String> predict(DataSet dataSet) {
        int[] predict = predict(dataSet.getFeatureMatrix());
        ArrayList arrayList = new ArrayList();
        for (int i : predict) {
            arrayList.add(i, dataSet.getLabelName(i));
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public INDArray labelProbabilities(INDArray iNDArray) {
        return output(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, INDArray iNDArray2) {
        setInput(iNDArray);
        setLabels(iNDArray2);
        applyDropOutIfNecessary(true);
        if (this.solver == null) {
            this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
            Updater updater = this.solver.getOptimizer().getUpdater();
            int stateSizeForLayer = updater.stateSizeForLayer(this);
            if (stateSizeForLayer > 0) {
                updater.setStateViewArray(this, Nd4j.createUninitialized(new int[]{1, stateSizeForLayer}, Nd4j.order().charValue()), true);
            }
        }
        this.solver.optimize();
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(DataSet dataSet) {
        fit(dataSet.getFeatureMatrix(), dataSet.getLabels());
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, int[] iArr) {
        fit(iNDArray, FeatureUtil.toOutcomeMatrix(iArr, numLabels()));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void clear() {
        super.clear();
        if (this.labels != null) {
            this.labels.data().destroy();
            this.labels = null;
        }
        this.solver = null;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray) {
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void iterate(INDArray iNDArray) {
        throw new UnsupportedOperationException();
    }

    public INDArray getLabels() {
        return this.labels;
    }

    public void setLabels(INDArray iNDArray) {
        this.labels = iNDArray;
    }

    protected INDArray preOutput2d(boolean z) {
        return preOutput(z);
    }

    protected INDArray output2d(INDArray iNDArray) {
        return output(iNDArray);
    }

    protected INDArray getLabels2d() {
        return this.labels.rank() > 2 ? this.labels.reshape(this.labels.size(2), this.labels.size(1)) : this.labels;
    }
}
