package org.deeplearning4j.nn.layers.ocnn;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/layers/ocnn/OCNNParamInitializer.class */
public class OCNNParamInitializer extends DefaultParamInitializer {
    public static final String NU_KEY = "nu";
    public static final String K_KEY = "k";
    private static final OCNNParamInitializer INSTANCE = new OCNNParamInitializer();
    public static final String W_KEY = "w";
    public static final String V_KEY = "v";
    public static final String R_KEY = "r";
    private static final List<String> WEIGHT_KEYS = Arrays.asList(W_KEY, V_KEY, R_KEY);
    private static final List<String> PARAM_KEYS = Arrays.asList(W_KEY, V_KEY, R_KEY);

    public static OCNNParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public long numParams(NeuralNetConfiguration neuralNetConfiguration) {
        return numParams(neuralNetConfiguration.getLayer());
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public long numParams(Layer layer) {
        org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer oCNNOutputLayer = (org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) layer;
        long nIn = oCNNOutputLayer.getNIn();
        int hiddenSize = oCNNOutputLayer.getHiddenSize();
        return hiddenSize + (nIn * hiddenSize) + 1;
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public List<String> paramKeys(Layer layer) {
        return PARAM_KEYS;
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public List<String> weightKeys(Layer layer) {
        return WEIGHT_KEYS;
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public List<String> biasKeys(Layer layer) {
        return Collections.emptyList();
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public boolean isWeightParam(Layer layer, String str) {
        return WEIGHT_KEYS.contains(str);
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public boolean isBiasParam(Layer layer, String str) {
        return false;
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> init(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer oCNNOutputLayer = (org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) neuralNetConfiguration.getLayer();
        Map<String, INDArray> synchronizedMap = Collections.synchronizedMap(new LinkedHashMap());
        long nIn = oCNNOutputLayer.getNIn();
        int hiddenSize = oCNNOutputLayer.getHiddenSize();
        long j = nIn * hiddenSize;
        INDArray reshape = iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(0, hiddenSize)).reshape(1L, hiddenSize);
        INDArray reshape2 = iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(hiddenSize, hiddenSize + j)).reshape('f', nIn, hiddenSize);
        INDArray iNDArray2 = iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.point(iNDArray.length() - 1));
        synchronizedMap.put(W_KEY, createWeightMatrix(neuralNetConfiguration, reshape, z));
        neuralNetConfiguration.addVariable(W_KEY);
        synchronizedMap.put(V_KEY, createWeightMatrix(neuralNetConfiguration, reshape2, z));
        neuralNetConfiguration.addVariable(V_KEY);
        synchronizedMap.put(R_KEY, createWeightMatrix(neuralNetConfiguration, iNDArray2, z));
        neuralNetConfiguration.addVariable(R_KEY);
        return synchronizedMap;
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer oCNNOutputLayer = (org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) neuralNetConfiguration.getLayer();
        Map<String, INDArray> synchronizedMap = Collections.synchronizedMap(new LinkedHashMap());
        long nIn = oCNNOutputLayer.getNIn();
        int hiddenSize = oCNNOutputLayer.getHiddenSize();
        long j = nIn * hiddenSize;
        INDArray reshape = iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(0, hiddenSize)).reshape('f', 1, hiddenSize);
        INDArray reshape2 = iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(hiddenSize, hiddenSize + j)).reshape('f', nIn, hiddenSize);
        synchronizedMap.put(W_KEY, reshape);
        synchronizedMap.put(V_KEY, reshape2);
        synchronizedMap.put(R_KEY, iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.point(iNDArray.length() - 1)));
        return synchronizedMap;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer
    public INDArray createWeightMatrix(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer oCNNOutputLayer = (org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) neuralNetConfiguration.getLayer();
        return z ? WeightInitUtil.initWeights(iNDArray.size(0), iNDArray.size(1), iNDArray.shape(), oCNNOutputLayer.getWeightInit(), Distributions.createDistribution(oCNNOutputLayer.getDist()), iNDArray) : WeightInitUtil.reshapeWeights(iNDArray.shape(), iNDArray);
    }
}
