package org.deeplearning4j.nn.layers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.Solver;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/layers/BaseOutputLayer.class */
public abstract class BaseOutputLayer<LayerConfT extends org.deeplearning4j.nn.conf.layers.BaseOutputLayer> extends BaseLayer<LayerConfT> implements Serializable, IOutputLayer {
    protected INDArray labels;
    private transient Solver solver;
    private double fullNetRegTerm;
    protected INDArray inputMaskArray;
    protected MaskState inputMaskArrayState;

    public BaseOutputLayer(NeuralNetConfiguration neuralNetConfiguration, DataType dataType) {
        super(neuralNetConfiguration, dataType);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public double computeScore(double d, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + layerId());
        }
        this.fullNetRegTerm = d;
        double computeScore = ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) layerConf()).getLossFn().computeScore(getLabels2d(layerWorkspaceMgr, ArrayType.FF_WORKING_MEM), preOutput2d(z, layerWorkspaceMgr), ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) layerConf()).getActivationFn(), this.maskArray, false);
        if (conf().isMiniBatch()) {
            computeScore /= getInputMiniBatchSize();
        }
        double d2 = computeScore + d;
        this.score = d2;
        return d2;
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public boolean needsLabels() {
        return true;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public INDArray computeScoreForExamples(double d, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + layerId());
        }
        INDArray computeScoreArray = ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) layerConf()).getLossFn().computeScoreArray(getLabels2d(layerWorkspaceMgr, ArrayType.FF_WORKING_MEM), preOutput2d(false, layerWorkspaceMgr), ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) layerConf()).getActivationFn(), this.maskArray);
        if (d != 0.0d) {
            computeScoreArray.addi(Double.valueOf(d));
        }
        return layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATIONS, computeScoreArray);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void computeGradientAndScore(LayerWorkspaceMgr layerWorkspaceMgr) {
        if (this.input == null || this.labels == null) {
            return;
        }
        this.gradient = getGradientsAndDelta(preOutput2d(true, layerWorkspaceMgr), layerWorkspaceMgr).getFirst();
        this.score = computeScore(this.fullNetRegTerm, true, layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer
    protected void setScoreWithZ(INDArray iNDArray) {
        throw new RuntimeException("Not supported - " + layerId());
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<>(gradient(), Double.valueOf(score()));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(true);
        Pair<Gradient, INDArray> gradientsAndDelta = getGradientsAndDelta(preOutput2d(true, layerWorkspaceMgr), layerWorkspaceMgr);
        INDArray second = gradientsAndDelta.getSecond();
        INDArray paramWithNoise = getParamWithNoise("W", true, layerWorkspaceMgr);
        return new Pair<>(gradientsAndDelta.getFirst(), backpropDropOutIfPresent(paramWithNoise.mmuli(second.transpose(), layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, second.dataType(), new long[]{paramWithNoise.size(0), second.size(0)}, 'f')).transpose()));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        return this.gradient;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Pair<Gradient, INDArray> getGradientsAndDelta(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray computeGradient = ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) layerConf()).getLossFn().computeGradient(getLabels2d(layerWorkspaceMgr, ArrayType.BP_WORKING_MEM), iNDArray, ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) layerConf()).getActivationFn(), this.maskArray);
        DefaultGradient defaultGradient = new DefaultGradient();
        INDArray iNDArray2 = this.gradientViews.get("W");
        Nd4j.gemm(this.input.castTo(iNDArray2.dataType()), computeGradient, iNDArray2, true, false, 1.0d, 0.0d);
        defaultGradient.gradientForVariable().put("W", iNDArray2);
        if (hasBias()) {
            INDArray iNDArray3 = this.gradientViews.get("b");
            computeGradient.sum(iNDArray3, 0);
            defaultGradient.gradientForVariable().put("b", iNDArray3);
        }
        return new Pair<>(defaultGradient, layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, computeGradient));
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        setInput(iNDArray, layerWorkspaceMgr);
        return activate(z, layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double f1Score(DataSet dataSet) {
        return f1Score(dataSet.getFeatures(), dataSet.getLabels());
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double f1Score(INDArray iNDArray, INDArray iNDArray2) {
        Evaluation evaluation = new Evaluation();
        evaluation.eval(iNDArray2, activate(iNDArray, false, LayerWorkspaceMgr.noWorkspacesImmutable()));
        return evaluation.f1();
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public int numLabels() {
        return (int) this.labels.size(1);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(DataSetIterator dataSetIterator) {
        while (dataSetIterator.hasNext()) {
            fit(dataSetIterator.next());
        }
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public int[] predict(INDArray iNDArray) {
        INDArray activate = activate(iNDArray, false, LayerWorkspaceMgr.noWorkspacesImmutable());
        int[] iArr = new int[iNDArray.rows()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = Nd4j.getBlasWrapper().iamax(activate.getRow(i));
        }
        return iArr;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public List<String> predict(DataSet dataSet) {
        int[] predict = predict(dataSet.getFeatures());
        ArrayList arrayList = new ArrayList();
        for (int i : predict) {
            arrayList.add(i, dataSet.getLabelName(i));
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, INDArray iNDArray2) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(DataSet dataSet) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, int[] iArr) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void clear() {
        super.clear();
        this.labels = null;
        this.solver = null;
        this.inputMaskArrayState = null;
        this.inputMaskArray = null;
        this.fullNetRegTerm = 0.0d;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public INDArray getLabels() {
        return this.labels;
    }

    @Override // org.deeplearning4j.nn.api.layers.IOutputLayer
    public void setLabels(INDArray iNDArray) {
        this.labels = iNDArray;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public INDArray preOutput2d(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        return preOutput(z, layerWorkspaceMgr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.nn.layers.AbstractLayer
    public void applyMask(INDArray iNDArray) {
        if (this.maskArray.isColumnVectorOrScalar()) {
            iNDArray.muliColumnVector(this.maskArray.castTo(iNDArray.dataType()));
        } else {
            if (!Arrays.equals(iNDArray.shape(), this.maskArray.shape())) {
                throw new IllegalStateException("Invalid mask array: per-example masking should be a column vector, per output masking arrays should be the same shape as the output/labels arrays. Mask shape: " + Arrays.toString(this.maskArray.shape()) + ", output shape: " + Arrays.toString(iNDArray.shape()) + layerId());
            }
            iNDArray.muli(this.maskArray.castTo(iNDArray.dataType()));
        }
    }

    protected abstract INDArray getLabels2d(LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType);

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return false;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BaseLayer
    public boolean hasBias() {
        return ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) layerConf()).hasBias();
    }
}
