package org.deeplearning4j.nn.params;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/params/GravesLSTMParamInitializer.class */
public class GravesLSTMParamInitializer implements ParamInitializer {
    private static final GravesLSTMParamInitializer INSTANCE = new GravesLSTMParamInitializer();
    public static final String RECURRENT_WEIGHT_KEY = "RW";
    public static final String BIAS_KEY = "b";
    public static final String INPUT_WEIGHT_KEY = "W";

    public static GravesLSTMParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public long numParams(NeuralNetConfiguration neuralNetConfiguration) {
        return numParams(neuralNetConfiguration.getLayer());
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public long numParams(Layer layer) {
        GravesLSTM gravesLSTM = (GravesLSTM) layer;
        long nOut = gravesLSTM.getNOut();
        return (gravesLSTM.getNIn() * 4 * nOut) + (nOut * ((4 * nOut) + 3)) + (4 * nOut);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> paramKeys(Layer layer) {
        return Arrays.asList("W", "RW", "b");
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> weightKeys(Layer layer) {
        return Arrays.asList("W", "RW");
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> biasKeys(Layer layer) {
        return Collections.singletonList("b");
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public boolean isWeightParam(Layer layer, String str) {
        return "RW".equals(str) || "W".equals(str);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public boolean isBiasParam(Layer layer, String str) {
        return "b".equals(str);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> init(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        WeightInit weightInit;
        Map<String, INDArray> synchronizedMap = Collections.synchronizedMap(new LinkedHashMap());
        GravesLSTM gravesLSTM = (GravesLSTM) neuralNetConfiguration.getLayer();
        double forgetGateBiasInit = gravesLSTM.getForgetGateBiasInit();
        Distribution createDistribution = Distributions.createDistribution(gravesLSTM.getDist());
        long nOut = gravesLSTM.getNOut();
        long nIn = gravesLSTM.getNIn();
        neuralNetConfiguration.addVariable("W");
        neuralNetConfiguration.addVariable("RW");
        neuralNetConfiguration.addVariable("b");
        long numParams = numParams(neuralNetConfiguration);
        if (iNDArray.length() != numParams) {
            throw new IllegalStateException("Expected params view of length " + numParams + ", got length " + iNDArray.length());
        }
        long j = nIn * 4 * nOut;
        long j2 = nOut * ((4 * nOut) + 3);
        long j3 = 4 * nOut;
        INDArray iNDArray2 = iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(0L, j));
        INDArray iNDArray3 = iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(j, j + j2));
        INDArray iNDArray4 = iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(j + j2, j + j2 + j3));
        if (z) {
            long j4 = nIn + nOut;
            long[] jArr = {nIn, 4 * nOut};
            long[] jArr2 = {nOut, (4 * nOut) + 3};
            Distribution distribution = createDistribution;
            if (gravesLSTM.getWeightInitRecurrent() != null) {
                weightInit = gravesLSTM.getWeightInitRecurrent();
                if (gravesLSTM.getDistRecurrent() != null) {
                    distribution = Distributions.createDistribution(gravesLSTM.getDistRecurrent());
                }
            } else {
                weightInit = gravesLSTM.getWeightInit();
            }
            synchronizedMap.put("W", WeightInitUtil.initWeights(nOut, j4, jArr, gravesLSTM.getWeightInit(), createDistribution, iNDArray2));
            synchronizedMap.put("RW", WeightInitUtil.initWeights(nOut, j4, jArr2, weightInit, distribution, iNDArray3));
            iNDArray4.put(new INDArrayIndex[]{NDArrayIndex.point(0L), NDArrayIndex.interval(nOut, 2 * nOut)}, Nd4j.valueArrayOf(new long[]{1, nOut}, forgetGateBiasInit));
            synchronizedMap.put("b", iNDArray4);
        } else {
            synchronizedMap.put("W", WeightInitUtil.reshapeWeights(new long[]{nIn, 4 * nOut}, iNDArray2));
            synchronizedMap.put("RW", WeightInitUtil.reshapeWeights(new long[]{nOut, (4 * nOut) + 3}, iNDArray3));
            synchronizedMap.put("b", iNDArray4);
        }
        return synchronizedMap;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        GravesLSTM gravesLSTM = (GravesLSTM) neuralNetConfiguration.getLayer();
        long nOut = gravesLSTM.getNOut();
        long nIn = gravesLSTM.getNIn();
        long numParams = numParams(neuralNetConfiguration);
        if (iNDArray.length() != numParams) {
            throw new IllegalStateException("Expected gradient view of length " + numParams + ", got length " + iNDArray.length());
        }
        long j = nIn * 4 * nOut;
        long j2 = nOut * ((4 * nOut) + 3);
        long j3 = 4 * nOut;
        INDArray reshape = iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(0L, j)).reshape('f', nIn, 4 * nOut);
        INDArray reshape2 = iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(j, j + j2)).reshape('f', nOut, (4 * nOut) + 3);
        INDArray iNDArray2 = iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(j + j2, j + j2 + j3));
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put("W", reshape);
        linkedHashMap.put("RW", reshape2);
        linkedHashMap.put("b", iNDArray2);
        return linkedHashMap;
    }
}
