package org.deeplearning4j.nn.modelimport.keras.utils;

import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/utils/KerasLossUtils.class */
public class KerasLossUtils {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) KerasLossUtils.class);

    public static LossFunctions.LossFunction mapLossFunction(String str, KerasLayerConfiguration kerasLayerConfiguration) throws UnsupportedKerasConfigurationException {
        LossFunctions.LossFunction lossFunction;
        if (str.equals(kerasLayerConfiguration.getKERAS_LOSS_MEAN_SQUARED_ERROR()) || str.equals(kerasLayerConfiguration.getKERAS_LOSS_MSE())) {
            lossFunction = LossFunctions.LossFunction.SQUARED_LOSS;
        } else if (str.equals(kerasLayerConfiguration.getKERAS_LOSS_MEAN_ABSOLUTE_ERROR()) || str.equals(kerasLayerConfiguration.getKERAS_LOSS_MAE())) {
            lossFunction = LossFunctions.LossFunction.MEAN_ABSOLUTE_ERROR;
        } else if (str.equals(kerasLayerConfiguration.getKERAS_LOSS_MEAN_ABSOLUTE_PERCENTAGE_ERROR()) || str.equals(kerasLayerConfiguration.getKERAS_LOSS_MAPE())) {
            lossFunction = LossFunctions.LossFunction.MEAN_ABSOLUTE_PERCENTAGE_ERROR;
        } else if (str.equals(kerasLayerConfiguration.getKERAS_LOSS_MEAN_SQUARED_LOGARITHMIC_ERROR()) || str.equals(kerasLayerConfiguration.getKERAS_LOSS_MSLE())) {
            lossFunction = LossFunctions.LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR;
        } else if (str.equals(kerasLayerConfiguration.getKERAS_LOSS_SQUARED_HINGE())) {
            lossFunction = LossFunctions.LossFunction.SQUARED_HINGE;
        } else if (str.equals(kerasLayerConfiguration.getKERAS_LOSS_HINGE())) {
            lossFunction = LossFunctions.LossFunction.HINGE;
        } else {
            if (str.equals(kerasLayerConfiguration.getKERAS_LOSS_SPARSE_CATEGORICAL_CROSSENTROPY())) {
                throw new UnsupportedKerasConfigurationException("Loss function " + str + " not supported yet.");
            }
            if (str.equals(kerasLayerConfiguration.getKERAS_LOSS_BINARY_CROSSENTROPY())) {
                lossFunction = LossFunctions.LossFunction.XENT;
            } else if (str.equals(kerasLayerConfiguration.getKERAS_LOSS_CATEGORICAL_CROSSENTROPY())) {
                lossFunction = LossFunctions.LossFunction.MCXENT;
            } else if (str.equals(kerasLayerConfiguration.getKERAS_LOSS_KULLBACK_LEIBLER_DIVERGENCE()) || str.equals(kerasLayerConfiguration.getKERAS_LOSS_KLD())) {
                lossFunction = LossFunctions.LossFunction.KL_DIVERGENCE;
            } else if (str.equals(kerasLayerConfiguration.getKERAS_LOSS_POISSON())) {
                lossFunction = LossFunctions.LossFunction.POISSON;
            } else {
                if (!str.equals(kerasLayerConfiguration.getKERAS_LOSS_COSINE_PROXIMITY())) {
                    throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + str);
                }
                lossFunction = LossFunctions.LossFunction.COSINE_PROXIMITY;
            }
        }
        return lossFunction;
    }
}
