package org.deeplearning4j.earlystopping.scorecalc;

import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/* loaded from: input_file:org/deeplearning4j/earlystopping/scorecalc/DataSetLossCalculator.class */
public class DataSetLossCalculator extends BaseScoreCalculator<Model> {

    @JsonProperty
    private boolean average;

    public DataSetLossCalculator(DataSetIterator dataSetIterator, boolean z) {
        super(dataSetIterator);
        this.average = z;
    }

    public DataSetLossCalculator(MultiDataSetIterator multiDataSetIterator, boolean z) {
        super(multiDataSetIterator);
        this.average = z;
    }

    public String toString() {
        return "DataSetLossCalculator(average=" + this.average + ")";
    }

    @Override // org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator
    protected void reset() {
        this.scoreSum = 0.0d;
        this.minibatchCount = 0;
        this.exampleCount = 0;
    }

    @Override // org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator
    protected INDArray output(Model model, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        return output(model, arr(iNDArray), arr(iNDArray2), arr(iNDArray3))[0];
    }

    @Override // org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator
    protected INDArray[] output(Model model, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3) {
        if (model instanceof MultiLayerNetwork) {
            return new INDArray[]{((MultiLayerNetwork) model).output(iNDArrayArr[0], false, get0(iNDArrayArr2), get0(iNDArrayArr3))};
        }
        if (model instanceof ComputationGraph) {
            return ((ComputationGraph) model).output(false, iNDArrayArr, iNDArrayArr2, iNDArrayArr3);
        }
        throw new RuntimeException("Unknown model type: " + model.getClass());
    }

    @Override // org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator
    protected double scoreMinibatch(Model model, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, INDArray[] iNDArrayArr4, INDArray[] iNDArrayArr5) {
        if (model instanceof MultiLayerNetwork) {
            return ((MultiLayerNetwork) model).score(new DataSet(get0(iNDArrayArr), get0(iNDArrayArr2), get0(iNDArrayArr3), get0(iNDArrayArr4)), false) * iNDArrayArr[0].size(0);
        }
        if (model instanceof ComputationGraph) {
            return ((ComputationGraph) model).score(new MultiDataSet(iNDArrayArr, iNDArrayArr2, iNDArrayArr3, iNDArrayArr4)) * iNDArrayArr[0].size(0);
        }
        throw new RuntimeException("Unknown model type: " + model.getClass());
    }

    @Override // org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator
    protected double finalScore(double d, int i, int i2) {
        return this.average ? d / i2 : d;
    }

    @Override // org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator
    public boolean minimizeScore() {
        return true;
    }
}
