package org.deeplearning4j.earlystopping.trainer;

import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/deeplearning4j/earlystopping/trainer/EarlyStoppingGraphTrainer.class */
public class EarlyStoppingGraphTrainer extends BaseEarlyStoppingTrainer<ComputationGraph> {

    /* renamed from: net, reason: collision with root package name */
    private ComputationGraph f38net;

    public EarlyStoppingGraphTrainer(EarlyStoppingConfiguration<ComputationGraph> earlyStoppingConfiguration, ComputationGraph computationGraph, DataSetIterator dataSetIterator) {
        this(earlyStoppingConfiguration, computationGraph, dataSetIterator, (EarlyStoppingListener<ComputationGraph>) null);
    }

    public EarlyStoppingGraphTrainer(EarlyStoppingConfiguration<ComputationGraph> earlyStoppingConfiguration, ComputationGraph computationGraph, DataSetIterator dataSetIterator, EarlyStoppingListener<ComputationGraph> earlyStoppingListener) {
        super(earlyStoppingConfiguration, computationGraph, dataSetIterator, null, earlyStoppingListener);
        if (computationGraph.getNumInputArrays() != 1 || computationGraph.getNumOutputArrays() != 1) {
            throw new IllegalStateException("Cannot do early stopping training on ComputationGraph with DataSetIterator: graph does not have 1 input and 1 output array");
        }
        this.f38net = computationGraph;
    }

    public EarlyStoppingGraphTrainer(EarlyStoppingConfiguration<ComputationGraph> earlyStoppingConfiguration, ComputationGraph computationGraph, MultiDataSetIterator multiDataSetIterator, EarlyStoppingListener<ComputationGraph> earlyStoppingListener) {
        super(earlyStoppingConfiguration, computationGraph, null, multiDataSetIterator, earlyStoppingListener);
        this.f38net = computationGraph;
    }

    @Override // org.deeplearning4j.earlystopping.trainer.BaseEarlyStoppingTrainer
    protected void fit(DataSet dataSet) {
        if (this.f38net.getConfiguration().isBackprop()) {
            this.f38net.fit(dataSet);
        } else {
            if (!this.f38net.getConfiguration().isPretrain()) {
                throw new IllegalStateException("Cannot train - network configuration has both isBackprop == false and isPretrain == false");
            }
            this.f38net.pretrain(new SingletonDataSetIterator(dataSet));
        }
    }

    @Override // org.deeplearning4j.earlystopping.trainer.BaseEarlyStoppingTrainer
    protected void fit(MultiDataSet multiDataSet) {
        if (this.f38net.getConfiguration().isBackprop()) {
            this.f38net.fit(multiDataSet);
        } else {
            if (!this.f38net.getConfiguration().isPretrain()) {
                throw new IllegalStateException("Cannot train - network configuration has both isBackprop == false and isPretrain == false");
            }
            this.f38net.pretrain(new SingletonMultiDataSetIterator(multiDataSet));
        }
    }
}
