package org.nd4j.autodiff.samediff.ops;

import lombok.NonNull;
import org.deeplearning4j.nn.layers.samediff.SameDiffOutputLayer;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;

/* loaded from: input_file:org/nd4j/autodiff/samediff/ops/SDLoss.class */
public class SDLoss extends SDOps {
    public SDLoss(SameDiff sameDiff) {
        super(sameDiff);
    }

    public SDVariable absoluteDifference(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return absoluteDifference(str, sDVariable, sDVariable2, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT);
    }

    public SDVariable absoluteDifference(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, SDVariable sDVariable3, @NonNull LossReduce lossReduce) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("absolute difference loss", "predictions", sDVariable2);
        SDValidation.validateNumerical("absolute difference loss", SameDiffOutputLayer.LABELS_KEY, sDVariable);
        if (sDVariable3 == null) {
            sDVariable3 = this.sd.scalar(null, sDVariable2.dataType(), Double.valueOf(1.0d));
        }
        SDVariable updateVariableNameAndReference = updateVariableNameAndReference(f().lossAbsoluteDifference(sDVariable, sDVariable2, sDVariable3, lossReduce), str);
        updateVariableNameAndReference.markAsLoss();
        return updateVariableNameAndReference;
    }

    public SDVariable absoluteDifference(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull LossReduce lossReduce) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return absoluteDifference(str, sDVariable, sDVariable2, null, lossReduce);
    }

    public SDVariable cosineDistance(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, int i) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return cosineDistance(str, sDVariable, sDVariable2, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, i);
    }

    public SDVariable cosineDistance(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, SDVariable sDVariable3, @NonNull LossReduce lossReduce, int i) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("cosine distance loss", "predictions", sDVariable2);
        SDValidation.validateNumerical("cosine distance loss", SameDiffOutputLayer.LABELS_KEY, sDVariable);
        if (sDVariable3 == null) {
            sDVariable3 = this.sd.scalar(null, sDVariable2.dataType(), Double.valueOf(1.0d));
        }
        SDVariable updateVariableNameAndReference = updateVariableNameAndReference(f().lossCosineDistance(sDVariable, sDVariable2, sDVariable3, lossReduce, i), str);
        updateVariableNameAndReference.markAsLoss();
        return updateVariableNameAndReference;
    }

    public SDVariable cosineDistance(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull LossReduce lossReduce, int i) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return cosineDistance(str, sDVariable, sDVariable2, null, lossReduce, i);
    }

    public SDVariable hingeLoss(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return hingeLoss(str, sDVariable, sDVariable2, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT);
    }

    public SDVariable hingeLoss(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, SDVariable sDVariable3, @NonNull LossReduce lossReduce) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("hinge loss", "predictions", sDVariable2);
        SDValidation.validateNumerical("hinge loss", SameDiffOutputLayer.LABELS_KEY, sDVariable);
        if (sDVariable3 == null) {
            sDVariable3 = this.sd.scalar(null, sDVariable2.dataType(), Double.valueOf(1.0d));
        }
        SDVariable updateVariableNameAndReference = updateVariableNameAndReference(f().lossHinge(sDVariable, sDVariable2, sDVariable3, lossReduce), str);
        updateVariableNameAndReference.markAsLoss();
        return updateVariableNameAndReference;
    }

    public SDVariable hingeLoss(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull LossReduce lossReduce) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return hingeLoss(str, sDVariable, sDVariable2, null, lossReduce);
    }

    public SDVariable huberLoss(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, double d) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return huberLoss(str, sDVariable, sDVariable2, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, d);
    }

    public SDVariable huberLoss(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, SDVariable sDVariable3, @NonNull LossReduce lossReduce, double d) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("huber loss", "predictions", sDVariable2);
        SDValidation.validateNumerical("huber loss", SameDiffOutputLayer.LABELS_KEY, sDVariable);
        if (sDVariable3 == null) {
            sDVariable3 = this.sd.scalar(null, sDVariable2.dataType(), Double.valueOf(1.0d));
        }
        SDVariable updateVariableNameAndReference = updateVariableNameAndReference(f().lossHuber(sDVariable, sDVariable2, sDVariable3, lossReduce, d), str);
        updateVariableNameAndReference.markAsLoss();
        return updateVariableNameAndReference;
    }

    public SDVariable huberLoss(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull LossReduce lossReduce, double d) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return huberLoss(str, sDVariable, sDVariable2, null, lossReduce, d);
    }

    public SDVariable l2Loss(@NonNull SDVariable sDVariable) {
        if (sDVariable == null) {
            throw new NullPointerException("var is marked @NonNull but is null");
        }
        return l2Loss(null, sDVariable);
    }

    public SDVariable l2Loss(String str, @NonNull SDVariable sDVariable) {
        if (sDVariable == null) {
            throw new NullPointerException("var is marked @NonNull but is null");
        }
        SDValidation.validateNumerical("l2 loss", sDVariable);
        SDVariable updateVariableNameAndReference = updateVariableNameAndReference(f().lossL2(sDVariable), str);
        updateVariableNameAndReference.markAsLoss();
        return updateVariableNameAndReference;
    }

    public SDVariable logLoss(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return logLoss(str, sDVariable, sDVariable2, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 1.0E-7d);
    }

    public SDVariable logLoss(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, SDVariable sDVariable3, @NonNull LossReduce lossReduce, double d) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("log loss", "predictions", sDVariable2);
        SDValidation.validateNumerical("log loss", SameDiffOutputLayer.LABELS_KEY, sDVariable);
        if (sDVariable3 == null) {
            sDVariable3 = this.sd.scalar(null, sDVariable2.dataType(), Double.valueOf(1.0d));
        }
        SDVariable updateVariableNameAndReference = updateVariableNameAndReference(f().lossLog(sDVariable, sDVariable2, sDVariable3, lossReduce, d), str);
        updateVariableNameAndReference.markAsLoss();
        return updateVariableNameAndReference;
    }

    public SDVariable logLoss(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull LossReduce lossReduce) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return logLoss(str, sDVariable, sDVariable2, null, lossReduce, 1.0E-7d);
    }

    public SDVariable logPoisson(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return logPoisson(str, sDVariable, sDVariable2, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT);
    }

    public SDVariable logPoisson(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, SDVariable sDVariable3, @NonNull LossReduce lossReduce) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("log poisson loss", "predictions", sDVariable2);
        SDValidation.validateNumerical("log poisson loss", SameDiffOutputLayer.LABELS_KEY, sDVariable);
        if (sDVariable3 == null) {
            sDVariable3 = this.sd.scalar(null, sDVariable2.dataType(), Double.valueOf(1.0d));
        }
        SDVariable updateVariableNameAndReference = updateVariableNameAndReference(f().lossLogPoisson(sDVariable, sDVariable2, sDVariable3, lossReduce), str);
        updateVariableNameAndReference.markAsLoss();
        return updateVariableNameAndReference;
    }

    public SDVariable logPoisson(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull LossReduce lossReduce) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return logPoisson(str, sDVariable, sDVariable2, null, lossReduce);
    }

    public SDVariable logPoissonFull(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return logPoissonFull(str, sDVariable, sDVariable2, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT);
    }

    public SDVariable logPoissonFull(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, SDVariable sDVariable3, @NonNull LossReduce lossReduce) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("log poisson (full) loss", "predictions", sDVariable2);
        SDValidation.validateNumerical("log poisson (full) loss", SameDiffOutputLayer.LABELS_KEY, sDVariable);
        if (sDVariable3 == null) {
            sDVariable3 = this.sd.scalar(null, sDVariable2.dataType(), Double.valueOf(1.0d));
        }
        SDVariable updateVariableNameAndReference = updateVariableNameAndReference(f().lossLogPoissonFull(sDVariable, sDVariable2, sDVariable3, lossReduce), str);
        updateVariableNameAndReference.markAsLoss();
        return updateVariableNameAndReference;
    }

    public SDVariable logPoissonFull(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull LossReduce lossReduce) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return logPoissonFull(str, sDVariable, sDVariable2, null, lossReduce);
    }

    public SDVariable meanPairwiseSquaredError(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull LossReduce lossReduce) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return meanPairwiseSquaredError(str, sDVariable, sDVariable2, null, lossReduce);
    }

    public SDVariable meanPairwiseSquaredError(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, SDVariable sDVariable3, @NonNull LossReduce lossReduce) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("main pairwise squared error loss", "predictions", sDVariable2);
        SDValidation.validateNumerical("mean pairwise squared error loss", SameDiffOutputLayer.LABELS_KEY, sDVariable);
        if (sDVariable3 == null) {
            sDVariable3 = this.sd.scalar(null, sDVariable2.dataType(), Double.valueOf(1.0d));
        }
        SDVariable updateVariableNameAndReference = updateVariableNameAndReference(f().lossMeanPairwiseSquaredError(sDVariable, sDVariable2, sDVariable3, lossReduce), str);
        updateVariableNameAndReference.markAsLoss();
        return updateVariableNameAndReference;
    }

    public SDVariable meanSquaredError(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return meanSquaredError(str, sDVariable, sDVariable2, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT);
    }

    public SDVariable meanSquaredError(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, SDVariable sDVariable3, @NonNull LossReduce lossReduce) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("mean squared error loss", "predictions", sDVariable2);
        SDValidation.validateNumerical("mean squared error loss", SameDiffOutputLayer.LABELS_KEY, sDVariable);
        if (sDVariable3 == null) {
            sDVariable3 = this.sd.scalar(null, sDVariable2.dataType(), Double.valueOf(1.0d));
        }
        SDVariable updateVariableNameAndReference = updateVariableNameAndReference(f().lossMeanSquaredError(sDVariable, sDVariable2, sDVariable3, lossReduce), str);
        updateVariableNameAndReference.markAsLoss();
        return updateVariableNameAndReference;
    }

    public SDVariable meanSquaredError(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull LossReduce lossReduce) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return meanSquaredError(str, sDVariable, sDVariable2, null, lossReduce);
    }

    public SDVariable sigmoidCrossEntropy(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return sigmoidCrossEntropy(str, sDVariable, sDVariable2, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0d);
    }

    public SDVariable sigmoidCrossEntropy(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, SDVariable sDVariable3, @NonNull LossReduce lossReduce, double d) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictionLogits is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("sigmoid cross entropy loss", "predictions", sDVariable2);
        SDValidation.validateNumerical("sigmoid cross entropy loss", SameDiffOutputLayer.LABELS_KEY, sDVariable);
        if (sDVariable3 == null) {
            sDVariable3 = this.sd.scalar(null, sDVariable2.dataType(), Double.valueOf(1.0d));
        }
        SDVariable updateVariableNameAndReference = updateVariableNameAndReference(f().lossSigmoidCrossEntropy(sDVariable, sDVariable2, sDVariable3, lossReduce, d), str);
        updateVariableNameAndReference.markAsLoss();
        return updateVariableNameAndReference;
    }

    public SDVariable sigmoidCrossEntropy(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull LossReduce lossReduce) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return sigmoidCrossEntropy(str, sDVariable, sDVariable2, null, lossReduce, 0.0d);
    }

    public SDVariable softmaxCrossEntropy(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        return softmaxCrossEntropy(str, sDVariable, sDVariable2, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0d);
    }

    public SDVariable softmaxCrossEntropy(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, SDVariable sDVariable3, @NonNull LossReduce lossReduce, double d) {
        if (sDVariable == null) {
            throw new NullPointerException("oneHotLabels is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("logitPredictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("softmax cross entropy loss", "predictions", sDVariable2);
        SDValidation.validateNumerical("softmax cross entropy loss", "oneHotLabels", sDVariable);
        if (sDVariable3 == null) {
            sDVariable3 = this.sd.scalar(null, sDVariable2.dataType(), Double.valueOf(1.0d));
        }
        SDVariable updateVariableNameAndReference = updateVariableNameAndReference(f().lossSoftmaxCrossEntropy(sDVariable, sDVariable2, sDVariable3, lossReduce, d), str);
        updateVariableNameAndReference.markAsLoss();
        return updateVariableNameAndReference;
    }

    public SDVariable softmaxCrossEntropy(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull LossReduce lossReduce) {
        if (sDVariable == null) {
            throw new NullPointerException("label is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        if (lossReduce == null) {
            throw new NullPointerException("lossReduce is marked @NonNull but is null");
        }
        return softmaxCrossEntropy(str, sDVariable, sDVariable2, null, lossReduce, 0.0d);
    }

    public SDVariable sparseSoftmaxCrossEntropy(@NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2) {
        if (sDVariable == null) {
            throw new NullPointerException("logits is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("labels is marked @NonNull but is null");
        }
        return sparseSoftmaxCrossEntropy(null, sDVariable, sDVariable2);
    }

    public SDVariable sparseSoftmaxCrossEntropy(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2) {
        if (sDVariable == null) {
            throw new NullPointerException("logits is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("labels is marked @NonNull but is null");
        }
        SDValidation.validateFloatingPoint("sparse softmax cross entropy", "logits (predictions)", sDVariable);
        SDValidation.validateInteger("sparse softmax cross entropy", SameDiffOutputLayer.LABELS_KEY, sDVariable2);
        Preconditions.checkState(sDVariable2.dataType().isIntType(), "Labels variable must be an integer type: got %s", sDVariable);
        SDVariable updateVariableNameAndReference = updateVariableNameAndReference(f().lossSparseSoftmaxCrossEntropy(sDVariable, sDVariable2), str);
        updateVariableNameAndReference.markAsLoss();
        return updateVariableNameAndReference;
    }

    public SDVariable weightedCrossEntropyWithLogits(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return weightedCrossEntropyWithLogits(null, sDVariable, sDVariable2, sDVariable3);
    }

    public SDVariable weightedCrossEntropyWithLogits(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        SDValidation.validateFloatingPoint("weighted cross entropy with logits", "inputs", sDVariable2);
        SDValidation.validateNumerical("weighted cross entropy with logits", "targets", sDVariable);
        SDVariable updateVariableNameAndReference = updateVariableNameAndReference(f().weightedCrossEntropyWithLogits(sDVariable, sDVariable2, sDVariable3), str);
        updateVariableNameAndReference.markAsLoss();
        return updateVariableNameAndReference;
    }
}
