package org.deeplearning4j.nn.transferlearning;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/deeplearning4j/nn/transferlearning/TransferLearningHelper.class */
public class TransferLearningHelper {
    private boolean isGraph;
    private boolean applyFrozen;
    private ComputationGraph origGraph;
    private MultiLayerNetwork origMLN;
    private int frozenTill;
    private String[] frozenOutputAt;
    private ComputationGraph unFrozenSubsetGraph;
    private MultiLayerNetwork unFrozenSubsetMLN;
    Set<String> frozenInputVertices;
    List<String> graphInputs;
    int frozenInputLayer;

    public TransferLearningHelper(ComputationGraph computationGraph, String... strArr) {
        this.isGraph = true;
        this.applyFrozen = false;
        this.frozenInputVertices = new HashSet();
        this.frozenInputLayer = 0;
        this.origGraph = computationGraph;
        this.frozenOutputAt = strArr;
        this.applyFrozen = true;
        initHelperGraph();
    }

    public TransferLearningHelper(ComputationGraph computationGraph) {
        this.isGraph = true;
        this.applyFrozen = false;
        this.frozenInputVertices = new HashSet();
        this.frozenInputLayer = 0;
        this.origGraph = computationGraph;
        initHelperGraph();
    }

    public TransferLearningHelper(MultiLayerNetwork multiLayerNetwork, int i) {
        this.isGraph = true;
        this.applyFrozen = false;
        this.frozenInputVertices = new HashSet();
        this.frozenInputLayer = 0;
        this.isGraph = false;
        this.frozenTill = i;
        this.applyFrozen = true;
        this.origMLN = multiLayerNetwork;
        initHelperMLN();
    }

    public TransferLearningHelper(MultiLayerNetwork multiLayerNetwork) {
        this.isGraph = true;
        this.applyFrozen = false;
        this.frozenInputVertices = new HashSet();
        this.frozenInputLayer = 0;
        this.isGraph = false;
        this.origMLN = multiLayerNetwork;
        initHelperMLN();
    }

    public void errorIfGraphIfMLN() {
        if (!this.isGraph) {
            throw new IllegalArgumentException("This instance was initialized with a MultiLayerNetwork. Cannot apply methods related to computation graphs");
        }
        throw new IllegalArgumentException("This instance was initialized with a computation graph. Cannot apply methods related to MLN");
    }

    public ComputationGraph unfrozenGraph() {
        if (!this.isGraph) {
            errorIfGraphIfMLN();
        }
        return this.unFrozenSubsetGraph;
    }

    public MultiLayerNetwork unfrozenMLN() {
        if (this.isGraph) {
            errorIfGraphIfMLN();
        }
        return this.unFrozenSubsetMLN;
    }

    public INDArray[] outputFromFeaturized(INDArray[] iNDArrayArr) {
        if (!this.isGraph) {
            errorIfGraphIfMLN();
        }
        return this.unFrozenSubsetGraph.output(iNDArrayArr);
    }

    public INDArray outputFromFeaturized(INDArray iNDArray) {
        if (!this.isGraph) {
            return this.unFrozenSubsetMLN.output(iNDArray);
        }
        if (this.unFrozenSubsetGraph.getNumOutputArrays() > 1) {
            throw new IllegalArgumentException("Graph has more than one output. Expecting an input array with outputFromFeaturized method call");
        }
        return this.unFrozenSubsetGraph.output(iNDArray)[0];
    }

