package org.deeplearning4j.nn.params;

import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/params/GravesLSTMParamInitializer.class */
public class GravesLSTMParamInitializer implements ParamInitializer {
    private static final GravesLSTMParamInitializer INSTANCE = new GravesLSTMParamInitializer();
    public static final String RECURRENT_WEIGHT_KEY = "RW";
    public static final String BIAS_KEY = "b";
    public static final String INPUT_WEIGHT_KEY = "W";

    public static GravesLSTMParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public int numParams(NeuralNetConfiguration neuralNetConfiguration) {
        GravesLSTM gravesLSTM = (GravesLSTM) neuralNetConfiguration.getLayer();
        int nOut = gravesLSTM.getNOut();
        return (gravesLSTM.getNIn() * 4 * nOut) + (nOut * ((4 * nOut) + 3)) + (4 * nOut);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> init(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        Map<String, INDArray> synchronizedMap = Collections.synchronizedMap(new LinkedHashMap());
        GravesLSTM gravesLSTM = (GravesLSTM) neuralNetConfiguration.getLayer();
        double forgetGateBiasInit = gravesLSTM.getForgetGateBiasInit();
        Distribution createDistribution = Distributions.createDistribution(gravesLSTM.getDist());
        int nOut = gravesLSTM.getNOut();
        int nIn = gravesLSTM.getNIn();
        neuralNetConfiguration.addVariable("W");
        neuralNetConfiguration.addVariable(RECURRENT_WEIGHT_KEY);
        neuralNetConfiguration.addVariable("b");
        int numParams = numParams(neuralNetConfiguration);
        if (iNDArray.length() != numParams) {
            throw new IllegalStateException("Expected params view of length " + numParams + ", got length " + iNDArray.length());
        }
        int i = nIn * 4 * nOut;
        int i2 = nOut * ((4 * nOut) + 3);
        int i3 = 4 * nOut;
        INDArray iNDArray2 = iNDArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, i));
        INDArray iNDArray3 = iNDArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(i, i + i2));
        INDArray iNDArray4 = iNDArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(i + i2, i + i2 + i3));
        if (z) {
            int i4 = nIn + nOut;
            synchronizedMap.put("W", WeightInitUtil.initWeights(nOut, i4, new int[]{nIn, 4 * nOut}, gravesLSTM.getWeightInit(), createDistribution, iNDArray2));
            synchronizedMap.put(RECURRENT_WEIGHT_KEY, WeightInitUtil.initWeights(nOut, i4, new int[]{nOut, (4 * nOut) + 3}, gravesLSTM.getWeightInit(), createDistribution, iNDArray3));
            iNDArray4.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(nOut, 2 * nOut)}, Nd4j.valueArrayOf(1, nOut, forgetGateBiasInit));
            synchronizedMap.put("b", iNDArray4);
        } else {
            synchronizedMap.put("W", WeightInitUtil.reshapeWeights(new int[]{nIn, 4 * nOut}, iNDArray2));
            synchronizedMap.put(RECURRENT_WEIGHT_KEY, WeightInitUtil.reshapeWeights(new int[]{nOut, (4 * nOut) + 3}, iNDArray3));
            synchronizedMap.put("b", iNDArray4);
        }
        return synchronizedMap;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        GravesLSTM gravesLSTM = (GravesLSTM) neuralNetConfiguration.getLayer();
        int nOut = gravesLSTM.getNOut();
        int nIn = gravesLSTM.getNIn();
        int numParams = numParams(neuralNetConfiguration);
        if (iNDArray.length() != numParams) {
            throw new IllegalStateException("Expected gradient view of length " + numParams + ", got length " + iNDArray.length());
        }
        int i = nIn * 4 * nOut;
        int i2 = nOut * ((4 * nOut) + 3);
        int i3 = 4 * nOut;
        INDArray reshape = iNDArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, i)).reshape('f', nIn, 4 * nOut);
        INDArray reshape2 = iNDArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(i, i + i2)).reshape('f', nOut, (4 * nOut) + 3);
        INDArray iNDArray2 = iNDArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(i + i2, i + i2 + i3));
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put("W", reshape);
        linkedHashMap.put(RECURRENT_WEIGHT_KEY, reshape2);
        linkedHashMap.put("b", iNDArray2);
        return linkedHashMap;
    }
}
