package org.deeplearning4j.nn.layers;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.params.PretrainParamInitializer;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/layers/BasePretrainNetwork.class */
public abstract class BasePretrainNetwork<LayerConfT extends org.deeplearning4j.nn.conf.layers.BasePretrainNetwork> extends BaseLayer<LayerConfT> {
    public BasePretrainNetwork(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public BasePretrainNetwork(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    public INDArray getCorruptedInput(INDArray iNDArray, double d) {
        INDArray sample = Nd4j.getDistributions().createBinomial(1, 1.0d - d).sample(iNDArray.shape());
        sample.muli(iNDArray);
        return sample;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Gradient createGradient(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        DefaultGradient defaultGradient = new DefaultGradient(this.gradientsFlattened);
        INDArray iNDArray4 = this.gradientViews.get("W");
        iNDArray4.assign(iNDArray);
        INDArray iNDArray5 = this.gradientViews.get("b");
        iNDArray5.assign(iNDArray3);
        INDArray iNDArray6 = this.gradientViews.get(PretrainParamInitializer.VISIBLE_BIAS_KEY);
        iNDArray6.assign(iNDArray2);
        defaultGradient.gradientForVariable().put("W", iNDArray4);
        defaultGradient.gradientForVariable().put("b", iNDArray5);
        defaultGradient.gradientForVariable().put(PretrainParamInitializer.VISIBLE_BIAS_KEY, iNDArray6);
        return defaultGradient;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public long numParams(boolean z) {
        return super.numParams(z);
    }

    public abstract Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray iNDArray);

    public abstract Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray iNDArray);

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BaseLayer
    public void setScoreWithZ(INDArray iNDArray) {
        if (this.input == null || iNDArray == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + layerId());
        }
        this.score = (((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork) layerConf()).getLossFunction().getILossFunction().computeScore(this.input, iNDArray, ((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork) layerConf()).getActivationFn(), this.maskArray, false) + (calcL1(false) + calcL2(false))) / getInputMiniBatchSize();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable(boolean z) {
        if (!z) {
            return this.params;
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put("W", this.params.get("W"));
        linkedHashMap.put("b", this.params.get("b"));
        return linkedHashMap;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model, org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray params() {
        ArrayList arrayList = new ArrayList(2);
        Iterator<Map.Entry<String, INDArray>> it2 = this.params.entrySet().iterator();
        while (it2.hasNext()) {
            arrayList.add(it2.next().getValue());
        }
        return Nd4j.toFlattened('f', (Collection<INDArray>) arrayList);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public long numParams() {
        int i = 0;
        Iterator<Map.Entry<String, INDArray>> it2 = this.params.entrySet().iterator();
        while (it2.hasNext()) {
            i = (int) (i + it2.next().getValue().length());
        }
        return i;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
        if (iNDArray == this.paramsFlattened) {
            return;
        }
        long j = 0;
        Iterator<String> it2 = this.conf.variables().iterator();
        while (it2.hasNext()) {
            j += getParam(it2.next()).length();
        }
        if (iNDArray.length() != j) {
            throw new IllegalArgumentException("Unable to set parameters: must be of length " + j + ", got params of length " + iNDArray.length() + " " + layerId());
        }
        this.paramsFlattened.assign(iNDArray);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        Pair<Gradient, INDArray> backpropGradient = super.backpropGradient(iNDArray, layerWorkspaceMgr);
        ((DefaultGradient) backpropGradient.getFirst()).setFlattenedGradient(this.gradientsFlattened);
        INDArray iNDArray2 = this.gradientViews.get(PretrainParamInitializer.VISIBLE_BIAS_KEY);
        backpropGradient.getFirst().gradientForVariable().put(PretrainParamInitializer.VISIBLE_BIAS_KEY, iNDArray2);
        iNDArray2.assign((Number) 0);
        this.weightNoiseParams.clear();
        return backpropGradient;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public double calcL2(boolean z) {
        double calcL2 = super.calcL2(true);
        if (z) {
            return calcL2;
        }
        if (((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork) layerConf()).getL2ByParam(PretrainParamInitializer.VISIBLE_BIAS_KEY) > 0.0d) {
            double doubleValue = getParam(PretrainParamInitializer.VISIBLE_BIAS_KEY).norm2Number().doubleValue();
            calcL2 += 0.5d * ((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork) layerConf()).getL2ByParam(PretrainParamInitializer.VISIBLE_BIAS_KEY) * doubleValue * doubleValue;
        }
        return calcL2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public double calcL1(boolean z) {
        double calcL1 = super.calcL1(true);
        if (((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork) layerConf()).getL1ByParam(PretrainParamInitializer.VISIBLE_BIAS_KEY) > 0.0d) {
            calcL1 += ((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork) layerConf()).getL1ByParam(PretrainParamInitializer.VISIBLE_BIAS_KEY) * getParam(PretrainParamInitializer.VISIBLE_BIAS_KEY).norm1Number().doubleValue();
        }
        return calcL1;
    }
}