    private void initHelperGraph() {
        int[] iArr = (int[]) this.origGraph.topologicalSortOrder().clone();
        ArrayUtils.reverse(iArr);
        HashSet<String> hashSet = new HashSet();
        if (this.applyFrozen) {
            Collections.addAll(hashSet, this.frozenOutputAt);
        }
        for (int i : iArr) {
            GraphVertex graphVertex = this.origGraph.getVertices()[i];
            if (this.applyFrozen && hashSet.contains(graphVertex.getVertexName())) {
                if (graphVertex.hasLayer()) {
                    Layer layer = graphVertex.getLayer();
                    graphVertex.setLayerAsFrozen();
                    Layer[] layers = this.origGraph.getLayers();
                    int i2 = 0;
                    while (true) {
                        if (i2 >= layers.length) {
                            break;
                        }
                        if (layers[i2] == layer) {
                            layers[i2] = graphVertex.getLayer();
                            break;
                        }
                        i2++;
                    }
                }
                VertexIndices[] inputVertices = graphVertex.getInputVertices();
                if (inputVertices != null && inputVertices.length > 0) {
                    for (VertexIndices vertexIndices : inputVertices) {
                        hashSet.add(this.origGraph.getVertices()[vertexIndices.getVertexIndex()].getVertexName());
                    }
                }
            } else if (graphVertex.hasLayer() && (graphVertex.getLayer() instanceof FrozenLayer)) {
                hashSet.add(graphVertex.getVertexName());
                VertexIndices[] inputVertices2 = graphVertex.getInputVertices();
                if (inputVertices2 != null && inputVertices2.length > 0) {
                    for (VertexIndices vertexIndices2 : inputVertices2) {
                        hashSet.add(this.origGraph.getVertices()[vertexIndices2.getVertexIndex()].getVertexName());
                    }
                }
            }
        }
        for (int i3 : iArr) {
            GraphVertex graphVertex2 = this.origGraph.getVertices()[i3];
            if (!hashSet.contains(graphVertex2.getVertexName()) && !graphVertex2.isInputVertex()) {
                for (VertexIndices vertexIndices3 : graphVertex2.getInputVertices()) {
                    String vertexName = this.origGraph.getVertices()[vertexIndices3.getVertexIndex()].getVertexName();
                    if (hashSet.contains(vertexName)) {
                        this.frozenInputVertices.add(vertexName);
                    }
                }
            }
        }
        TransferLearning.GraphBuilder graphBuilder = new TransferLearning.GraphBuilder(this.origGraph);
        for (String str : hashSet) {
            if (this.frozenInputVertices.contains(str)) {
                graphBuilder.removeVertexKeepConnections(str);
            } else {
                graphBuilder.removeVertexAndConnections(str);
            }
        }
        HashSet hashSet2 = new HashSet();
        hashSet2.addAll(this.origGraph.getConfiguration().getNetworkInputs());
        hashSet2.removeAll(hashSet);
        Iterator it2 = hashSet2.iterator();
        while (it2.hasNext()) {
            graphBuilder.removeVertexKeepConnections((String) it2.next());
        }
        hashSet2.addAll(this.frozenInputVertices);
        this.graphInputs = new ArrayList(hashSet2);
        Collections.sort(this.graphInputs);
        Iterator it3 = hashSet2.iterator();
        while (it3.hasNext()) {
            graphBuilder.addInputs((String) it3.next());
        }
        this.unFrozenSubsetGraph = graphBuilder.build();
        copyOrigParamsToSubsetGraph();
        if (this.frozenInputVertices.isEmpty()) {
            throw new IllegalArgumentException("No frozen layers found");
        }
    }

    private void initHelperMLN() {
        if (this.applyFrozen) {
            Layer[] layers = this.origMLN.getLayers();
            for (int i = this.frozenTill; i >= 0; i--) {
                layers[i] = new FrozenLayer(layers[i]);
            }
            this.origMLN.setLayers(layers);
        }
        for (int i2 = 0; i2 < this.origMLN.getnLayers(); i2++) {
            if (this.origMLN.getLayer(i2) instanceof FrozenLayer) {
                this.frozenInputLayer = i2;
            }
        }
        ArrayList arrayList = new ArrayList();
        for (int i3 = this.frozenInputLayer + 1; i3 < this.origMLN.getnLayers(); i3++) {
            arrayList.add(this.origMLN.getLayer(i3).conf());
        }
        MultiLayerConfiguration layerWiseConfigurations = this.origMLN.getLayerWiseConfigurations();
        this.unFrozenSubsetMLN = new MultiLayerNetwork(new MultiLayerConfiguration.Builder().backprop(layerWiseConfigurations.isBackprop()).inputPreProcessors(layerWiseConfigurations.getInputPreProcessors()).pretrain(layerWiseConfigurations.isPretrain()).backpropType(layerWiseConfigurations.getBackpropType()).tBPTTForwardLength(layerWiseConfigurations.getTbpttFwdLength()).tBPTTBackwardLength(layerWiseConfigurations.getTbpttBackLength()).confs(arrayList).build());
        this.unFrozenSubsetMLN.init();
        for (int i4 = this.frozenInputLayer + 1; i4 < this.origMLN.getnLayers(); i4++) {
            this.unFrozenSubsetMLN.getLayer((i4 - this.frozenInputLayer) - 1).setParams(this.origMLN.getLayer(i4).params());
        }
    }

