package org.deeplearning4j.nn.layers.feedforward.rbm;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.RBM;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.layers.recurrent.LSTMHelpers;
import org.deeplearning4j.nn.params.PretrainParamInitializer;
import org.deeplearning4j.util.Dropout;
import org.deeplearning4j.util.RBMUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/nn/layers/feedforward/rbm/RBM.class */
public class RBM extends BasePretrainNetwork<org.deeplearning4j.nn.conf.layers.RBM> {
    private long seed;

    @Deprecated
    protected INDArray sigma;

    @Deprecated
    protected INDArray hiddenSigma;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.deeplearning4j.nn.layers.feedforward.rbm.RBM$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/nn/layers/feedforward/rbm/RBM$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$nn$conf$layers$RBM$HiddenUnit;
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$nn$conf$layers$RBM$VisibleUnit = new int[RBM.VisibleUnit.values().length];

        static {
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$RBM$VisibleUnit[RBM.VisibleUnit.IDENTITY.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$RBM$VisibleUnit[RBM.VisibleUnit.BINARY.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$RBM$VisibleUnit[RBM.VisibleUnit.GAUSSIAN.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$RBM$VisibleUnit[RBM.VisibleUnit.LINEAR.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$RBM$VisibleUnit[RBM.VisibleUnit.SOFTMAX.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            $SwitchMap$org$deeplearning4j$nn$conf$layers$RBM$HiddenUnit = new int[RBM.HiddenUnit.values().length];
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$RBM$HiddenUnit[RBM.HiddenUnit.IDENTITY.ordinal()] = 1;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$RBM$HiddenUnit[RBM.HiddenUnit.BINARY.ordinal()] = 2;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$RBM$HiddenUnit[RBM.HiddenUnit.GAUSSIAN.ordinal()] = 3;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$RBM$HiddenUnit[RBM.HiddenUnit.RECTIFIED.ordinal()] = 4;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$RBM$HiddenUnit[RBM.HiddenUnit.SOFTMAX.ordinal()] = 5;
            } catch (NoSuchFieldError e10) {
            }
        }
    }

    public RBM(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
        this.seed = neuralNetConfiguration.getSeed();
    }

    public RBM(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
        this.seed = neuralNetConfiguration.getSeed();
    }

    @Deprecated
    public void contrastiveDivergence() {
        Gradient gradient = gradient();
        getParam(PretrainParamInitializer.VISIBLE_BIAS_KEY).subi(gradient.gradientForVariable().get(PretrainParamInitializer.VISIBLE_BIAS_KEY));
        getParam("b").subi(gradient.gradientForVariable().get("b"));
        getParam("W").subi(gradient.gradientForVariable().get("W"));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void computeGradientAndScore() {
        int k = ((org.deeplearning4j.nn.conf.layers.RBM) layerConf()).getK();
        Pair<INDArray, INDArray> sampleHiddenGivenVisible = sampleHiddenGivenVisible(input());
        INDArray first = sampleHiddenGivenVisible.getFirst();
        INDArray iNDArray = null;
        INDArray iNDArray2 = null;
        INDArray iNDArray3 = null;
        INDArray iNDArray4 = null;
        int i = 0;
        while (i < k) {
            Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>> gibbhVh = i == 0 ? gibbhVh(first) : gibbhVh(iNDArray4);
            iNDArray = gibbhVh.getFirst().getFirst();
            iNDArray2 = gibbhVh.getFirst().getSecond();
            iNDArray3 = gibbhVh.getSecond().getFirst();
            iNDArray4 = gibbhVh.getSecond().getSecond();
            i++;
        }
        INDArray subi = input().transposei().mmul(sampleHiddenGivenVisible.getFirst()).subi(iNDArray.transpose().mmul(iNDArray3));
        INDArray sum = ((org.deeplearning4j.nn.conf.layers.RBM) layerConf()).getSparsity() != 0.0d ? sampleHiddenGivenVisible.getFirst().rsub(Double.valueOf(((org.deeplearning4j.nn.conf.layers.RBM) layerConf()).getSparsity())).sum(new int[]{0}) : sampleHiddenGivenVisible.getFirst().sub(iNDArray3).sum(new int[]{0});
        INDArray sum2 = this.input.sub(iNDArray).sum(new int[]{0});
        if (this.conf.isPretrain()) {
            subi.negi();
            sum.negi();
            sum2.negi();
        }
        this.gradient = createGradient(subi, sum2, sum);
        setScoreWithZ(iNDArray2);
    }

    public Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>> gibbhVh(INDArray iNDArray) {
        Pair<INDArray, INDArray> sampleVisibleGivenHidden = sampleVisibleGivenHidden(iNDArray);
        return new Pair<>(sampleVisibleGivenHidden, sampleHiddenGivenVisible(sampleVisibleGivenHidden.getFirst()));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BasePretrainNetwork
    public Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray iNDArray) {
        INDArray execAndReturn;
        INDArray propUp = propUp(iNDArray);
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$layers$RBM$HiddenUnit[((org.deeplearning4j.nn.conf.layers.RBM) layerConf()).getHiddenUnit().ordinal()]) {
            case 1:
                execAndReturn = propUp;
                break;
            case 2:
                Distribution createBinomial = Nd4j.getDistributions().createBinomial(1, propUp);
                createBinomial.reseedRandomGenerator(this.seed);
                execAndReturn = createBinomial.sample(propUp.shape());
                break;
            case 3:
                Distribution createNormal = Nd4j.getDistributions().createNormal(propUp, 1.0d);
                createNormal.reseedRandomGenerator(this.seed);
                execAndReturn = createNormal.sample(propUp.shape());
                break;
            case 4:
                INDArray sqrt = Transforms.sqrt(Transforms.sigmoid(propUp));
                INDArray sample = Nd4j.getDistributions().createNormal(propUp, 1.0d).sample(propUp.shape());
                sample.muli(sqrt);
                execAndReturn = Transforms.max(propUp.add(sample), 0.0d);
                break;
            case RegressionEvaluation.DEFAULT_PRECISION /* 5 */:
                execAndReturn = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", propUp));
                break;
            default:
                throw new IllegalStateException("Hidden unit type must either be Binary, Gaussian, SoftMax or Rectified");
        }
        return new Pair<>(propUp, execAndReturn);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BasePretrainNetwork
    public Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray iNDArray) {
        INDArray execAndReturn;
        INDArray propDown = propDown(iNDArray);
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$layers$RBM$VisibleUnit[((org.deeplearning4j.nn.conf.layers.RBM) layerConf()).getVisibleUnit().ordinal()]) {
            case 1:
                execAndReturn = propDown;
                break;
            case 2:
                Distribution createBinomial = Nd4j.getDistributions().createBinomial(1, propDown);
                createBinomial.reseedRandomGenerator(this.seed);
                execAndReturn = createBinomial.sample(propDown.shape());
                break;
            case 3:
            case 4:
                Distribution createNormal = Nd4j.getDistributions().createNormal(propDown, 1.0d);
                createNormal.reseedRandomGenerator(this.seed);
                execAndReturn = createNormal.sample(propDown.shape());
                break;
            case RegressionEvaluation.DEFAULT_PRECISION /* 5 */:
                execAndReturn = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", propDown));
                break;
            default:
                throw new IllegalStateException("Visible type must be one of Binary, Gaussian, SoftMax or Linear");
        }
        return new Pair<>(propDown, execAndReturn);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray, boolean z) {
        INDArray param = getParam("b");
        INDArray param2 = getParam("W");
        if (z && this.conf.isUseDropConnect() && this.conf.getLayer().getDropOut() > 0.0d) {
            param2 = Dropout.applyDropConnect(this, "W");
        }
        return iNDArray.mmul(param2).addiRowVector(param);
    }

    public INDArray propUp(INDArray iNDArray) {
        return propUp(iNDArray, true);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public INDArray propUp(INDArray iNDArray, boolean z) {
        INDArray preOutput = preOutput(iNDArray, z);
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$layers$RBM$HiddenUnit[((org.deeplearning4j.nn.conf.layers.RBM) layerConf()).getHiddenUnit().ordinal()]) {
            case 1:
                return preOutput;
            case 2:
                return Transforms.sigmoid(preOutput);
            case 3:
                Distribution createNormal = Nd4j.getDistributions().createNormal(preOutput, 1.0d);
                createNormal.reseedRandomGenerator(this.seed);
                return createNormal.sample(preOutput.shape());
            case 4:
                return Transforms.max(preOutput, 0.0d);
            case RegressionEvaluation.DEFAULT_PRECISION /* 5 */:
                return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", preOutput));
            default:
                throw new IllegalStateException("Hidden unit type should either be binary, gaussian, or rectified linear");
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public INDArray propUpDerivative(INDArray iNDArray) {
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$layers$RBM$HiddenUnit[((org.deeplearning4j.nn.conf.layers.RBM) layerConf()).getHiddenUnit().ordinal()]) {
            case 1:
                return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("identity", iNDArray).derivative());
            case 2:
                return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(LSTMHelpers.SIGMOID, iNDArray).derivative());
            case 3:
                Distribution createNormal = Nd4j.getDistributions().createNormal(iNDArray, 1.0d);
                createNormal.reseedRandomGenerator(this.seed);
                return iNDArray.mul(-2).mul(createNormal.sample(iNDArray.shape()));
            case 4:
                return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("relu", iNDArray).derivative());
            case RegressionEvaluation.DEFAULT_PRECISION /* 5 */:
                return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", iNDArray).derivative());
            default:
                throw new IllegalStateException("Hidden unit type should either be binary, gaussian, or rectified linear");
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public INDArray propDown(INDArray iNDArray) {
        INDArray transpose = getParam("W").transpose();
        INDArray addiRowVector = iNDArray.mmul(transpose).addiRowVector(getParam(PretrainParamInitializer.VISIBLE_BIAS_KEY));
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$layers$RBM$VisibleUnit[((org.deeplearning4j.nn.conf.layers.RBM) layerConf()).getVisibleUnit().ordinal()]) {
            case 1:
                return addiRowVector;
            case 2:
                return Transforms.sigmoid(addiRowVector);
            case 3:
                Distribution createNormal = Nd4j.getDistributions().createNormal(addiRowVector, 1.0d);
                createNormal.reseedRandomGenerator(this.seed);
                return createNormal.sample(addiRowVector.shape());
            case 4:
                return addiRowVector;
            case RegressionEvaluation.DEFAULT_PRECISION /* 5 */:
                return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", addiRowVector));
            default:
                throw new IllegalStateException("Visible unit type should either be binary or gaussian");
        }
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        if (z && this.conf.getLayer().getDropOut() > 0.0d) {
            Dropout.applyDropout(this.input, this.conf.getLayer().getDropOut());
        }
        return propUp(this.input, z);
    }

    @Override // org.deeplearning4j.nn.layers.BasePretrainNetwork, org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        INDArray muli = iNDArray.muli(propUpDerivative(preOutput(this.input, true)));
        if (this.maskArray != null) {
            muli.muliColumnVector(this.maskArray);
        }
        DefaultGradient defaultGradient = new DefaultGradient();
        INDArray iNDArray2 = this.gradientViews.get("W");
        Nd4j.gemm(this.input, muli, iNDArray2, true, false, 1.0d, 0.0d);
        INDArray iNDArray3 = this.gradientViews.get("b");
        iNDArray3.assign(muli.sum(new int[]{0}));
        INDArray iNDArray4 = this.gradientViews.get(PretrainParamInitializer.VISIBLE_BIAS_KEY);
        defaultGradient.gradientForVariable().put("W", iNDArray2);
        defaultGradient.gradientForVariable().put("b", iNDArray3);
        defaultGradient.gradientForVariable().put(PretrainParamInitializer.VISIBLE_BIAS_KEY, iNDArray4);
        return new Pair<>(defaultGradient, this.params.get("W").mmul(muli.transpose()).transpose());
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    @Deprecated
    public void iterate(INDArray iNDArray) {
        if (((org.deeplearning4j.nn.conf.layers.RBM) layerConf()).getVisibleUnit() == RBM.VisibleUnit.GAUSSIAN) {
            this.sigma = iNDArray.var(new int[]{0}).divi(Integer.valueOf(iNDArray.rows()));
        }
        this.input = iNDArray.dup();
        applyDropOutIfNecessary(true);
        contrastiveDivergence();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    @Deprecated
    public Layer transpose() {
        RBM rbm = (RBM) super.transpose();
        RBM.HiddenUnit inverse = RBMUtil.inverse(((org.deeplearning4j.nn.conf.layers.RBM) layerConf()).getVisibleUnit());
        RBM.VisibleUnit inverse2 = RBMUtil.inverse(((org.deeplearning4j.nn.conf.layers.RBM) layerConf()).getHiddenUnit());
        if (inverse == null) {
            inverse = ((org.deeplearning4j.nn.conf.layers.RBM) layerConf()).getHiddenUnit();
        }
        if (inverse2 == null) {
            inverse2 = ((org.deeplearning4j.nn.conf.layers.RBM) layerConf()).getVisibleUnit();
        }
        ((org.deeplearning4j.nn.conf.layers.RBM) rbm.layerConf()).setHiddenUnit(inverse);
        ((org.deeplearning4j.nn.conf.layers.RBM) rbm.layerConf()).setVisibleUnit(inverse2);
        INDArray dup = getParam("b").dup();
        INDArray dup2 = getParam(PretrainParamInitializer.VISIBLE_BIAS_KEY).dup();
        rbm.setParam(PretrainParamInitializer.VISIBLE_BIAS_KEY, dup);
        rbm.setParam("b", dup2);
        rbm.sigma = this.sigma;
        rbm.hiddenSigma = this.hiddenSigma;
        return rbm;
    }
}
