package org.deeplearning4j.nn.params;

import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/params/PretrainParamInitializer.class */
public class PretrainParamInitializer extends DefaultParamInitializer {
    private static final PretrainParamInitializer INSTANCE = new PretrainParamInitializer();
    public static final String VISIBLE_BIAS_KEY = "vb";

    public static PretrainParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public int numParams(NeuralNetConfiguration neuralNetConfiguration) {
        return super.numParams(neuralNetConfiguration) + ((BasePretrainNetwork) neuralNetConfiguration.getLayer()).getNIn();
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> init(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        Map<String, INDArray> init = super.init(neuralNetConfiguration, iNDArray, z);
        BasePretrainNetwork basePretrainNetwork = (BasePretrainNetwork) neuralNetConfiguration.getLayer();
        int nIn = basePretrainNetwork.getNIn();
        int nOut = basePretrainNetwork.getNOut();
        int i = nIn * nOut;
        init.put(VISIBLE_BIAS_KEY, createVisibleBias(neuralNetConfiguration, iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i + nOut, i + nOut + nIn)}), z));
        neuralNetConfiguration.addVariable(VISIBLE_BIAS_KEY);
        return init;
    }

    protected INDArray createVisibleBias(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        BasePretrainNetwork basePretrainNetwork = (BasePretrainNetwork) neuralNetConfiguration.getLayer();
        if (z) {
            iNDArray.assign(Nd4j.valueArrayOf(basePretrainNetwork.getNIn(), basePretrainNetwork.getVisibleBiasInit()));
        }
        return iNDArray;
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        Map<String, INDArray> gradientsFromFlattened = super.getGradientsFromFlattened(neuralNetConfiguration, iNDArray);
        gradientsFromFlattened.put(VISIBLE_BIAS_KEY, Nd4j.valueArrayOf(((FeedForwardLayer) neuralNetConfiguration.getLayer()).getNIn(), 0.0d));
        return gradientsFromFlattened;
    }
}