    public MultiDataSet featurize(MultiDataSet multiDataSet) {
        if (!this.isGraph) {
            throw new IllegalArgumentException("Cannot use multidatasets with MultiLayerNetworks.");
        }
        INDArray[] labels = multiDataSet.getLabels();
        INDArray[] features = multiDataSet.getFeatures();
        if (multiDataSet.getFeaturesMaskArrays() != null) {
            throw new IllegalArgumentException("Currently cannot support featurizing datasets with feature masks");
        }
        INDArray[] labelsMaskArrays = multiDataSet.getLabelsMaskArrays();
        INDArray[] iNDArrayArr = new INDArray[this.graphInputs.size()];
        Map<String, INDArray> feedForward = this.origGraph.feedForward(features, false);
        for (int i = 0; i < this.graphInputs.size(); i++) {
            String str = this.graphInputs.get(i);
            if (this.origGraph.getVertex(str).isInputVertex()) {
                iNDArrayArr[i] = this.origGraph.getInput(this.origGraph.getConfiguration().getNetworkInputs().indexOf(str));
            } else {
                iNDArrayArr[i] = feedForward.get(str);
            }
        }
        return new MultiDataSet(iNDArrayArr, labels, (INDArray[]) null, labelsMaskArrays);
    }

    public DataSet featurize(DataSet dataSet) {
        if (!this.isGraph) {
            if (dataSet.getFeaturesMaskArray() != null) {
                throw new UnsupportedOperationException("Feature masks not supported with featurizing currently");
            }
            return new DataSet(this.origMLN.feedForwardToLayer(this.frozenInputLayer + 1, dataSet.getFeatures(), false).get(this.frozenInputLayer + 1), dataSet.getLabels(), null, dataSet.getLabelsMaskArray());
        }
        if (this.origGraph.getNumInputArrays() > 1 || this.origGraph.getNumOutputArrays() > 1) {
            throw new IllegalArgumentException("Input or output size to a computation graph is greater than one. Requires use of a MultiDataSet.");
        }
        if (dataSet.getFeaturesMaskArray() != null) {
            throw new IllegalArgumentException("Currently cannot support featurizing datasets with feature masks");
        }
        MultiDataSet featurize = featurize(new MultiDataSet(new INDArray[]{dataSet.getFeatures()}, new INDArray[]{dataSet.getLabels()}, (INDArray[]) null, new INDArray[]{dataSet.getLabelsMaskArray()}));
        return new DataSet(featurize.getFeatures()[0], dataSet.getLabels(), featurize.getLabelsMaskArrays()[0], dataSet.getLabelsMaskArray());
    }

    public void fitFeaturized(MultiDataSetIterator multiDataSetIterator) {
        this.unFrozenSubsetGraph.fit(multiDataSetIterator);
        copyParamsFromSubsetGraphToOrig();
    }

    public void fitFeaturized(MultiDataSet multiDataSet) {
        this.unFrozenSubsetGraph.fit(multiDataSet);
        copyParamsFromSubsetGraphToOrig();
    }

    public void fitFeaturized(DataSet dataSet) {
        if (this.isGraph) {
            this.unFrozenSubsetGraph.fit(dataSet);
            copyParamsFromSubsetGraphToOrig();
        } else {
            this.unFrozenSubsetMLN.fit(dataSet);
            copyParamsFromSubsetMLNToOrig();
        }
    }

    public void fitFeaturized(DataSetIterator dataSetIterator) {
        if (this.isGraph) {
            this.unFrozenSubsetGraph.fit(dataSetIterator);
            copyParamsFromSubsetGraphToOrig();
        } else {
            this.unFrozenSubsetMLN.fit(dataSetIterator);
            copyParamsFromSubsetMLNToOrig();
        }
    }

    private void copyParamsFromSubsetGraphToOrig() {
        for (GraphVertex graphVertex : this.unFrozenSubsetGraph.getVertices()) {
            if (graphVertex.hasLayer()) {
                this.origGraph.getVertex(graphVertex.getVertexName()).getLayer().setParams(graphVertex.getLayer().params());
            }
        }
    }

    private void copyOrigParamsToSubsetGraph() {
        for (GraphVertex graphVertex : this.unFrozenSubsetGraph.getVertices()) {
            if (graphVertex.hasLayer()) {
                graphVertex.getLayer().setParams(this.origGraph.getLayer(graphVertex.getVertexName()).params());
            }
        }
    }

    private void copyParamsFromSubsetMLNToOrig() {
        for (int i = this.frozenInputLayer + 1; i < this.origMLN.getnLayers(); i++) {
            this.origMLN.getLayer(i).setParams(this.unFrozenSubsetMLN.getLayer((i - this.frozenInputLayer) - 1).params());
        }
    }
}
