package org.deeplearning4j.nn.params;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/params/ConvolutionParamInitializer.class */
public class ConvolutionParamInitializer implements ParamInitializer {
    private static final ConvolutionParamInitializer INSTANCE = new ConvolutionParamInitializer();
    public static final String WEIGHT_KEY = "W";
    public static final String BIAS_KEY = "b";

    public static ConvolutionParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public long numParams(NeuralNetConfiguration neuralNetConfiguration) {
        return numParams(neuralNetConfiguration.getLayer());
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public long numParams(Layer layer) {
        ConvolutionLayer convolutionLayer = (ConvolutionLayer) layer;
        int[] kernelSize = convolutionLayer.getKernelSize();
        long nIn = convolutionLayer.getNIn();
        long nOut = convolutionLayer.getNOut();
        return (nIn * nOut * kernelSize[0] * kernelSize[1]) + (convolutionLayer.hasBias() ? nOut : 0L);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> paramKeys(Layer layer) {
        return ((ConvolutionLayer) layer).hasBias() ? Arrays.asList("W", "b") : weightKeys(layer);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> weightKeys(Layer layer) {
        return Collections.singletonList("W");
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> biasKeys(Layer layer) {
        return ((ConvolutionLayer) layer).hasBias() ? Collections.singletonList("b") : Collections.emptyList();
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public boolean isWeightParam(Layer layer, String str) {
        return "W".equals(str);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public boolean isBiasParam(Layer layer, String str) {
        return "b".equals(str);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> init(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        ConvolutionLayer convolutionLayer = (ConvolutionLayer) neuralNetConfiguration.getLayer();
        if (convolutionLayer.getKernelSize().length != 2) {
            throw new IllegalArgumentException("Filter size must be == 2");
        }
        Map<String, INDArray> synchronizedMap = Collections.synchronizedMap(new LinkedHashMap());
        long nOut = ((ConvolutionLayer) neuralNetConfiguration.getLayer()).getNOut();
        if (convolutionLayer.hasBias()) {
            INDArray iNDArray2 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(0L, nOut));
            INDArray iNDArray3 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(nOut, numParams(neuralNetConfiguration)));
            synchronizedMap.put("b", createBias(neuralNetConfiguration, iNDArray2, z));
            synchronizedMap.put("W", createWeightMatrix(neuralNetConfiguration, iNDArray3, z));
            neuralNetConfiguration.addVariable("W");
            neuralNetConfiguration.addVariable("b");
        } else {
            synchronizedMap.put("W", createWeightMatrix(neuralNetConfiguration, iNDArray, z));
            neuralNetConfiguration.addVariable("W");
        }
        return synchronizedMap;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        ConvolutionLayer convolutionLayer = (ConvolutionLayer) neuralNetConfiguration.getLayer();
        int[] kernelSize = convolutionLayer.getKernelSize();
        long nIn = convolutionLayer.getNIn();
        long nOut = convolutionLayer.getNOut();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        if (convolutionLayer.hasBias()) {
            INDArray iNDArray2 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(0L, nOut));
            INDArray reshape = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(nOut, numParams(neuralNetConfiguration))).reshape('c', nOut, nIn, kernelSize[0], kernelSize[1]);
            linkedHashMap.put("b", iNDArray2);
            linkedHashMap.put("W", reshape);
        } else {
            linkedHashMap.put("W", iNDArray.reshape('c', nOut, nIn, kernelSize[0], kernelSize[1]));
        }
        return linkedHashMap;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public INDArray createBias(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        ConvolutionLayer convolutionLayer = (ConvolutionLayer) neuralNetConfiguration.getLayer();
        if (z) {
            iNDArray.assign(Double.valueOf(convolutionLayer.getBiasInit()));
        }
        return iNDArray;
    }

    protected INDArray createWeightMatrix(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        ConvolutionLayer convolutionLayer = (ConvolutionLayer) neuralNetConfiguration.getLayer();
        if (!z) {
            int[] kernelSize = convolutionLayer.getKernelSize();
            return WeightInitUtil.reshapeWeights(new long[]{convolutionLayer.getNOut(), convolutionLayer.getNIn(), kernelSize[0], kernelSize[1]}, iNDArray, 'c');
        }
        int[] kernelSize2 = convolutionLayer.getKernelSize();
        int[] stride = convolutionLayer.getStride();
        long nIn = convolutionLayer.getNIn();
        return convolutionLayer.getWeightInitFn().init(nIn * kernelSize2[0] * kernelSize2[1], ((r0 * kernelSize2[0]) * kernelSize2[1]) / (stride[0] * stride[1]), new long[]{convolutionLayer.getNOut(), nIn, kernelSize2[0], kernelSize2[1]}, 'c', iNDArray);
    }
}
