package org.deeplearning4j.nn.params;

import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.Deconvolution2D;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/params/DeconvolutionParamInitializer.class */
public class DeconvolutionParamInitializer extends ConvolutionParamInitializer {
    private static final DeconvolutionParamInitializer INSTANCE = new DeconvolutionParamInitializer();

    public static DeconvolutionParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override // org.deeplearning4j.nn.params.ConvolutionParamInitializer
    protected INDArray createWeightMatrix(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        Deconvolution2D deconvolution2D = (Deconvolution2D) neuralNetConfiguration.getLayer();
        if (!z) {
            int[] kernelSize = deconvolution2D.getKernelSize();
            return WeightInitUtil.reshapeWeights(new long[]{deconvolution2D.getNIn(), deconvolution2D.getNOut(), kernelSize[0], kernelSize[1]}, iNDArray, 'c');
        }
        Distribution createDistribution = Distributions.createDistribution(deconvolution2D.getDist());
        int[] kernelSize2 = deconvolution2D.getKernelSize();
        int[] stride = deconvolution2D.getStride();
        return WeightInitUtil.initWeights(r0 * kernelSize2[0] * kernelSize2[1], ((r0 * kernelSize2[0]) * kernelSize2[1]) / (stride[0] * stride[1]), new long[]{deconvolution2D.getNIn(), deconvolution2D.getNOut(), kernelSize2[0], kernelSize2[1]}, deconvolution2D.getWeightInit(), createDistribution, 'c', iNDArray);
    }

    @Override // org.deeplearning4j.nn.params.ConvolutionParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        Deconvolution2D deconvolution2D = (Deconvolution2D) neuralNetConfiguration.getLayer();
        int[] kernelSize = deconvolution2D.getKernelSize();
        long nIn = deconvolution2D.getNIn();
        long nOut = deconvolution2D.getNOut();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        if (deconvolution2D.hasBias()) {
            INDArray iNDArray2 = iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(0L, nOut));
            INDArray reshape = iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(nOut, numParams(neuralNetConfiguration))).reshape('c', nIn, nOut, kernelSize[0], kernelSize[1]);
            linkedHashMap.put("b", iNDArray2);
            linkedHashMap.put("W", reshape);
        } else {
            linkedHashMap.put("W", iNDArray.reshape('c', nIn, nOut, kernelSize[0], kernelSize[1]));
        }
        return linkedHashMap;
    }
}
