package ai.sklearn4j.naive_bayes;

import ai.sklearn4j.base.ClassifierMixin;
import ai.sklearn4j.core.libraries.Scipy;
import ai.sklearn4j.core.libraries.numpy.Numpy;
import ai.sklearn4j.core.libraries.numpy.NumpyArray;

/* loaded from: input_file:ai/sklearn4j/naive_bayes/BaseNaiveBayes.class */
public abstract class BaseNaiveBayes extends ClassifierMixin {
    protected abstract NumpyArray<Double> jointLogLikelihood(NumpyArray<Double> numpyArray);

    @Override // ai.sklearn4j.base.ClassifierMixin
    public NumpyArray<Long> predict(NumpyArray<Double> numpyArray) {
        return Numpy.argmax(jointLogLikelihood(numpyArray), 1, false);
    }

    @Override // ai.sklearn4j.base.ClassifierMixin
    public NumpyArray<Double> predictLogProbabilities(NumpyArray<Double> numpyArray) {
        NumpyArray<Double> jointLogLikelihood = jointLogLikelihood(numpyArray);
        return Numpy.subtract(jointLogLikelihood, Numpy.atLeast2D(Scipy.logSumExponent(jointLogLikelihood, 1)).transpose());
    }

    @Override // ai.sklearn4j.base.ClassifierMixin
    public NumpyArray<Double> predictProbabilities(NumpyArray<Double> numpyArray) {
        return Numpy.exp(predictLogProbabilities(numpyArray));
    }
}
