package org.deeplearning4j.nn.modelimport.keras.layers;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.params.GravesLSTMParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/layers/KerasLstm.class */
public class KerasLstm extends KerasLayer {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) KerasLstm.class);
    public static final String LAYER_FIELD_INNER_INIT = "inner_init";
    public static final String LAYER_FIELD_INNER_ACTIVATION = "inner_activation";
    public static final String LAYER_FIELD_FORGET_BIAS_INIT = "forget_bias_init";
    public static final String LAYER_FIELD_DROPOUT_U = "dropout_U";
    public static final String LAYER_FIELD_UNROLL = "unroll";
    public static final String LSTM_FORGET_BIAS_INIT_ZERO = "zero";
    public static final String LSTM_FORGET_BIAS_INIT_ONE = "one";
    public static final int NUM_TRAINABLE_PARAMS = 12;
    public static final String KERAS_PARAM_NAME_W_C = "W_c";
    public static final String KERAS_PARAM_NAME_W_F = "W_f";
    public static final String KERAS_PARAM_NAME_W_I = "W_i";
    public static final String KERAS_PARAM_NAME_W_O = "W_o";
    public static final String KERAS_PARAM_NAME_U_C = "U_c";
    public static final String KERAS_PARAM_NAME_U_F = "U_f";
    public static final String KERAS_PARAM_NAME_U_I = "U_i";
    public static final String KERAS_PARAM_NAME_U_O = "U_o";
    public static final String KERAS_PARAM_NAME_B_C = "b_c";
    public static final String KERAS_PARAM_NAME_B_F = "b_f";
    public static final String KERAS_PARAM_NAME_B_I = "b_i";
    public static final String KERAS_PARAM_NAME_B_O = "b_o";
    public static final int NUM_WEIGHTS_IN_KERAS_LSTM = 12;
    protected boolean unroll;

    public KerasLstm(Map<String, Object> map) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(map, true);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public KerasLstm(Map<String, Object> map, boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        super(map, z);
        this.unroll = false;
        WeightInit weightInitFromConfig = getWeightInitFromConfig(map, z);
        if (weightInitFromConfig != getRecurrentWeightInitFromConfig(map, z)) {
            if (z) {
                throw new UnsupportedKerasConfigurationException("Specifying different initialization for recurrent weights not supported.");
            }
            log.warn("Specifying different initialization for recurrent weights not supported.");
        }
        getRecurrentDropout(map);
        this.unroll = getUnrollRecurrentLayer(map);
        this.layer = ((GravesLSTM.Builder) ((GravesLSTM.Builder) ((GravesLSTM.Builder) ((GravesLSTM.Builder) ((GravesLSTM.Builder) ((GravesLSTM.Builder) ((GravesLSTM.Builder) ((GravesLSTM.Builder) new GravesLSTM.Builder().gateActivationFunction(getGateActivationFromConfig(map)).forgetGateBiasInit(getForgetBiasInitFromConfig(map, z)).name(this.layerName)).nOut(getNOutFromConfig(map))).dropOut(this.dropout)).activation(getActivationFromConfig(map))).weightInit(weightInitFromConfig)).biasInit(CMAESOptimizer.DEFAULT_STOPFITNESS)).l1(this.weightL1Regularization)).l2(this.weightL2Regularization)).build();
    }

    public GravesLSTM getGravesLSTMLayer() {
        return (GravesLSTM) this.layer;
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public InputType getOutputType(InputType... inputTypeArr) throws InvalidKerasConfigurationException {
        if (inputTypeArr.length > 1) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer accepts only one input (received " + inputTypeArr.length + ")");
        }
        return getGravesLSTMLayer().getOutputType(-1, inputTypeArr[0]);
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public int getNumParams() {
        return 12;
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public void setWeights(Map<String, INDArray> map) throws InvalidKerasConfigurationException {
        this.weights = new HashMap();
        if (!map.containsKey(KERAS_PARAM_NAME_W_C)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter W_c");
        }
        INDArray iNDArray = map.get(KERAS_PARAM_NAME_W_C);
        if (!map.containsKey(KERAS_PARAM_NAME_W_F)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter W_f");
        }
        INDArray iNDArray2 = map.get(KERAS_PARAM_NAME_W_F);
        if (!map.containsKey(KERAS_PARAM_NAME_W_O)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter W_o");
        }
        INDArray iNDArray3 = map.get(KERAS_PARAM_NAME_W_O);
        if (!map.containsKey(KERAS_PARAM_NAME_W_I)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter W_i");
        }
        INDArray iNDArray4 = map.get(KERAS_PARAM_NAME_W_I);
        INDArray zeros = Nd4j.zeros(iNDArray.rows(), iNDArray.columns() + iNDArray2.columns() + iNDArray3.columns() + iNDArray4.columns());
        zeros.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros.rows()), NDArrayIndex.interval(0, iNDArray.columns())}, iNDArray);
        zeros.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros.rows()), NDArrayIndex.interval(iNDArray.columns(), iNDArray.columns() + iNDArray2.columns())}, iNDArray2);
        zeros.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros.rows()), NDArrayIndex.interval(iNDArray.columns() + iNDArray2.columns(), iNDArray.columns() + iNDArray2.columns() + iNDArray3.columns())}, iNDArray3);
        zeros.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros.rows()), NDArrayIndex.interval(iNDArray.columns() + iNDArray2.columns() + iNDArray3.columns(), iNDArray.columns() + iNDArray2.columns() + iNDArray3.columns() + iNDArray4.columns())}, iNDArray4);
        this.weights.put("W", zeros);
        if (!map.containsKey(KERAS_PARAM_NAME_U_C)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter U_c");
        }
        INDArray iNDArray5 = map.get(KERAS_PARAM_NAME_U_C);
        if (!map.containsKey(KERAS_PARAM_NAME_U_F)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter U_f");
        }
        INDArray iNDArray6 = map.get(KERAS_PARAM_NAME_U_F);
        if (!map.containsKey(KERAS_PARAM_NAME_U_O)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter U_o");
        }
        INDArray iNDArray7 = map.get(KERAS_PARAM_NAME_U_O);
        if (!map.containsKey(KERAS_PARAM_NAME_U_I)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter U_i");
        }
        INDArray iNDArray8 = map.get(KERAS_PARAM_NAME_U_I);
        INDArray zeros2 = Nd4j.zeros(iNDArray5.rows(), iNDArray5.columns() + iNDArray6.columns() + iNDArray7.columns() + iNDArray8.columns() + 3);
        zeros2.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros2.rows()), NDArrayIndex.interval(0, iNDArray5.columns())}, iNDArray5);
        zeros2.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros2.rows()), NDArrayIndex.interval(iNDArray5.columns(), iNDArray5.columns() + iNDArray6.columns())}, iNDArray6);
        zeros2.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros2.rows()), NDArrayIndex.interval(iNDArray5.columns() + iNDArray6.columns(), iNDArray5.columns() + iNDArray6.columns() + iNDArray7.columns())}, iNDArray7);
        zeros2.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros2.rows()), NDArrayIndex.interval(iNDArray5.columns() + iNDArray6.columns() + iNDArray7.columns(), iNDArray5.columns() + iNDArray6.columns() + iNDArray7.columns() + iNDArray8.columns())}, iNDArray8);
        this.weights.put(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY, zeros2);
        if (!map.containsKey(KERAS_PARAM_NAME_B_C)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter b_c");
        }
        INDArray iNDArray9 = map.get(KERAS_PARAM_NAME_B_C);
        if (!map.containsKey(KERAS_PARAM_NAME_B_F)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter b_f");
        }
        INDArray iNDArray10 = map.get(KERAS_PARAM_NAME_B_F);
        if (!map.containsKey(KERAS_PARAM_NAME_B_O)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter b_o");
        }
        INDArray iNDArray11 = map.get(KERAS_PARAM_NAME_B_O);
        if (!map.containsKey(KERAS_PARAM_NAME_B_I)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter b_i");
        }
        INDArray iNDArray12 = map.get(KERAS_PARAM_NAME_B_I);
        INDArray zeros3 = Nd4j.zeros(iNDArray9.rows(), iNDArray9.columns() + iNDArray10.columns() + iNDArray11.columns() + iNDArray12.columns());
        zeros3.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros3.rows()), NDArrayIndex.interval(0, iNDArray9.columns())}, iNDArray9);
        zeros3.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros3.rows()), NDArrayIndex.interval(iNDArray9.columns(), iNDArray9.columns() + iNDArray10.columns())}, iNDArray10);
        zeros3.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros3.rows()), NDArrayIndex.interval(iNDArray9.columns() + iNDArray10.columns(), iNDArray9.columns() + iNDArray10.columns() + iNDArray11.columns())}, iNDArray11);
        zeros3.put(new INDArrayIndex[]{NDArrayIndex.interval(0, zeros3.rows()), NDArrayIndex.interval(iNDArray9.columns() + iNDArray10.columns() + iNDArray11.columns(), iNDArray9.columns() + iNDArray10.columns() + iNDArray11.columns() + iNDArray12.columns())}, iNDArray12);
        this.weights.put("b", zeros3);
        if (map.size() > 12) {
            Set<String> keySet = map.keySet();
            keySet.remove(KERAS_PARAM_NAME_W_C);
            keySet.remove(KERAS_PARAM_NAME_W_F);
            keySet.remove(KERAS_PARAM_NAME_W_I);
            keySet.remove(KERAS_PARAM_NAME_W_O);
            keySet.remove(KERAS_PARAM_NAME_U_C);
            keySet.remove(KERAS_PARAM_NAME_U_F);
            keySet.remove(KERAS_PARAM_NAME_U_I);
            keySet.remove(KERAS_PARAM_NAME_U_O);
            keySet.remove(KERAS_PARAM_NAME_B_C);
            keySet.remove(KERAS_PARAM_NAME_B_F);
            keySet.remove(KERAS_PARAM_NAME_B_I);
            keySet.remove(KERAS_PARAM_NAME_B_O);
            String obj = keySet.toString();
            log.warn("Attemping to set weights for unknown parameters: " + obj.substring(1, obj.length() - 1));
        }
    }

    public boolean getUnroll() {
        return this.unroll;
    }

    public static boolean getUnrollRecurrentLayer(Map<String, Object> map) throws InvalidKerasConfigurationException {
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map);
        if (innerLayerConfigFromConfig.containsKey(LAYER_FIELD_UNROLL)) {
            return ((Boolean) innerLayerConfigFromConfig.get(LAYER_FIELD_UNROLL)).booleanValue();
        }
        throw new InvalidKerasConfigurationException("Keras LSTM layer config missing unroll field");
    }

    public static WeightInit getRecurrentWeightInitFromConfig(Map<String, Object> map, boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        WeightInit weightInit;
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map);
        if (!innerLayerConfigFromConfig.containsKey(LAYER_FIELD_INNER_INIT)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer config missing inner_init field");
        }
        String str = (String) innerLayerConfigFromConfig.get(LAYER_FIELD_INNER_INIT);
        try {
            weightInit = mapWeightInitialization(str);
        } catch (UnsupportedKerasConfigurationException e) {
            if (z) {
                throw e;
            }
            weightInit = WeightInit.XAVIER;
            log.warn("Unknown weight initializer " + str + " (Using XAVIER instead).");
        }
        return weightInit;
    }

    public static double getRecurrentDropout(Map<String, Object> map) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map);
        double d = 1.0d;
        if (innerLayerConfigFromConfig.containsKey(LAYER_FIELD_DROPOUT_U)) {
            d = 1.0d - ((Double) innerLayerConfigFromConfig.get(LAYER_FIELD_DROPOUT_U)).doubleValue();
        }
        if (d < 1.0d) {
            throw new UnsupportedKerasConfigurationException("Dropout > 0 on LSTM recurrent connections not supported.");
        }
        return d;
    }

    public static IActivation getGateActivationFromConfig(Map<String, Object> map) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map);
        if (innerLayerConfigFromConfig.containsKey(LAYER_FIELD_INNER_ACTIVATION)) {
            return mapActivation((String) innerLayerConfigFromConfig.get(LAYER_FIELD_INNER_ACTIVATION));
        }
        throw new InvalidKerasConfigurationException("Keras LSTM layer config missing inner_activation field");
    }

    public static double getForgetBiasInitFromConfig(Map<String, Object> map, boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        double d;
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map);
        if (!innerLayerConfigFromConfig.containsKey(LAYER_FIELD_FORGET_BIAS_INIT)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer config missing forget_bias_init field");
        }
        String str = (String) innerLayerConfigFromConfig.get(LAYER_FIELD_FORGET_BIAS_INIT);
        boolean z2 = -1;
        switch (str.hashCode()) {
            case 110182:
                if (str.equals(LSTM_FORGET_BIAS_INIT_ONE)) {
                    z2 = true;
                    break;
                }
                break;
            case 3735208:
                if (str.equals("zero")) {
                    z2 = false;
                    break;
                }
                break;
        }
        switch (z2) {
            case false:
                d = 0.0d;
                break;
            case true:
                d = 1.0d;
                break;
            default:
                if (!z) {
                    d = 1.0d;
                    log.warn("Unsupported LSTM forget gate bias initialization: " + str + " (using 1 instead)");
                    break;
                } else {
                    throw new UnsupportedKerasConfigurationException("Unsupported LSTM forget gate bias initialization: " + str);
                }
        }
        return d;
    }
}
