package org.deeplearning4j.nn.modelimport.keras.layers;

import java.util.ArrayList;
import java.util.Map;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/layers/KerasLoss.class */
public class KerasLoss extends KerasLayer {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) KerasLoss.class);
    public static final String KERAS_CLASS_NAME_LOSS = "Loss";

    public KerasLoss(String str, String str2, String str3) throws UnsupportedKerasConfigurationException {
        this(str, str2, str3, true);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public KerasLoss(String str, String str2, String str3, boolean z) throws UnsupportedKerasConfigurationException {
        LossFunctions.LossFunction lossFunction;
        this.className = KERAS_CLASS_NAME_LOSS;
        this.layerName = str;
        this.inputShape = null;
        this.dimOrder = KerasLayer.DimOrder.NONE;
        this.inboundLayerNames = new ArrayList();
        this.inboundLayerNames.add(str2);
        try {
            lossFunction = mapLossFunction(str3);
        } catch (UnsupportedKerasConfigurationException e) {
            if (z) {
                throw e;
            }
            log.warn("Unsupported Keras loss function. Replacing with MSE.");
            lossFunction = LossFunctions.LossFunction.SQUARED_LOSS;
        }
        this.layer = ((LossLayer.Builder) new LossLayer.Builder(lossFunction).name(str)).build();
    }

    private KerasLoss(Map<String, Object> map) {
    }

    private KerasLoss(Map<String, Object> map, boolean z) {
    }

    public LossLayer getLossLayer() {
        return (LossLayer) this.layer;
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public InputType getOutputType(InputType... inputTypeArr) throws InvalidKerasConfigurationException {
        if (inputTypeArr.length > 1) {
            throw new InvalidKerasConfigurationException("Keras Loss layer accepts only one input (received " + inputTypeArr.length + ")");
        }
        return getLossLayer().getOutputType(-1, inputTypeArr[0]);
    }
}
